bulk rewrite using this script

set -euo pipefail

    unset CDPATH

    cd "$(dirname "$0")"

    for f in $(git grep '\brequire := require\.New(' | cut -d':' -f1 | sort -u); do
        echo "=== require: $f ==="
        sed -i '/require := require.New(t)/d' $f
        # require.XXX(blah) but not require.XXX(tblah) or require.XXX(rblah)
        sed -i 's/\brequire\.\([a-zA-Z0-9_]*\)(\([^tr]\)/require.\1(t,\2/g' $f
        # require.XXX(tblah) but not require.XXX(t, blah)
        sed -i 's/\brequire\.\([a-zA-Z0-9_]*\)(\(t[^,]\)/require.\1(t,\2/g' $f
        # require.XXX(rblah) but not require.XXX(r, blah)
        sed -i 's/\brequire\.\([a-zA-Z0-9_]*\)(\(r[^,]\)/require.\1(t,\2/g' $f
        gofmt -s -w $f
    done

    for f in $(git grep '\bassert := assert\.New(' | cut -d':' -f1 | sort -u); do
        echo "=== assert: $f ==="
        sed -i '/assert := assert.New(t)/d' $f
        # assert.XXX(blah) but not assert.XXX(tblah) or assert.XXX(rblah)
        sed -i 's/\bassert\.\([a-zA-Z0-9_]*\)(\([^tr]\)/assert.\1(t,\2/g' $f
        # assert.XXX(tblah) but not assert.XXX(t, blah)
        sed -i 's/\bassert\.\([a-zA-Z0-9_]*\)(\(t[^,]\)/assert.\1(t,\2/g' $f
        # assert.XXX(rblah) but not assert.XXX(r, blah)
        sed -i 's/\bassert\.\([a-zA-Z0-9_]*\)(\(r[^,]\)/assert.\1(t,\2/g' $f
        gofmt -s -w $f
    done
This commit is contained in:
R.B. Boyer 2022-01-20 10:46:23 -06:00
parent c12b0ee3d2
commit 05c7373a28
97 changed files with 2523 additions and 3030 deletions

View File

@ -1516,30 +1516,28 @@ func TestMergePolicies(t *testing.T) {
}, },
} }
require := require.New(t)
for _, tcase := range tests { for _, tcase := range tests {
t.Run(tcase.name, func(t *testing.T) { t.Run(tcase.name, func(t *testing.T) {
act := MergePolicies(tcase.input) act := MergePolicies(tcase.input)
exp := tcase.expected exp := tcase.expected
require.Equal(exp.ACL, act.ACL) require.Equal(t, exp.ACL, act.ACL)
require.Equal(exp.Keyring, act.Keyring) require.Equal(t, exp.Keyring, act.Keyring)
require.Equal(exp.Operator, act.Operator) require.Equal(t, exp.Operator, act.Operator)
require.Equal(exp.Mesh, act.Mesh) require.Equal(t, exp.Mesh, act.Mesh)
require.ElementsMatch(exp.Agents, act.Agents) require.ElementsMatch(t, exp.Agents, act.Agents)
require.ElementsMatch(exp.AgentPrefixes, act.AgentPrefixes) require.ElementsMatch(t, exp.AgentPrefixes, act.AgentPrefixes)
require.ElementsMatch(exp.Events, act.Events) require.ElementsMatch(t, exp.Events, act.Events)
require.ElementsMatch(exp.EventPrefixes, act.EventPrefixes) require.ElementsMatch(t, exp.EventPrefixes, act.EventPrefixes)
require.ElementsMatch(exp.Keys, act.Keys) require.ElementsMatch(t, exp.Keys, act.Keys)
require.ElementsMatch(exp.KeyPrefixes, act.KeyPrefixes) require.ElementsMatch(t, exp.KeyPrefixes, act.KeyPrefixes)
require.ElementsMatch(exp.Nodes, act.Nodes) require.ElementsMatch(t, exp.Nodes, act.Nodes)
require.ElementsMatch(exp.NodePrefixes, act.NodePrefixes) require.ElementsMatch(t, exp.NodePrefixes, act.NodePrefixes)
require.ElementsMatch(exp.PreparedQueries, act.PreparedQueries) require.ElementsMatch(t, exp.PreparedQueries, act.PreparedQueries)
require.ElementsMatch(exp.PreparedQueryPrefixes, act.PreparedQueryPrefixes) require.ElementsMatch(t, exp.PreparedQueryPrefixes, act.PreparedQueryPrefixes)
require.ElementsMatch(exp.Services, act.Services) require.ElementsMatch(t, exp.Services, act.Services)
require.ElementsMatch(exp.ServicePrefixes, act.ServicePrefixes) require.ElementsMatch(t, exp.ServicePrefixes, act.ServicePrefixes)
require.ElementsMatch(exp.Sessions, act.Sessions) require.ElementsMatch(t, exp.Sessions, act.Sessions)
require.ElementsMatch(exp.SessionPrefixes, act.SessionPrefixes) require.ElementsMatch(t, exp.SessionPrefixes, act.SessionPrefixes)
}) })
} }

File diff suppressed because it is too large Load Diff

View File

@ -1855,7 +1855,6 @@ func TestAgent_AddCheck_Alias(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
@ -1869,19 +1868,19 @@ func TestAgent_AddCheck_Alias(t *testing.T) {
AliasService: "foo", AliasService: "foo",
} }
err := a.AddCheck(health, chk, false, "", ConfigSourceLocal) err := a.AddCheck(health, chk, false, "", ConfigSourceLocal)
require.NoError(err) require.NoError(t, err)
// Ensure we have a check mapping // Ensure we have a check mapping
sChk := requireCheckExists(t, a, "aliashealth") sChk := requireCheckExists(t, a, "aliashealth")
require.Equal(api.HealthCritical, sChk.Status) require.Equal(t, api.HealthCritical, sChk.Status)
chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)] chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)]
require.True(ok, "missing aliashealth check") require.True(t, ok, "missing aliashealth check")
require.Equal("", chkImpl.RPCReq.Token) require.Equal(t, "", chkImpl.RPCReq.Token)
cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil)) cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil))
require.NotNil(cs) require.NotNil(t, cs)
require.Equal("", cs.Token) require.Equal(t, "", cs.Token)
} }
func TestAgent_AddCheck_Alias_setToken(t *testing.T) { func TestAgent_AddCheck_Alias_setToken(t *testing.T) {
@ -1891,7 +1890,6 @@ func TestAgent_AddCheck_Alias_setToken(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
@ -1905,15 +1903,15 @@ func TestAgent_AddCheck_Alias_setToken(t *testing.T) {
AliasService: "foo", AliasService: "foo",
} }
err := a.AddCheck(health, chk, false, "foo", ConfigSourceLocal) err := a.AddCheck(health, chk, false, "foo", ConfigSourceLocal)
require.NoError(err) require.NoError(t, err)
cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil)) cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil))
require.NotNil(cs) require.NotNil(t, cs)
require.Equal("foo", cs.Token) require.Equal(t, "foo", cs.Token)
chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)] chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)]
require.True(ok, "missing aliashealth check") require.True(t, ok, "missing aliashealth check")
require.Equal("foo", chkImpl.RPCReq.Token) require.Equal(t, "foo", chkImpl.RPCReq.Token)
} }
func TestAgent_AddCheck_Alias_userToken(t *testing.T) { func TestAgent_AddCheck_Alias_userToken(t *testing.T) {
@ -1923,7 +1921,6 @@ func TestAgent_AddCheck_Alias_userToken(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := NewTestAgent(t, ` a := NewTestAgent(t, `
acl_token = "hello" acl_token = "hello"
`) `)
@ -1939,15 +1936,15 @@ acl_token = "hello"
AliasService: "foo", AliasService: "foo",
} }
err := a.AddCheck(health, chk, false, "", ConfigSourceLocal) err := a.AddCheck(health, chk, false, "", ConfigSourceLocal)
require.NoError(err) require.NoError(t, err)
cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil)) cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil))
require.NotNil(cs) require.NotNil(t, cs)
require.Equal("", cs.Token) // State token should still be empty require.Equal(t, "", cs.Token) // State token should still be empty
chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)] chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)]
require.True(ok, "missing aliashealth check") require.True(t, ok, "missing aliashealth check")
require.Equal("hello", chkImpl.RPCReq.Token) // Check should use the token require.Equal(t, "hello", chkImpl.RPCReq.Token) // Check should use the token
} }
func TestAgent_AddCheck_Alias_userAndSetToken(t *testing.T) { func TestAgent_AddCheck_Alias_userAndSetToken(t *testing.T) {
@ -1957,7 +1954,6 @@ func TestAgent_AddCheck_Alias_userAndSetToken(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := NewTestAgent(t, ` a := NewTestAgent(t, `
acl_token = "hello" acl_token = "hello"
`) `)
@ -1973,15 +1969,15 @@ acl_token = "hello"
AliasService: "foo", AliasService: "foo",
} }
err := a.AddCheck(health, chk, false, "goodbye", ConfigSourceLocal) err := a.AddCheck(health, chk, false, "goodbye", ConfigSourceLocal)
require.NoError(err) require.NoError(t, err)
cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil)) cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil))
require.NotNil(cs) require.NotNil(t, cs)
require.Equal("goodbye", cs.Token) require.Equal(t, "goodbye", cs.Token)
chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)] chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)]
require.True(ok, "missing aliashealth check") require.True(t, ok, "missing aliashealth check")
require.Equal("goodbye", chkImpl.RPCReq.Token) require.Equal(t, "goodbye", chkImpl.RPCReq.Token)
} }
func TestAgent_RemoveCheck(t *testing.T) { func TestAgent_RemoveCheck(t *testing.T) {

View File

@ -11,7 +11,6 @@ import (
) )
func TestCatalogServices(t *testing.T) { func TestCatalogServices(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &CatalogServices{RPC: rpc} typ := &CatalogServices{RPC: rpc}
@ -22,10 +21,10 @@ func TestCatalogServices(t *testing.T) {
rpc.On("RPC", "Catalog.ServiceNodes", mock.Anything, mock.Anything).Return(nil). rpc.On("RPC", "Catalog.ServiceNodes", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) { Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.ServiceSpecificRequest) req := args.Get(1).(*structs.ServiceSpecificRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex) require.Equal(t, uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime) require.Equal(t, 1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal("web", req.ServiceName) require.Equal(t, "web", req.ServiceName)
require.True(req.AllowStale) require.True(t, req.AllowStale)
reply := args.Get(2).(*structs.IndexedServiceNodes) reply := args.Get(2).(*structs.IndexedServiceNodes)
reply.ServiceNodes = []*structs.ServiceNode{ reply.ServiceNodes = []*structs.ServiceNode{
@ -44,15 +43,14 @@ func TestCatalogServices(t *testing.T) {
ServiceName: "web", ServiceName: "web",
ServiceTags: []string{"tag1", "tag2"}, ServiceTags: []string{"tag1", "tag2"},
}) })
require.NoError(err) require.NoError(t, err)
require.Equal(cache.FetchResult{ require.Equal(t, cache.FetchResult{
Value: resp, Value: resp,
Index: 48, Index: 48,
}, resultA) }, resultA)
} }
func TestCatalogServices_badReqType(t *testing.T) { func TestCatalogServices_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &CatalogServices{RPC: rpc} typ := &CatalogServices{RPC: rpc}
@ -60,7 +58,7 @@ func TestCatalogServices_badReqType(t *testing.T) {
// Fetch // Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest( _, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64})) t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), "wrong type") require.Contains(t, err.Error(), "wrong type")
} }

View File

@ -123,23 +123,22 @@ func TestCalculateSoftExpire(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
now, err := time.Parse("2006-01-02 15:04:05", tc.now) now, err := time.Parse("2006-01-02 15:04:05", tc.now)
require.NoError(err) require.NoError(t, err)
issued, err := time.Parse("2006-01-02 15:04:05", tc.issued) issued, err := time.Parse("2006-01-02 15:04:05", tc.issued)
require.NoError(err) require.NoError(t, err)
wantMin, err := time.Parse("2006-01-02 15:04:05", tc.wantMin) wantMin, err := time.Parse("2006-01-02 15:04:05", tc.wantMin)
require.NoError(err) require.NoError(t, err)
wantMax, err := time.Parse("2006-01-02 15:04:05", tc.wantMax) wantMax, err := time.Parse("2006-01-02 15:04:05", tc.wantMax)
require.NoError(err) require.NoError(t, err)
min, max := calculateSoftExpiry(now, &structs.IssuedCert{ min, max := calculateSoftExpiry(now, &structs.IssuedCert{
ValidAfter: issued, ValidAfter: issued,
ValidBefore: issued.Add(tc.lifetime), ValidBefore: issued.Add(tc.lifetime),
}) })
require.Equal(wantMin, min) require.Equal(t, wantMin, min)
require.Equal(wantMax, max) require.Equal(t, wantMax, max)
}) })
} }
} }
@ -156,7 +155,6 @@ func TestConnectCALeaf_changingRoots(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
@ -211,8 +209,8 @@ func TestConnectCALeaf_changingRoots(t *testing.T) {
t.Fatal("shouldn't block waiting for fetch") t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh: case result := <-fetchCh:
v := mustFetchResult(t, result) v := mustFetchResult(t, result)
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
require.Equal(uint64(1), v.Index) require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
} }
@ -244,9 +242,9 @@ func TestConnectCALeaf_changingRoots(t *testing.T) {
t.Fatal("shouldn't block waiting for fetch") t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh: case result := <-fetchCh:
v := mustFetchResult(t, result) v := mustFetchResult(t, result)
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
// 3 since the second CA "update" used up 2 // 3 since the second CA "update" used up 2
require.Equal(uint64(3), v.Index) require.Equal(t, uint64(3), v.Index)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
opts.MinIndex = 3 opts.MinIndex = 3
@ -267,7 +265,6 @@ func TestConnectCALeaf_changingRoots(t *testing.T) {
func TestConnectCALeaf_changingRootsJitterBetweenCalls(t *testing.T) { func TestConnectCALeaf_changingRootsJitterBetweenCalls(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
@ -323,8 +320,8 @@ func TestConnectCALeaf_changingRootsJitterBetweenCalls(t *testing.T) {
t.Fatal("shouldn't block waiting for fetch") t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh: case result := <-fetchCh:
v := mustFetchResult(t, result) v := mustFetchResult(t, result)
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
require.Equal(uint64(1), v.Index) require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
} }
@ -378,24 +375,24 @@ func TestConnectCALeaf_changingRootsJitterBetweenCalls(t *testing.T) {
if v.Index > uint64(1) { if v.Index > uint64(1) {
// Got a new cert // Got a new cert
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
require.Equal(uint64(3), v.Index) require.Equal(t, uint64(3), v.Index)
// Should not have been delivered before the delay // Should not have been delivered before the delay
require.True(time.Since(earliestRootDelivery) > typ.TestOverrideCAChangeInitialDelay) require.True(t, time.Since(earliestRootDelivery) > typ.TestOverrideCAChangeInitialDelay)
// All good. We are done! // All good. We are done!
rootsDelivered = true rootsDelivered = true
} else { } else {
// Should be the cached cert // Should be the cached cert
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
require.Equal(uint64(1), v.Index) require.Equal(t, uint64(1), v.Index)
// Sanity check we blocked for the whole timeout // Sanity check we blocked for the whole timeout
require.Truef(timeTaken > opts.Timeout, require.Truef(t, timeTaken > opts.Timeout,
"should block for at least %s, returned after %s", "should block for at least %s, returned after %s",
opts.Timeout, timeTaken) opts.Timeout, timeTaken)
// Sanity check that the forceExpireAfter state was set correctly // Sanity check that the forceExpireAfter state was set correctly
shouldExpireAfter = v.State.(*fetchState).forceExpireAfter shouldExpireAfter = v.State.(*fetchState).forceExpireAfter
require.True(shouldExpireAfter.After(time.Now())) require.True(t, shouldExpireAfter.After(time.Now()))
require.True(shouldExpireAfter.Before(time.Now().Add(typ.TestOverrideCAChangeInitialDelay))) require.True(t, shouldExpireAfter.Before(time.Now().Add(typ.TestOverrideCAChangeInitialDelay)))
} }
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
@ -406,7 +403,7 @@ func TestConnectCALeaf_changingRootsJitterBetweenCalls(t *testing.T) {
// Sanity check that we've not gone way beyond the deadline without a // Sanity check that we've not gone way beyond the deadline without a
// new cert. We give some leeway to make it less brittle. // new cert. We give some leeway to make it less brittle.
require.Falsef(time.Now().After(shouldExpireAfter.Add(100*time.Millisecond)), require.Falsef(t, time.Now().After(shouldExpireAfter.Add(100*time.Millisecond)),
"waited extra 100ms and delayed CA rotate renew didn't happen") "waited extra 100ms and delayed CA rotate renew didn't happen")
} }
} }
@ -415,7 +412,6 @@ func TestConnectCALeaf_changingRootsJitterBetweenCalls(t *testing.T) {
func TestConnectCALeaf_changingRootsBetweenBlockingCalls(t *testing.T) { func TestConnectCALeaf_changingRootsBetweenBlockingCalls(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
@ -460,8 +456,8 @@ func TestConnectCALeaf_changingRootsBetweenBlockingCalls(t *testing.T) {
t.Fatal("shouldn't block waiting for fetch") t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh: case result := <-fetchCh:
v := mustFetchResult(t, result) v := mustFetchResult(t, result)
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
require.Equal(uint64(1), v.Index) require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
} }
@ -474,11 +470,11 @@ func TestConnectCALeaf_changingRootsBetweenBlockingCalls(t *testing.T) {
t.Fatal("shouldn't block for too long waiting for fetch") t.Fatal("shouldn't block for too long waiting for fetch")
case result := <-fetchCh: case result := <-fetchCh:
v := mustFetchResult(t, result) v := mustFetchResult(t, result)
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
// Still the initial cached result // Still the initial cached result
require.Equal(uint64(1), v.Index) require.Equal(t, uint64(1), v.Index)
// Sanity check that it waited // Sanity check that it waited
require.True(time.Since(start) > opts.Timeout) require.True(t, time.Since(start) > opts.Timeout)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
} }
@ -506,11 +502,11 @@ func TestConnectCALeaf_changingRootsBetweenBlockingCalls(t *testing.T) {
t.Fatal("shouldn't block too long waiting for fetch") t.Fatal("shouldn't block too long waiting for fetch")
case result := <-fetchCh: case result := <-fetchCh:
v := mustFetchResult(t, result) v := mustFetchResult(t, result)
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
// Index should be 3 since root change consumed 2 // Index should be 3 since root change consumed 2
require.Equal(uint64(3), v.Index) require.Equal(t, uint64(3), v.Index)
// Sanity check that we didn't wait too long // Sanity check that we didn't wait too long
require.True(time.Since(earliestRootDelivery) < opts.Timeout) require.True(t, time.Since(earliestRootDelivery) < opts.Timeout)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
} }
@ -524,7 +520,6 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
@ -593,8 +588,8 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
case result := <-fetchCh: case result := <-fetchCh:
switch v := result.(type) { switch v := result.(type) {
case error: case error:
require.Error(v) require.Error(t, v)
require.Equal(consul.ErrRateLimited.Error(), v.Error()) require.Equal(t, consul.ErrRateLimited.Error(), v.Error())
case cache.FetchResult: case cache.FetchResult:
t.Fatalf("Expected error") t.Fatalf("Expected error")
} }
@ -607,8 +602,8 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
t.Fatal("shouldn't block waiting for fetch") t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh: case result := <-fetchCh:
v := mustFetchResult(t, result) v := mustFetchResult(t, result)
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
require.Equal(uint64(1), v.Index) require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
// Set MinIndex // Set MinIndex
@ -632,7 +627,7 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
earliestRootDelivery := time.Now() earliestRootDelivery := time.Now()
// Sanity check state // Sanity check state
require.Equal(uint64(1), atomic.LoadUint64(&rateLimitedRPCs)) require.Equal(t, uint64(1), atomic.LoadUint64(&rateLimitedRPCs))
// After root rotation jitter has been waited out, a new CSR will // After root rotation jitter has been waited out, a new CSR will
// be attempted but will fail and return the previous cached result with no // be attempted but will fail and return the previous cached result with no
@ -645,14 +640,14 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
// We should block for _at least_ one jitter period since we set that to // We should block for _at least_ one jitter period since we set that to
// 100ms and in test override mode we always pick the max jitter not a // 100ms and in test override mode we always pick the max jitter not a
// random amount. // random amount.
require.True(time.Since(earliestRootDelivery) > 100*time.Millisecond) require.True(t, time.Since(earliestRootDelivery) > 100*time.Millisecond)
require.Equal(uint64(2), atomic.LoadUint64(&rateLimitedRPCs)) require.Equal(t, uint64(2), atomic.LoadUint64(&rateLimitedRPCs))
v := mustFetchResult(t, result) v := mustFetchResult(t, result)
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
// 1 since this should still be the original cached result as we failed to // 1 since this should still be the original cached result as we failed to
// get a new cert. // get a new cert.
require.Equal(uint64(1), v.Index) require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
} }
@ -666,14 +661,14 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
t.Fatal("shouldn't block too long waiting for fetch") t.Fatal("shouldn't block too long waiting for fetch")
case result := <-fetchCh: case result := <-fetchCh:
// We should block for _at least_ two jitter periods now. // We should block for _at least_ two jitter periods now.
require.True(time.Since(earliestRootDelivery) > 200*time.Millisecond) require.True(t, time.Since(earliestRootDelivery) > 200*time.Millisecond)
require.Equal(uint64(3), atomic.LoadUint64(&rateLimitedRPCs)) require.Equal(t, uint64(3), atomic.LoadUint64(&rateLimitedRPCs))
v := mustFetchResult(t, result) v := mustFetchResult(t, result)
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
// 1 since this should still be the original cached result as we failed to // 1 since this should still be the original cached result as we failed to
// get a new cert. // get a new cert.
require.Equal(uint64(1), v.Index) require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
} }
@ -688,13 +683,13 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
t.Fatal("shouldn't block too long waiting for fetch") t.Fatal("shouldn't block too long waiting for fetch")
case result := <-fetchCh: case result := <-fetchCh:
// We should block for _at least_ three jitter periods now. // We should block for _at least_ three jitter periods now.
require.True(time.Since(earliestRootDelivery) > 300*time.Millisecond) require.True(t, time.Since(earliestRootDelivery) > 300*time.Millisecond)
require.Equal(uint64(3), atomic.LoadUint64(&rateLimitedRPCs)) require.Equal(t, uint64(3), atomic.LoadUint64(&rateLimitedRPCs))
v := mustFetchResult(t, result) v := mustFetchResult(t, result)
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
// 3 since the rootCA change used 2 // 3 since the rootCA change used 2
require.Equal(uint64(3), v.Index) require.Equal(t, uint64(3), v.Index)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
} }
@ -908,7 +903,6 @@ func TestConnectCALeaf_expiringLeaf(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
@ -962,10 +956,10 @@ func TestConnectCALeaf_expiringLeaf(t *testing.T) {
case result := <-fetchCh: case result := <-fetchCh:
switch v := result.(type) { switch v := result.(type) {
case error: case error:
require.NoError(v) require.NoError(t, v)
case cache.FetchResult: case cache.FetchResult:
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
require.Equal(uint64(1), v.Index) require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
} }
@ -980,10 +974,10 @@ func TestConnectCALeaf_expiringLeaf(t *testing.T) {
case result := <-fetchCh: case result := <-fetchCh:
switch v := result.(type) { switch v := result.(type) {
case error: case error:
require.NoError(v) require.NoError(t, v)
case cache.FetchResult: case cache.FetchResult:
require.Equal(resp, v.Value) require.Equal(t, resp, v.Value)
require.Equal(uint64(2), v.Index) require.Equal(t, uint64(2), v.Index)
// Set the LastResult for subsequent fetches // Set the LastResult for subsequent fetches
opts.LastResult = &v opts.LastResult = &v
} }
@ -1003,7 +997,6 @@ func TestConnectCALeaf_expiringLeaf(t *testing.T) {
func TestConnectCALeaf_DNSSANForService(t *testing.T) { func TestConnectCALeaf_DNSSANForService(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
@ -1039,12 +1032,12 @@ func TestConnectCALeaf_DNSSANForService(t *testing.T) {
DNSSAN: []string{"test.example.com"}, DNSSAN: []string{"test.example.com"},
} }
_, err := typ.Fetch(opts, req) _, err := typ.Fetch(opts, req)
require.NoError(err) require.NoError(t, err)
pemBlock, _ := pem.Decode([]byte(caReq.CSR)) pemBlock, _ := pem.Decode([]byte(caReq.CSR))
csr, err := x509.ParseCertificateRequest(pemBlock.Bytes) csr, err := x509.ParseCertificateRequest(pemBlock.Bytes)
require.NoError(err) require.NoError(t, err)
require.Equal(csr.DNSNames, []string{"test.example.com"}) require.Equal(t, csr.DNSNames, []string{"test.example.com"})
} }
// testConnectCaRoot wraps ConnectCARoot to disable refresh so that the gated // testConnectCaRoot wraps ConnectCARoot to disable refresh so that the gated

View File

@ -11,7 +11,6 @@ import (
) )
func TestConnectCARoot(t *testing.T) { func TestConnectCARoot(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &ConnectCARoot{RPC: rpc} typ := &ConnectCARoot{RPC: rpc}
@ -22,8 +21,8 @@ func TestConnectCARoot(t *testing.T) {
rpc.On("RPC", "ConnectCA.Roots", mock.Anything, mock.Anything).Return(nil). rpc.On("RPC", "ConnectCA.Roots", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) { Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.DCSpecificRequest) req := args.Get(1).(*structs.DCSpecificRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex) require.Equal(t, uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime) require.Equal(t, 1*time.Second, req.QueryOptions.MaxQueryTime)
reply := args.Get(2).(*structs.IndexedCARoots) reply := args.Get(2).(*structs.IndexedCARoots)
reply.QueryMeta.Index = 48 reply.QueryMeta.Index = 48
@ -35,15 +34,14 @@ func TestConnectCARoot(t *testing.T) {
MinIndex: 24, MinIndex: 24,
Timeout: 1 * time.Second, Timeout: 1 * time.Second,
}, &structs.DCSpecificRequest{Datacenter: "dc1"}) }, &structs.DCSpecificRequest{Datacenter: "dc1"})
require.Nil(err) require.Nil(t, err)
require.Equal(cache.FetchResult{ require.Equal(t, cache.FetchResult{
Value: resp, Value: resp,
Index: 48, Index: 48,
}, result) }, result)
} }
func TestConnectCARoot_badReqType(t *testing.T) { func TestConnectCARoot_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &ConnectCARoot{RPC: rpc} typ := &ConnectCARoot{RPC: rpc}
@ -51,7 +49,7 @@ func TestConnectCARoot_badReqType(t *testing.T) {
// Fetch // Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest( _, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64})) t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.NotNil(err) require.NotNil(t, err)
require.Contains(err.Error(), "wrong type") require.Contains(t, err.Error(), "wrong type")
} }

View File

@ -11,7 +11,6 @@ import (
) )
func TestHealthServices(t *testing.T) { func TestHealthServices(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &HealthServices{RPC: rpc} typ := &HealthServices{RPC: rpc}
@ -22,10 +21,10 @@ func TestHealthServices(t *testing.T) {
rpc.On("RPC", "Health.ServiceNodes", mock.Anything, mock.Anything).Return(nil). rpc.On("RPC", "Health.ServiceNodes", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) { Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.ServiceSpecificRequest) req := args.Get(1).(*structs.ServiceSpecificRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex) require.Equal(t, uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime) require.Equal(t, 1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal("web", req.ServiceName) require.Equal(t, "web", req.ServiceName)
require.True(req.AllowStale) require.True(t, req.AllowStale)
reply := args.Get(2).(*structs.IndexedCheckServiceNodes) reply := args.Get(2).(*structs.IndexedCheckServiceNodes)
reply.Nodes = []structs.CheckServiceNode{ reply.Nodes = []structs.CheckServiceNode{
@ -44,15 +43,14 @@ func TestHealthServices(t *testing.T) {
ServiceName: "web", ServiceName: "web",
ServiceTags: []string{"tag1", "tag2"}, ServiceTags: []string{"tag1", "tag2"},
}) })
require.NoError(err) require.NoError(t, err)
require.Equal(cache.FetchResult{ require.Equal(t, cache.FetchResult{
Value: resp, Value: resp,
Index: 48, Index: 48,
}, resultA) }, resultA)
} }
func TestHealthServices_badReqType(t *testing.T) { func TestHealthServices_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &HealthServices{RPC: rpc} typ := &HealthServices{RPC: rpc}
@ -60,7 +58,7 @@ func TestHealthServices_badReqType(t *testing.T) {
// Fetch // Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest( _, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64})) t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), "wrong type") require.Contains(t, err.Error(), "wrong type")
} }

View File

@ -11,7 +11,6 @@ import (
) )
func TestIntentionMatch(t *testing.T) { func TestIntentionMatch(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &IntentionMatch{RPC: rpc} typ := &IntentionMatch{RPC: rpc}
@ -22,8 +21,8 @@ func TestIntentionMatch(t *testing.T) {
rpc.On("RPC", "Intention.Match", mock.Anything, mock.Anything).Return(nil). rpc.On("RPC", "Intention.Match", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) { Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.IntentionQueryRequest) req := args.Get(1).(*structs.IntentionQueryRequest)
require.Equal(uint64(24), req.MinQueryIndex) require.Equal(t, uint64(24), req.MinQueryIndex)
require.Equal(1*time.Second, req.MaxQueryTime) require.Equal(t, 1*time.Second, req.MaxQueryTime)
reply := args.Get(2).(*structs.IndexedIntentionMatches) reply := args.Get(2).(*structs.IndexedIntentionMatches)
reply.Index = 48 reply.Index = 48
@ -35,15 +34,14 @@ func TestIntentionMatch(t *testing.T) {
MinIndex: 24, MinIndex: 24,
Timeout: 1 * time.Second, Timeout: 1 * time.Second,
}, &structs.IntentionQueryRequest{Datacenter: "dc1"}) }, &structs.IntentionQueryRequest{Datacenter: "dc1"})
require.NoError(err) require.NoError(t, err)
require.Equal(cache.FetchResult{ require.Equal(t, cache.FetchResult{
Value: resp, Value: resp,
Index: 48, Index: 48,
}, result) }, result)
} }
func TestIntentionMatch_badReqType(t *testing.T) { func TestIntentionMatch_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &IntentionMatch{RPC: rpc} typ := &IntentionMatch{RPC: rpc}
@ -51,7 +49,7 @@ func TestIntentionMatch_badReqType(t *testing.T) {
// Fetch // Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest( _, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64})) t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), "wrong type") require.Contains(t, err.Error(), "wrong type")
} }

View File

@ -11,7 +11,6 @@ import (
) )
func TestNodeServices(t *testing.T) { func TestNodeServices(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &NodeServices{RPC: rpc} typ := &NodeServices{RPC: rpc}
@ -22,10 +21,10 @@ func TestNodeServices(t *testing.T) {
rpc.On("RPC", "Catalog.NodeServices", mock.Anything, mock.Anything).Return(nil). rpc.On("RPC", "Catalog.NodeServices", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) { Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.NodeSpecificRequest) req := args.Get(1).(*structs.NodeSpecificRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex) require.Equal(t, uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime) require.Equal(t, 1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal("node-01", req.Node) require.Equal(t, "node-01", req.Node)
require.True(req.AllowStale) require.True(t, req.AllowStale)
reply := args.Get(2).(*structs.IndexedNodeServices) reply := args.Get(2).(*structs.IndexedNodeServices)
reply.NodeServices = &structs.NodeServices{ reply.NodeServices = &structs.NodeServices{
@ -49,15 +48,14 @@ func TestNodeServices(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
Node: "node-01", Node: "node-01",
}) })
require.NoError(err) require.NoError(t, err)
require.Equal(cache.FetchResult{ require.Equal(t, cache.FetchResult{
Value: resp, Value: resp,
Index: 48, Index: 48,
}, resultA) }, resultA)
} }
func TestNodeServices_badReqType(t *testing.T) { func TestNodeServices_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &NodeServices{RPC: rpc} typ := &NodeServices{RPC: rpc}
@ -65,7 +63,7 @@ func TestNodeServices_badReqType(t *testing.T) {
// Fetch // Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest( _, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64})) t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), "wrong type") require.Contains(t, err.Error(), "wrong type")
} }

View File

@ -10,7 +10,6 @@ import (
) )
func TestPreparedQuery(t *testing.T) { func TestPreparedQuery(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &PreparedQuery{RPC: rpc} typ := &PreparedQuery{RPC: rpc}
@ -21,9 +20,9 @@ func TestPreparedQuery(t *testing.T) {
rpc.On("RPC", "PreparedQuery.Execute", mock.Anything, mock.Anything).Return(nil). rpc.On("RPC", "PreparedQuery.Execute", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) { Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.PreparedQueryExecuteRequest) req := args.Get(1).(*structs.PreparedQueryExecuteRequest)
require.Equal("geo-db", req.QueryIDOrName) require.Equal(t, "geo-db", req.QueryIDOrName)
require.Equal(10, req.Limit) require.Equal(t, 10, req.Limit)
require.True(req.AllowStale) require.True(t, req.AllowStale)
reply := args.Get(2).(*structs.PreparedQueryExecuteResponse) reply := args.Get(2).(*structs.PreparedQueryExecuteResponse)
reply.QueryMeta.Index = 48 reply.QueryMeta.Index = 48
@ -36,15 +35,14 @@ func TestPreparedQuery(t *testing.T) {
QueryIDOrName: "geo-db", QueryIDOrName: "geo-db",
Limit: 10, Limit: 10,
}) })
require.NoError(err) require.NoError(t, err)
require.Equal(cache.FetchResult{ require.Equal(t, cache.FetchResult{
Value: resp, Value: resp,
Index: 48, Index: 48,
}, result) }, result)
} }
func TestPreparedQuery_badReqType(t *testing.T) { func TestPreparedQuery_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &PreparedQuery{RPC: rpc} typ := &PreparedQuery{RPC: rpc}
@ -52,6 +50,6 @@ func TestPreparedQuery_badReqType(t *testing.T) {
// Fetch // Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest( _, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64})) t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), "wrong type") require.Contains(t, err.Error(), "wrong type")
} }

View File

@ -11,7 +11,6 @@ import (
) )
func TestResolvedServiceConfig(t *testing.T) { func TestResolvedServiceConfig(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &ResolvedServiceConfig{RPC: rpc} typ := &ResolvedServiceConfig{RPC: rpc}
@ -22,10 +21,10 @@ func TestResolvedServiceConfig(t *testing.T) {
rpc.On("RPC", "ConfigEntry.ResolveServiceConfig", mock.Anything, mock.Anything).Return(nil). rpc.On("RPC", "ConfigEntry.ResolveServiceConfig", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) { Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.ServiceConfigRequest) req := args.Get(1).(*structs.ServiceConfigRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex) require.Equal(t, uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime) require.Equal(t, 1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal("foo", req.Name) require.Equal(t, "foo", req.Name)
require.True(req.AllowStale) require.True(t, req.AllowStale)
reply := args.Get(2).(*structs.ServiceConfigResponse) reply := args.Get(2).(*structs.ServiceConfigResponse)
reply.ProxyConfig = map[string]interface{}{ reply.ProxyConfig = map[string]interface{}{
@ -49,15 +48,14 @@ func TestResolvedServiceConfig(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
Name: "foo", Name: "foo",
}) })
require.NoError(err) require.NoError(t, err)
require.Equal(cache.FetchResult{ require.Equal(t, cache.FetchResult{
Value: resp, Value: resp,
Index: 48, Index: 48,
}, resultA) }, resultA)
} }
func TestResolvedServiceConfig_badReqType(t *testing.T) { func TestResolvedServiceConfig_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t) rpc := TestRPC(t)
defer rpc.AssertExpectations(t) defer rpc.AssertExpectations(t)
typ := &ResolvedServiceConfig{RPC: rpc} typ := &ResolvedServiceConfig{RPC: rpc}
@ -65,7 +63,7 @@ func TestResolvedServiceConfig_badReqType(t *testing.T) {
// Fetch // Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest( _, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64})) t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), "wrong type") require.Contains(t, err.Error(), "wrong type")
} }

View File

@ -24,8 +24,6 @@ import (
func TestCacheGet_noIndex(t *testing.T) { func TestCacheGet_noIndex(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := TestType(t) typ := TestType(t)
defer typ.AssertExpectations(t) defer typ.AssertExpectations(t)
c := New(Options{}) c := New(Options{})
@ -37,15 +35,15 @@ func TestCacheGet_noIndex(t *testing.T) {
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Get, should not fetch since we already have a satisfying value // Get, should not fetch since we already have a satisfying value
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.True(meta.Hit) require.True(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen // Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call // then verify that we still only got the one call
@ -57,8 +55,6 @@ func TestCacheGet_noIndex(t *testing.T) {
func TestCacheGet_initError(t *testing.T) { func TestCacheGet_initError(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := TestType(t) typ := TestType(t)
defer typ.AssertExpectations(t) defer typ.AssertExpectations(t)
c := New(Options{}) c := New(Options{})
@ -71,15 +67,15 @@ func TestCacheGet_initError(t *testing.T) {
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.Error(err) require.Error(t, err)
require.Nil(result) require.Nil(t, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Get, should fetch again since our last fetch was an error // Get, should fetch again since our last fetch was an error
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.Error(err) require.Error(t, err)
require.Nil(result) require.Nil(t, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen // Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call // then verify that we still only got the one call
@ -96,8 +92,6 @@ func TestCacheGet_cachedErrorsDontStick(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := TestType(t) typ := TestType(t)
defer typ.AssertExpectations(t) defer typ.AssertExpectations(t)
c := New(Options{}) c := New(Options{})
@ -115,15 +109,15 @@ func TestCacheGet_cachedErrorsDontStick(t *testing.T) {
// Get, should fetch and get error // Get, should fetch and get error
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.Error(err) require.Error(t, err)
require.Nil(result) require.Nil(t, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Get, should fetch again since our last fetch was an error, but get success // Get, should fetch again since our last fetch was an error, but get success
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Now get should block until timeout and then get the same response NOT the // Now get should block until timeout and then get the same response NOT the
// cached error. // cached error.
@ -157,8 +151,6 @@ func TestCacheGet_cachedErrorsDontStick(t *testing.T) {
func TestCacheGet_blankCacheKey(t *testing.T) { func TestCacheGet_blankCacheKey(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := TestType(t) typ := TestType(t)
defer typ.AssertExpectations(t) defer typ.AssertExpectations(t)
c := New(Options{}) c := New(Options{})
@ -170,15 +162,15 @@ func TestCacheGet_blankCacheKey(t *testing.T) {
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: ""}) req := TestRequest(t, RequestInfo{Key: ""})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Get, should not fetch since we already have a satisfying value // Get, should not fetch since we already have a satisfying value
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen // Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call // then verify that we still only got the one call
@ -225,8 +217,6 @@ func TestCacheGet_blockingInitSameKey(t *testing.T) {
func TestCacheGet_blockingInitDiffKeys(t *testing.T) { func TestCacheGet_blockingInitDiffKeys(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := TestType(t) typ := TestType(t)
defer typ.AssertExpectations(t) defer typ.AssertExpectations(t)
c := New(Options{}) c := New(Options{})
@ -269,7 +259,7 @@ func TestCacheGet_blockingInitDiffKeys(t *testing.T) {
// Verify proper keys // Verify proper keys
sort.Strings(keys) sort.Strings(keys)
require.Equal([]string{"goodbye", "hello"}, keys) require.Equal(t, []string{"goodbye", "hello"}, keys)
} }
// Test a get with an index set will wait until an index that is higher // Test a get with an index set will wait until an index that is higher
@ -414,8 +404,6 @@ func TestCacheGet_emptyFetchResult(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := TestType(t) typ := TestType(t)
defer typ.AssertExpectations(t) defer typ.AssertExpectations(t)
c := New(Options{}) c := New(Options{})
@ -429,29 +417,29 @@ func TestCacheGet_emptyFetchResult(t *testing.T) {
typ.Static(FetchResult{Value: nil, State: 32}, nil).Run(func(args mock.Arguments) { typ.Static(FetchResult{Value: nil, State: 32}, nil).Run(func(args mock.Arguments) {
// We should get back the original state // We should get back the original state
opts := args.Get(0).(FetchOptions) opts := args.Get(0).(FetchOptions)
require.NotNil(opts.LastResult) require.NotNil(t, opts.LastResult)
stateCh <- opts.LastResult.State.(int) stateCh <- opts.LastResult.State.(int)
}) })
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Get, should not fetch since we already have a satisfying value // Get, should not fetch since we already have a satisfying value
req = TestRequest(t, RequestInfo{ req = TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 1, Timeout: 100 * time.Millisecond}) Key: "hello", MinIndex: 1, Timeout: 100 * time.Millisecond})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// State delivered to second call should be the result from first call. // State delivered to second call should be the result from first call.
select { select {
case state := <-stateCh: case state := <-stateCh:
require.Equal(31, state) require.Equal(t, 31, state)
case <-time.After(20 * time.Millisecond): case <-time.After(20 * time.Millisecond):
t.Fatal("timed out") t.Fatal("timed out")
} }
@ -461,12 +449,12 @@ func TestCacheGet_emptyFetchResult(t *testing.T) {
req = TestRequest(t, RequestInfo{ req = TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 1, Timeout: 100 * time.Millisecond}) Key: "hello", MinIndex: 1, Timeout: 100 * time.Millisecond})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
select { select {
case state := <-stateCh: case state := <-stateCh:
require.Equal(32, state) require.Equal(t, 32, state)
case <-time.After(20 * time.Millisecond): case <-time.After(20 * time.Millisecond):
t.Fatal("timed out") t.Fatal("timed out")
} }
@ -737,8 +725,6 @@ func TestCacheGet_noIndexSetsOne(t *testing.T) {
func TestCacheGet_fetchTimeout(t *testing.T) { func TestCacheGet_fetchTimeout(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := &MockType{} typ := &MockType{}
timeout := 10 * time.Minute timeout := 10 * time.Minute
typ.On("RegisterOptions").Return(RegisterOptions{ typ.On("RegisterOptions").Return(RegisterOptions{
@ -761,12 +747,12 @@ func TestCacheGet_fetchTimeout(t *testing.T) {
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Test the timeout // Test the timeout
require.Equal(timeout, actual) require.Equal(t, timeout, actual)
} }
// Test that entries expire // Test that entries expire
@ -777,8 +763,6 @@ func TestCacheGet_expire(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := &MockType{} typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{ typ.On("RegisterOptions").Return(RegisterOptions{
LastGetTTL: 400 * time.Millisecond, LastGetTTL: 400 * time.Millisecond,
@ -795,9 +779,9 @@ func TestCacheGet_expire(t *testing.T) {
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Wait for a non-trivial amount of time to sanity check the age increases at // Wait for a non-trivial amount of time to sanity check the age increases at
// least this amount. Note that this is not a fudge for some timing-dependent // least this amount. Note that this is not a fudge for some timing-dependent
@ -808,10 +792,10 @@ func TestCacheGet_expire(t *testing.T) {
// Get, should not fetch, verified via the mock assertions above // Get, should not fetch, verified via the mock assertions above
req = TestRequest(t, RequestInfo{Key: "hello"}) req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.True(meta.Hit) require.True(t, meta.Hit)
require.True(meta.Age > 5*time.Millisecond) require.True(t, meta.Age > 5*time.Millisecond)
// Sleep for the expiry // Sleep for the expiry
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
@ -819,9 +803,9 @@ func TestCacheGet_expire(t *testing.T) {
// Get, should fetch // Get, should fetch
req = TestRequest(t, RequestInfo{Key: "hello"}) req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen then verify // Sleep a tiny bit just to let maybe some background calls happen then verify
// that we still only got the one call // that we still only got the one call
@ -837,8 +821,6 @@ func TestCacheGet_expire(t *testing.T) {
func TestCacheGet_expireBackgroudRefreshCancel(t *testing.T) { func TestCacheGet_expireBackgroudRefreshCancel(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := &MockType{} typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{ typ.On("RegisterOptions").Return(RegisterOptions{
LastGetTTL: 400 * time.Millisecond, LastGetTTL: 400 * time.Millisecond,
@ -879,18 +861,18 @@ func TestCacheGet_expireBackgroudRefreshCancel(t *testing.T) {
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.Equal(uint64(4), meta.Index) require.Equal(t, uint64(4), meta.Index)
require.False(meta.Hit) require.False(t, meta.Hit)
// Get, should not fetch, verified via the mock assertions above // Get, should not fetch, verified via the mock assertions above
req = TestRequest(t, RequestInfo{Key: "hello"}) req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.Equal(uint64(4), meta.Index) require.Equal(t, uint64(4), meta.Index)
require.True(meta.Hit) require.True(t, meta.Hit)
// Sleep for the expiry // Sleep for the expiry
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
@ -898,10 +880,10 @@ func TestCacheGet_expireBackgroudRefreshCancel(t *testing.T) {
// Get, should fetch // Get, should fetch
req = TestRequest(t, RequestInfo{Key: "hello"}) req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.Equal(uint64(4), meta.Index) require.Equal(t, uint64(4), meta.Index)
require.False(meta.Hit, "the fetch should not have re-populated the cache "+ require.False(t, meta.Hit, "the fetch should not have re-populated the cache "+
"entry after it expired so this get should be a miss") "entry after it expired so this get should be a miss")
// Sleep a tiny bit just to let maybe some background calls happen // Sleep a tiny bit just to let maybe some background calls happen
@ -915,8 +897,6 @@ func TestCacheGet_expireBackgroudRefreshCancel(t *testing.T) {
func TestCacheGet_expireBackgroudRefresh(t *testing.T) { func TestCacheGet_expireBackgroudRefresh(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := &MockType{} typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{ typ.On("RegisterOptions").Return(RegisterOptions{
LastGetTTL: 400 * time.Millisecond, LastGetTTL: 400 * time.Millisecond,
@ -948,18 +928,18 @@ func TestCacheGet_expireBackgroudRefresh(t *testing.T) {
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.Equal(uint64(4), meta.Index) require.Equal(t, uint64(4), meta.Index)
require.False(meta.Hit) require.False(t, meta.Hit)
// Get, should not fetch, verified via the mock assertions above // Get, should not fetch, verified via the mock assertions above
req = TestRequest(t, RequestInfo{Key: "hello"}) req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.Equal(uint64(4), meta.Index) require.Equal(t, uint64(4), meta.Index)
require.True(meta.Hit) require.True(t, meta.Hit)
// Sleep for the expiry // Sleep for the expiry
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
@ -971,10 +951,10 @@ func TestCacheGet_expireBackgroudRefresh(t *testing.T) {
// re-insert the value back into the cache and make it live forever). // re-insert the value back into the cache and make it live forever).
req = TestRequest(t, RequestInfo{Key: "hello"}) req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.Equal(uint64(4), meta.Index) require.Equal(t, uint64(4), meta.Index)
require.False(meta.Hit, "the fetch should not have re-populated the cache "+ require.False(t, meta.Hit, "the fetch should not have re-populated the cache "+
"entry after it expired so this get should be a miss") "entry after it expired so this get should be a miss")
// Sleep a tiny bit just to let maybe some background calls happen // Sleep a tiny bit just to let maybe some background calls happen
@ -991,8 +971,6 @@ func TestCacheGet_expireResetGet(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := &MockType{} typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{ typ.On("RegisterOptions").Return(RegisterOptions{
LastGetTTL: 150 * time.Millisecond, LastGetTTL: 150 * time.Millisecond,
@ -1009,9 +987,9 @@ func TestCacheGet_expireResetGet(t *testing.T) {
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Fetch multiple times, where the total time is well beyond // Fetch multiple times, where the total time is well beyond
// the TTL. We should not trigger any fetches during this time. // the TTL. We should not trigger any fetches during this time.
@ -1022,9 +1000,9 @@ func TestCacheGet_expireResetGet(t *testing.T) {
// Get, should not fetch // Get, should not fetch
req = TestRequest(t, RequestInfo{Key: "hello"}) req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.True(meta.Hit) require.True(t, meta.Hit)
} }
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
@ -1032,9 +1010,9 @@ func TestCacheGet_expireResetGet(t *testing.T) {
// Get, should fetch // Get, should fetch
req = TestRequest(t, RequestInfo{Key: "hello"}) req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen // Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call // then verify that we still only got the one call
@ -1046,8 +1024,6 @@ func TestCacheGet_expireResetGet(t *testing.T) {
func TestCacheGet_expireResetGetNoChange(t *testing.T) { func TestCacheGet_expireResetGetNoChange(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
// Create a closer so we can tell if the entry gets evicted. // Create a closer so we can tell if the entry gets evicted.
closer := &testCloser{} closer := &testCloser{}
@ -1080,19 +1056,19 @@ func TestCacheGet_expireResetGetNoChange(t *testing.T) {
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.Equal(uint64(10), meta.Index) require.Equal(t, uint64(10), meta.Index)
require.False(meta.Hit) require.False(t, meta.Hit)
// Do a blocking watch of the value that won't time out until after the TTL. // Do a blocking watch of the value that won't time out until after the TTL.
start := time.Now() start := time.Now()
req = TestRequest(t, RequestInfo{Key: "hello", MinIndex: 10, Timeout: 300 * time.Millisecond}) req = TestRequest(t, RequestInfo{Key: "hello", MinIndex: 10, Timeout: 300 * time.Millisecond})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.Equal(uint64(10), meta.Index) require.Equal(t, uint64(10), meta.Index)
require.GreaterOrEqual(time.Since(start).Milliseconds(), int64(300)) require.GreaterOrEqual(t, time.Since(start).Milliseconds(), int64(300))
// This is the point of this test! Even though we waited for a change for // This is the point of this test! Even though we waited for a change for
// longer than the TTL, we should have been updating the TTL so that the cache // longer than the TTL, we should have been updating the TTL so that the cache
@ -1100,7 +1076,7 @@ func TestCacheGet_expireResetGetNoChange(t *testing.T) {
// since that is not set for blocking Get calls but we can assert that the // since that is not set for blocking Get calls but we can assert that the
// entry was never closed (which assuming the test for eviction closing is // entry was never closed (which assuming the test for eviction closing is
// also passing is a reliable signal). // also passing is a reliable signal).
require.False(closer.isClosed(), "cache entry should not have been evicted") require.False(t, closer.isClosed(), "cache entry should not have been evicted")
// Sleep a tiny bit just to let maybe some background calls happen // Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call // then verify that we still only got the one call
@ -1116,8 +1092,6 @@ func TestCacheGet_expireClose(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := &MockType{} typ := &MockType{}
defer typ.AssertExpectations(t) defer typ.AssertExpectations(t)
c := New(Options{}) c := New(Options{})
@ -1137,16 +1111,16 @@ func TestCacheGet_expireClose(t *testing.T) {
ctx := context.Background() ctx := context.Background()
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(ctx, "t", req) result, meta, err := c.Get(ctx, "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
require.False(state.isClosed()) require.False(t, state.isClosed())
// Sleep for the expiry // Sleep for the expiry
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
// state.Close() should have been called // state.Close() should have been called
require.True(state.isClosed()) require.True(t, state.isClosed())
} }
type testCloser struct { type testCloser struct {
@ -1171,8 +1145,6 @@ func (t *testCloser) isClosed() bool {
func TestCacheGet_duplicateKeyDifferentType(t *testing.T) { func TestCacheGet_duplicateKeyDifferentType(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := TestType(t) typ := TestType(t)
defer typ.AssertExpectations(t) defer typ.AssertExpectations(t)
typ2 := TestType(t) typ2 := TestType(t)
@ -1189,23 +1161,23 @@ func TestCacheGet_duplicateKeyDifferentType(t *testing.T) {
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: "foo"}) req := TestRequest(t, RequestInfo{Key: "foo"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(100, result) require.Equal(t, 100, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Get from t2 with same key, should fetch // Get from t2 with same key, should fetch
req = TestRequest(t, RequestInfo{Key: "foo"}) req = TestRequest(t, RequestInfo{Key: "foo"})
result, meta, err = c.Get(context.Background(), "t2", req) result, meta, err = c.Get(context.Background(), "t2", req)
require.NoError(err) require.NoError(t, err)
require.Equal(200, result) require.Equal(t, 200, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Get from t again with same key, should cache // Get from t again with same key, should cache
req = TestRequest(t, RequestInfo{Key: "foo"}) req = TestRequest(t, RequestInfo{Key: "foo"})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(100, result) require.Equal(t, 100, result)
require.True(meta.Hit) require.True(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen // Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call // then verify that we still only got the one call
@ -1283,8 +1255,6 @@ func TestCacheGet_refreshAge(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
typ := &MockType{} typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{ typ.On("RegisterOptions").Return(RegisterOptions{
Refresh: true, Refresh: true,
@ -1330,11 +1300,11 @@ func TestCacheGet_refreshAge(t *testing.T) {
// Fetch again, non-blocking // Fetch again, non-blocking
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"})) result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.True(meta.Hit) require.True(t, meta.Hit)
// Age should be zero since background refresh was "active" // Age should be zero since background refresh was "active"
require.Equal(time.Duration(0), meta.Age) require.Equal(t, time.Duration(0), meta.Age)
} }
// Now fail the next background sync // Now fail the next background sync
@ -1350,21 +1320,21 @@ func TestCacheGet_refreshAge(t *testing.T) {
var lastAge time.Duration var lastAge time.Duration
{ {
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"})) result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.True(meta.Hit) require.True(t, meta.Hit)
// Age should be non-zero since background refresh was "active" // Age should be non-zero since background refresh was "active"
require.True(meta.Age > 0) require.True(t, meta.Age > 0)
lastAge = meta.Age lastAge = meta.Age
} }
// Wait a bit longer - age should increase by at least this much // Wait a bit longer - age should increase by at least this much
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
{ {
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"})) result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.True(meta.Hit) require.True(t, meta.Hit)
require.True(meta.Age > (lastAge + (1 * time.Millisecond))) require.True(t, meta.Age > (lastAge+(1*time.Millisecond)))
} }
// Now unfail the background refresh // Now unfail the background refresh
@ -1384,18 +1354,18 @@ func TestCacheGet_refreshAge(t *testing.T) {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"})) result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
// Should never error even if background is failing as we have cached value // Should never error even if background is failing as we have cached value
require.NoError(err) require.NoError(t, err)
require.True(meta.Hit) require.True(t, meta.Hit)
// Got the new value! // Got the new value!
if result == 10 { if result == 10 {
// Age should be zero since background refresh is "active" again // Age should be zero since background refresh is "active" again
t.Logf("Succeeded after %d attempts", attempts) t.Logf("Succeeded after %d attempts", attempts)
require.Equal(time.Duration(0), meta.Age) require.Equal(t, time.Duration(0), meta.Age)
timeout = false timeout = false
break break
} }
} }
require.False(timeout, "failed to observe update after %s", time.Since(t0)) require.False(t, timeout, "failed to observe update after %s", time.Since(t0))
} }
func TestCacheGet_nonRefreshAge(t *testing.T) { func TestCacheGet_nonRefreshAge(t *testing.T) {
@ -1405,8 +1375,6 @@ func TestCacheGet_nonRefreshAge(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
typ := &MockType{} typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{ typ.On("RegisterOptions").Return(RegisterOptions{
Refresh: false, Refresh: false,
@ -1440,10 +1408,10 @@ func TestCacheGet_nonRefreshAge(t *testing.T) {
// Fetch again, non-blocking // Fetch again, non-blocking
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"})) result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.True(meta.Hit) require.True(t, meta.Hit)
require.True(meta.Age > (5 * time.Millisecond)) require.True(t, meta.Age > (5*time.Millisecond))
lastAge = meta.Age lastAge = meta.Age
} }
@ -1452,11 +1420,11 @@ func TestCacheGet_nonRefreshAge(t *testing.T) {
{ {
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"})) result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Age should smaller again // Age should smaller again
require.True(meta.Age < lastAge) require.True(t, meta.Age < lastAge)
} }
{ {
@ -1468,10 +1436,10 @@ func TestCacheGet_nonRefreshAge(t *testing.T) {
// Fetch again, non-blocking // Fetch again, non-blocking
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"})) result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.True(meta.Hit) require.True(t, meta.Hit)
require.True(meta.Age > (5 * time.Millisecond)) require.True(t, meta.Age > (5*time.Millisecond))
lastAge = meta.Age lastAge = meta.Age
} }
@ -1481,11 +1449,11 @@ func TestCacheGet_nonRefreshAge(t *testing.T) {
Key: "hello", Key: "hello",
MaxAge: 1 * time.Millisecond, MaxAge: 1 * time.Millisecond,
})) }))
require.NoError(err) require.NoError(t, err)
require.Equal(8, result) require.Equal(t, 8, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Age should smaller again // Age should smaller again
require.True(meta.Age < lastAge) require.True(t, meta.Age < lastAge)
} }
} }
@ -1505,21 +1473,19 @@ func TestCacheGet_nonBlockingType(t *testing.T) {
require.Equal(t, uint64(0), opts.MinIndex) require.Equal(t, uint64(0), opts.MinIndex)
}) })
require := require.New(t)
// Get, should fetch // Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"}) req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req) result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Get, should not fetch since we have a cached value // Get, should not fetch since we have a cached value
req = TestRequest(t, RequestInfo{Key: "hello"}) req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.True(meta.Hit) require.True(t, meta.Hit)
// Get, should not attempt to fetch with blocking even if requested. The // Get, should not attempt to fetch with blocking even if requested. The
// assertions below about the value being the same combined with the fact the // assertions below about the value being the same combined with the fact the
@ -1531,25 +1497,25 @@ func TestCacheGet_nonBlockingType(t *testing.T) {
Timeout: 10 * time.Minute, Timeout: 10 * time.Minute,
}) })
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(42, result) require.Equal(t, 42, result)
require.True(meta.Hit) require.True(t, meta.Hit)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
// Get with a max age should fetch again // Get with a max age should fetch again
req = TestRequest(t, RequestInfo{Key: "hello", MaxAge: 5 * time.Millisecond}) req = TestRequest(t, RequestInfo{Key: "hello", MaxAge: 5 * time.Millisecond})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(43, result) require.Equal(t, 43, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Get with a must revalidate should fetch again even without a delay. // Get with a must revalidate should fetch again even without a delay.
req = TestRequest(t, RequestInfo{Key: "hello", MustRevalidate: true}) req = TestRequest(t, RequestInfo{Key: "hello", MustRevalidate: true})
result, meta, err = c.Get(context.Background(), "t", req) result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err) require.NoError(t, err)
require.Equal(43, result) require.Equal(t, 43, result)
require.False(meta.Hit) require.False(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen // Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call // then verify that we still only got the one call

View File

@ -51,15 +51,13 @@ func TestCacheNotify(t *testing.T) {
// after cancellation as if it had timed out. // after cancellation as if it had timed out.
typ.Static(FetchResult{Value: 42, Index: 8}, nil).WaitUntil(trigger[4]) typ.Static(FetchResult{Value: 42, Index: 8}, nil).WaitUntil(trigger[4])
require := require.New(t)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
ch := make(chan UpdateEvent) ch := make(chan UpdateEvent)
err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello"}), "test", ch) err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello"}), "test", ch)
require.NoError(err) require.NoError(t, err)
// Should receive the error with index == 0 first. // Should receive the error with index == 0 first.
TestCacheNotifyChResult(t, ch, UpdateEvent{ TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -70,7 +68,7 @@ func TestCacheNotify(t *testing.T) {
}) })
// There should be no more updates delivered yet // There should be no more updates delivered yet
require.Len(ch, 0) require.Len(t, ch, 0)
// Trigger blocking query to return a "change" // Trigger blocking query to return a "change"
close(trigger[0]) close(trigger[0])
@ -102,7 +100,7 @@ func TestCacheNotify(t *testing.T) {
// requests to the "backend" // requests to the "backend"
// - that multiple watchers can distinguish their results using correlationID // - that multiple watchers can distinguish their results using correlationID
err = c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello"}), "test2", ch) err = c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello"}), "test2", ch)
require.NoError(err) require.NoError(t, err)
// Should get test2 notify immediately, and it should be a cache hit // Should get test2 notify immediately, and it should be a cache hit
TestCacheNotifyChResult(t, ch, UpdateEvent{ TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -121,7 +119,7 @@ func TestCacheNotify(t *testing.T) {
// it's only a sanity check, if we somehow _do_ get the change delivered later // it's only a sanity check, if we somehow _do_ get the change delivered later
// than 10ms the next value assertion will fail anyway. // than 10ms the next value assertion will fail anyway.
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
require.Len(ch, 0) require.Len(t, ch, 0)
// Trigger final update // Trigger final update
close(trigger[3]) close(trigger[3])
@ -183,15 +181,13 @@ func TestCacheNotifyPolling(t *testing.T) {
typ.Static(FetchResult{Value: 12, Index: 1}, nil).Once() typ.Static(FetchResult{Value: 12, Index: 1}, nil).Once()
typ.Static(FetchResult{Value: 42, Index: 1}, nil).Once() typ.Static(FetchResult{Value: 42, Index: 1}, nil).Once()
require := require.New(t)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
ch := make(chan UpdateEvent) ch := make(chan UpdateEvent)
err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello", MaxAge: 100 * time.Millisecond}), "test", ch) err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello", MaxAge: 100 * time.Millisecond}), "test", ch)
require.NoError(err) require.NoError(t, err)
// Should receive the first result pretty soon // Should receive the first result pretty soon
TestCacheNotifyChResult(t, ch, UpdateEvent{ TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -202,32 +198,32 @@ func TestCacheNotifyPolling(t *testing.T) {
}) })
// There should be no more updates delivered yet // There should be no more updates delivered yet
require.Len(ch, 0) require.Len(t, ch, 0)
// make sure the updates do not come too quickly // make sure the updates do not come too quickly
select { select {
case <-time.After(50 * time.Millisecond): case <-time.After(50 * time.Millisecond):
case <-ch: case <-ch:
require.Fail("Received update too early") require.Fail(t, "Received update too early")
} }
// make sure we get the update not too far out. // make sure we get the update not too far out.
select { select {
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
require.Fail("Didn't receive the notification") require.Fail(t, "Didn't receive the notification")
case result := <-ch: case result := <-ch:
require.Equal(result.Result, 12) require.Equal(t, result.Result, 12)
require.Equal(result.CorrelationID, "test") require.Equal(t, result.CorrelationID, "test")
require.Equal(result.Meta.Hit, false) require.Equal(t, result.Meta.Hit, false)
require.Equal(result.Meta.Index, uint64(1)) require.Equal(t, result.Meta.Index, uint64(1))
// pretty conservative check it should be even newer because without a second // pretty conservative check it should be even newer because without a second
// notifier each value returned will have been executed just then and not served // notifier each value returned will have been executed just then and not served
// from the cache. // from the cache.
require.True(result.Meta.Age < 50*time.Millisecond) require.True(t, result.Meta.Age < 50*time.Millisecond)
require.NoError(result.Err) require.NoError(t, result.Err)
} }
require.Len(ch, 0) require.Len(t, ch, 0)
// Register a second observer using same chan and request. Note that this is // Register a second observer using same chan and request. Note that this is
// testing a few things implicitly: // testing a few things implicitly:
@ -235,7 +231,7 @@ func TestCacheNotifyPolling(t *testing.T) {
// requests to the "backend" // requests to the "backend"
// - that multiple watchers can distinguish their results using correlationID // - that multiple watchers can distinguish their results using correlationID
err = c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello", MaxAge: 100 * time.Millisecond}), "test2", ch) err = c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello", MaxAge: 100 * time.Millisecond}), "test2", ch)
require.NoError(err) require.NoError(t, err)
// Should get test2 notify immediately, and it should be a cache hit // Should get test2 notify immediately, and it should be a cache hit
TestCacheNotifyChResult(t, ch, UpdateEvent{ TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -245,7 +241,7 @@ func TestCacheNotifyPolling(t *testing.T) {
Err: nil, Err: nil,
}) })
require.Len(ch, 0) require.Len(t, ch, 0)
// wait for the next batch of responses // wait for the next batch of responses
events := make([]UpdateEvent, 0) events := make([]UpdateEvent, 0)
@ -255,25 +251,25 @@ func TestCacheNotifyPolling(t *testing.T) {
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
select { select {
case <-timeout: case <-timeout:
require.Fail("UpdateEvent not received in time") require.Fail(t, "UpdateEvent not received in time")
case eve := <-ch: case eve := <-ch:
events = append(events, eve) events = append(events, eve)
} }
} }
require.Equal(events[0].Result, 42) require.Equal(t, events[0].Result, 42)
require.Equal(events[0].Meta.Hit && events[1].Meta.Hit, false) require.Equal(t, events[0].Meta.Hit && events[1].Meta.Hit, false)
require.Equal(events[0].Meta.Index, uint64(1)) require.Equal(t, events[0].Meta.Index, uint64(1))
require.True(events[0].Meta.Age < 50*time.Millisecond) require.True(t, events[0].Meta.Age < 50*time.Millisecond)
require.NoError(events[0].Err) require.NoError(t, events[0].Err)
require.Equal(events[1].Result, 42) require.Equal(t, events[1].Result, 42)
// Sometimes this would be a hit and others not. It all depends on when the various getWithIndex calls got fired. // Sometimes this would be a hit and others not. It all depends on when the various getWithIndex calls got fired.
// If both are done concurrently then it will not be a cache hit but the request gets single flighted and both // If both are done concurrently then it will not be a cache hit but the request gets single flighted and both
// get notified at the same time. // get notified at the same time.
// require.Equal(events[1].Meta.Hit, true) // require.Equal(t,events[1].Meta.Hit, true)
require.Equal(events[1].Meta.Index, uint64(1)) require.Equal(t, events[1].Meta.Index, uint64(1))
require.True(events[1].Meta.Age < 100*time.Millisecond) require.True(t, events[1].Meta.Age < 100*time.Millisecond)
require.NoError(events[1].Err) require.NoError(t, events[1].Err)
} }
// Test that a refresh performs a backoff. // Test that a refresh performs a backoff.
@ -298,15 +294,13 @@ func TestCacheWatch_ErrorBackoff(t *testing.T) {
atomic.AddUint32(&retries, 1) atomic.AddUint32(&retries, 1)
}) })
require := require.New(t)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
ch := make(chan UpdateEvent) ch := make(chan UpdateEvent)
err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello"}), "test", ch) err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello"}), "test", ch)
require.NoError(err) require.NoError(t, err)
// Should receive the first result pretty soon // Should receive the first result pretty soon
TestCacheNotifyChResult(t, ch, UpdateEvent{ TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -331,15 +325,15 @@ OUT:
break OUT break OUT
case u := <-ch: case u := <-ch:
numErrors++ numErrors++
require.Error(u.Err) require.Error(t, u.Err)
} }
} }
// Must be fewer than 10 failures in that time // Must be fewer than 10 failures in that time
require.True(numErrors < 10, fmt.Sprintf("numErrors: %d", numErrors)) require.True(t, numErrors < 10, fmt.Sprintf("numErrors: %d", numErrors))
// Check the number of RPCs as a sanity check too // Check the number of RPCs as a sanity check too
actual := atomic.LoadUint32(&retries) actual := atomic.LoadUint32(&retries)
require.True(actual < 10, fmt.Sprintf("actual: %d", actual)) require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
} }
// Test that a refresh performs a backoff. // Test that a refresh performs a backoff.
@ -363,15 +357,13 @@ func TestCacheWatch_ErrorBackoffNonBlocking(t *testing.T) {
atomic.AddUint32(&retries, 1) atomic.AddUint32(&retries, 1)
}) })
require := require.New(t)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
ch := make(chan UpdateEvent) ch := make(chan UpdateEvent)
err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello", MaxAge: 100 * time.Millisecond}), "test", ch) err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello", MaxAge: 100 * time.Millisecond}), "test", ch)
require.NoError(err) require.NoError(t, err)
// Should receive the first result pretty soon // Should receive the first result pretty soon
TestCacheNotifyChResult(t, ch, UpdateEvent{ TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -399,13 +391,13 @@ OUT:
break OUT break OUT
case u := <-ch: case u := <-ch:
numErrors++ numErrors++
require.Error(u.Err) require.Error(t, u.Err)
} }
} }
// Must be fewer than 10 failures in that time // Must be fewer than 10 failures in that time
require.True(numErrors < 10, fmt.Sprintf("numErrors: %d", numErrors)) require.True(t, numErrors < 10, fmt.Sprintf("numErrors: %d", numErrors))
// Check the number of RPCs as a sanity check too // Check the number of RPCs as a sanity check too
actual := atomic.LoadUint32(&retries) actual := atomic.LoadUint32(&retries)
require.True(actual < 10, fmt.Sprintf("actual: %d", actual)) require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
} }

View File

@ -635,9 +635,6 @@ func TestCatalogServiceNodes(t *testing.T) {
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
assert := assert.New(t)
require := require.New(t)
// Make sure an empty list is returned, not a nil // Make sure an empty list is returned, not a nil
{ {
req, _ := http.NewRequest("GET", "/v1/catalog/service/api?tag=a", nil) req, _ := http.NewRequest("GET", "/v1/catalog/service/api?tag=a", nil)
@ -691,12 +688,12 @@ func TestCatalogServiceNodes(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/catalog/service/api?cached", nil) req, _ := http.NewRequest("GET", "/v1/catalog/service/api?cached", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.CatalogServiceNodes(resp, req) obj, err := a.srv.CatalogServiceNodes(resp, req)
require.NoError(err) require.NoError(t, err)
nodes := obj.(structs.ServiceNodes) nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1) assert.Len(t, nodes, 1)
// Should be a cache miss // Should be a cache miss
assert.Equal("MISS", resp.Header().Get("X-Cache")) assert.Equal(t, "MISS", resp.Header().Get("X-Cache"))
} }
{ {
@ -704,13 +701,13 @@ func TestCatalogServiceNodes(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/catalog/service/api?cached", nil) req, _ := http.NewRequest("GET", "/v1/catalog/service/api?cached", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.CatalogServiceNodes(resp, req) obj, err := a.srv.CatalogServiceNodes(resp, req)
require.NoError(err) require.NoError(t, err)
nodes := obj.(structs.ServiceNodes) nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1) assert.Len(t, nodes, 1)
// Should be a cache HIT now! // Should be a cache HIT now!
assert.Equal("HIT", resp.Header().Get("X-Cache")) assert.Equal(t, "HIT", resp.Header().Get("X-Cache"))
assert.Equal("0", resp.Header().Get("Age")) assert.Equal(t, "0", resp.Header().Get("Age"))
} }
// Ensure background refresh works // Ensure background refresh works
@ -719,7 +716,7 @@ func TestCatalogServiceNodes(t *testing.T) {
args2 := args args2 := args
args2.Node = "bar" args2.Node = "bar"
args2.Address = "127.0.0.2" args2.Address = "127.0.0.2"
require.NoError(a.RPC("Catalog.Register", args, &out)) require.NoError(t, a.RPC("Catalog.Register", args, &out))
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
// List it again // List it again
@ -1057,7 +1054,6 @@ func TestCatalogServiceNodes_ConnectProxy(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1") testrpc.WaitForLeader(t, a.RPC, "dc1")
@ -1065,19 +1061,19 @@ func TestCatalogServiceNodes_ConnectProxy(t *testing.T) {
// Register // Register
args := structs.TestRegisterRequestProxy(t) args := structs.TestRegisterRequestProxy(t)
var out struct{} var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out)) assert.Nil(t, a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf( req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/service/%s", args.Service.Service), nil) "/v1/catalog/service/%s", args.Service.Service), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.CatalogServiceNodes(resp, req) obj, err := a.srv.CatalogServiceNodes(resp, req)
assert.Nil(err) assert.Nil(t, err)
assertIndex(t, resp) assertIndex(t, resp)
nodes := obj.(structs.ServiceNodes) nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1) assert.Len(t, nodes, 1)
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind) assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal(args.Service.Proxy, nodes[0].ServiceProxy) assert.Equal(t, args.Service.Proxy, nodes[0].ServiceProxy)
} }
// Test that the Connect-compatible endpoints can be queried for a // Test that the Connect-compatible endpoints can be queried for a
@ -1089,7 +1085,6 @@ func TestCatalogConnectServiceNodes_good(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1") testrpc.WaitForLeader(t, a.RPC, "dc1")
@ -1098,20 +1093,20 @@ func TestCatalogConnectServiceNodes_good(t *testing.T) {
args := structs.TestRegisterRequestProxy(t) args := structs.TestRegisterRequestProxy(t)
args.Service.Address = "127.0.0.55" args.Service.Address = "127.0.0.55"
var out struct{} var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out)) assert.Nil(t, a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf( req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/connect/%s", args.Service.Proxy.DestinationServiceName), nil) "/v1/catalog/connect/%s", args.Service.Proxy.DestinationServiceName), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.CatalogConnectServiceNodes(resp, req) obj, err := a.srv.CatalogConnectServiceNodes(resp, req)
assert.Nil(err) assert.Nil(t, err)
assertIndex(t, resp) assertIndex(t, resp)
nodes := obj.(structs.ServiceNodes) nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1) assert.Len(t, nodes, 1)
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind) assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal(args.Service.Address, nodes[0].ServiceAddress) assert.Equal(t, args.Service.Address, nodes[0].ServiceAddress)
assert.Equal(args.Service.Proxy, nodes[0].ServiceProxy) assert.Equal(t, args.Service.Proxy, nodes[0].ServiceProxy)
} }
func TestCatalogConnectServiceNodes_Filter(t *testing.T) { func TestCatalogConnectServiceNodes_Filter(t *testing.T) {
@ -1307,7 +1302,6 @@ func TestCatalogNodeServices_ConnectProxy(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -1315,19 +1309,19 @@ func TestCatalogNodeServices_ConnectProxy(t *testing.T) {
// Register // Register
args := structs.TestRegisterRequestProxy(t) args := structs.TestRegisterRequestProxy(t)
var out struct{} var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out)) assert.Nil(t, a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf( req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/node/%s", args.Node), nil) "/v1/catalog/node/%s", args.Node), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.CatalogNodeServices(resp, req) obj, err := a.srv.CatalogNodeServices(resp, req)
assert.Nil(err) assert.Nil(t, err)
assertIndex(t, resp) assertIndex(t, resp)
ns := obj.(*structs.NodeServices) ns := obj.(*structs.NodeServices)
assert.Len(ns.Services, 1) assert.Len(t, ns.Services, 1)
v := ns.Services[args.Service.Service] v := ns.Services[args.Service.Service]
assert.Equal(structs.ServiceKindConnectProxy, v.Kind) assert.Equal(t, structs.ServiceKindConnectProxy, v.Kind)
} }
func TestCatalogNodeServices_WanTranslation(t *testing.T) { func TestCatalogNodeServices_WanTranslation(t *testing.T) {

View File

@ -88,7 +88,6 @@ enable_acl_replication = true
func TestLoad_DeprecatedConfig_ACLMasterTokens(t *testing.T) { func TestLoad_DeprecatedConfig_ACLMasterTokens(t *testing.T) {
t.Run("top-level fields", func(t *testing.T) { t.Run("top-level fields", func(t *testing.T) {
require := require.New(t)
opts := LoadOpts{ opts := LoadOpts{
HCL: []string{` HCL: []string{`
@ -101,21 +100,20 @@ func TestLoad_DeprecatedConfig_ACLMasterTokens(t *testing.T) {
patchLoadOptsShims(&opts) patchLoadOptsShims(&opts)
result, err := Load(opts) result, err := Load(opts)
require.NoError(err) require.NoError(t, err)
expectWarns := []string{ expectWarns := []string{
deprecationWarning("acl_master_token", "acl.tokens.initial_management"), deprecationWarning("acl_master_token", "acl.tokens.initial_management"),
deprecationWarning("acl_agent_master_token", "acl.tokens.agent_recovery"), deprecationWarning("acl_agent_master_token", "acl.tokens.agent_recovery"),
} }
require.ElementsMatch(expectWarns, result.Warnings) require.ElementsMatch(t, expectWarns, result.Warnings)
rt := result.RuntimeConfig rt := result.RuntimeConfig
require.Equal("token1", rt.ACLInitialManagementToken) require.Equal(t, "token1", rt.ACLInitialManagementToken)
require.Equal("token2", rt.ACLTokens.ACLAgentRecoveryToken) require.Equal(t, "token2", rt.ACLTokens.ACLAgentRecoveryToken)
}) })
t.Run("embedded in tokens struct", func(t *testing.T) { t.Run("embedded in tokens struct", func(t *testing.T) {
require := require.New(t)
opts := LoadOpts{ opts := LoadOpts{
HCL: []string{` HCL: []string{`
@ -132,21 +130,20 @@ func TestLoad_DeprecatedConfig_ACLMasterTokens(t *testing.T) {
patchLoadOptsShims(&opts) patchLoadOptsShims(&opts)
result, err := Load(opts) result, err := Load(opts)
require.NoError(err) require.NoError(t, err)
expectWarns := []string{ expectWarns := []string{
deprecationWarning("acl.tokens.master", "acl.tokens.initial_management"), deprecationWarning("acl.tokens.master", "acl.tokens.initial_management"),
deprecationWarning("acl.tokens.agent_master", "acl.tokens.agent_recovery"), deprecationWarning("acl.tokens.agent_master", "acl.tokens.agent_recovery"),
} }
require.ElementsMatch(expectWarns, result.Warnings) require.ElementsMatch(t, expectWarns, result.Warnings)
rt := result.RuntimeConfig rt := result.RuntimeConfig
require.Equal("token1", rt.ACLInitialManagementToken) require.Equal(t, "token1", rt.ACLInitialManagementToken)
require.Equal("token2", rt.ACLTokens.ACLAgentRecoveryToken) require.Equal(t, "token2", rt.ACLTokens.ACLAgentRecoveryToken)
}) })
t.Run("both", func(t *testing.T) { t.Run("both", func(t *testing.T) {
require := require.New(t)
opts := LoadOpts{ opts := LoadOpts{
HCL: []string{` HCL: []string{`
@ -166,10 +163,10 @@ func TestLoad_DeprecatedConfig_ACLMasterTokens(t *testing.T) {
patchLoadOptsShims(&opts) patchLoadOptsShims(&opts)
result, err := Load(opts) result, err := Load(opts)
require.NoError(err) require.NoError(t, err)
rt := result.RuntimeConfig rt := result.RuntimeConfig
require.Equal("token3", rt.ACLInitialManagementToken) require.Equal(t, "token3", rt.ACLInitialManagementToken)
require.Equal("token4", rt.ACLTokens.ACLAgentRecoveryToken) require.Equal(t, "token4", rt.ACLTokens.ACLAgentRecoveryToken)
}) })
} }

View File

@ -149,7 +149,6 @@ func TestConfig_Delete(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -171,7 +170,7 @@ func TestConfig_Delete(t *testing.T) {
} }
for _, req := range reqs { for _, req := range reqs {
out := false out := false
require.NoError(a.RPC("ConfigEntry.Apply", &req, &out)) require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &out))
} }
// Delete an entry. // Delete an entry.
@ -179,7 +178,7 @@ func TestConfig_Delete(t *testing.T) {
req, _ := http.NewRequest("DELETE", "/v1/config/service-defaults/bar", nil) req, _ := http.NewRequest("DELETE", "/v1/config/service-defaults/bar", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.Config(resp, req) _, err := a.srv.Config(resp, req)
require.NoError(err) require.NoError(t, err)
} }
// Get the remaining entry. // Get the remaining entry.
{ {
@ -188,11 +187,11 @@ func TestConfig_Delete(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var out structs.IndexedConfigEntries var out structs.IndexedConfigEntries
require.NoError(a.RPC("ConfigEntry.List", &args, &out)) require.NoError(t, a.RPC("ConfigEntry.List", &args, &out))
require.Equal(structs.ServiceDefaults, out.Kind) require.Equal(t, structs.ServiceDefaults, out.Kind)
require.Len(out.Entries, 1) require.Len(t, out.Entries, 1)
entry := out.Entries[0].(*structs.ServiceConfigEntry) entry := out.Entries[0].(*structs.ServiceConfigEntry)
require.Equal(entry.Name, "foo") require.Equal(t, entry.Name, "foo")
} }
} }
@ -202,8 +201,6 @@ func TestConfig_Delete_CAS(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -214,20 +211,20 @@ func TestConfig_Delete_CAS(t *testing.T) {
Name: "foo", Name: "foo",
} }
var created bool var created bool
require.NoError(a.RPC("ConfigEntry.Apply", &structs.ConfigEntryRequest{ require.NoError(t, a.RPC("ConfigEntry.Apply", &structs.ConfigEntryRequest{
Datacenter: "dc1", Datacenter: "dc1",
Entry: entry, Entry: entry,
}, &created)) }, &created))
require.True(created) require.True(t, created)
// Read it back to get its ModifyIndex. // Read it back to get its ModifyIndex.
var out structs.ConfigEntryResponse var out structs.ConfigEntryResponse
require.NoError(a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{ require.NoError(t, a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{
Datacenter: "dc1", Datacenter: "dc1",
Kind: entry.Kind, Kind: entry.Kind,
Name: entry.Name, Name: entry.Name,
}, &out)) }, &out))
require.NotNil(out.Entry) require.NotNil(t, out.Entry)
modifyIndex := out.Entry.GetRaftIndex().ModifyIndex modifyIndex := out.Entry.GetRaftIndex().ModifyIndex
@ -238,20 +235,20 @@ func TestConfig_Delete_CAS(t *testing.T) {
nil, nil,
) )
rawRsp, err := a.srv.Config(httptest.NewRecorder(), req) rawRsp, err := a.srv.Config(httptest.NewRecorder(), req)
require.NoError(err) require.NoError(t, err)
deleted, isBool := rawRsp.(bool) deleted, isBool := rawRsp.(bool)
require.True(isBool, "response should be a boolean") require.True(t, isBool, "response should be a boolean")
require.False(deleted, "entry should not have been deleted") require.False(t, deleted, "entry should not have been deleted")
// Verify it was not deleted. // Verify it was not deleted.
var out structs.ConfigEntryResponse var out structs.ConfigEntryResponse
require.NoError(a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{ require.NoError(t, a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{
Datacenter: "dc1", Datacenter: "dc1",
Kind: entry.Kind, Kind: entry.Kind,
Name: entry.Name, Name: entry.Name,
}, &out)) }, &out))
require.NotNil(out.Entry) require.NotNil(t, out.Entry)
}) })
t.Run("attempt to delete with a valid index", func(t *testing.T) { t.Run("attempt to delete with a valid index", func(t *testing.T) {
@ -261,20 +258,20 @@ func TestConfig_Delete_CAS(t *testing.T) {
nil, nil,
) )
rawRsp, err := a.srv.Config(httptest.NewRecorder(), req) rawRsp, err := a.srv.Config(httptest.NewRecorder(), req)
require.NoError(err) require.NoError(t, err)
deleted, isBool := rawRsp.(bool) deleted, isBool := rawRsp.(bool)
require.True(isBool, "response should be a boolean") require.True(t, isBool, "response should be a boolean")
require.True(deleted, "entry should have been deleted") require.True(t, deleted, "entry should have been deleted")
// Verify it was deleted. // Verify it was deleted.
var out structs.ConfigEntryResponse var out structs.ConfigEntryResponse
require.NoError(a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{ require.NoError(t, a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{
Datacenter: "dc1", Datacenter: "dc1",
Kind: entry.Kind, Kind: entry.Kind,
Name: entry.Name, Name: entry.Name,
}, &out)) }, &out))
require.Nil(out.Entry) require.Nil(t, out.Entry)
}) })
} }
@ -285,7 +282,6 @@ func TestConfig_Apply(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -301,7 +297,7 @@ func TestConfig_Apply(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/config", body) req, _ := http.NewRequest("PUT", "/v1/config", body)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ConfigApply(resp, req) _, err := a.srv.ConfigApply(resp, req)
require.NoError(err) require.NoError(t, err)
if resp.Code != 200 { if resp.Code != 200 {
t.Fatalf(resp.Body.String()) t.Fatalf(resp.Body.String())
} }
@ -314,10 +310,10 @@ func TestConfig_Apply(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var out structs.ConfigEntryResponse var out structs.ConfigEntryResponse
require.NoError(a.RPC("ConfigEntry.Get", &args, &out)) require.NoError(t, a.RPC("ConfigEntry.Get", &args, &out))
require.NotNil(out.Entry) require.NotNil(t, out.Entry)
entry := out.Entry.(*structs.ServiceConfigEntry) entry := out.Entry.(*structs.ServiceConfigEntry)
require.Equal(entry.Name, "foo") require.Equal(t, entry.Name, "foo")
} }
} }
@ -503,7 +499,6 @@ func TestConfig_Apply_CAS(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -519,7 +514,7 @@ func TestConfig_Apply_CAS(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/config", body) req, _ := http.NewRequest("PUT", "/v1/config", body)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ConfigApply(resp, req) _, err := a.srv.ConfigApply(resp, req)
require.NoError(err) require.NoError(t, err)
if resp.Code != 200 { if resp.Code != 200 {
t.Fatalf(resp.Body.String()) t.Fatalf(resp.Body.String())
} }
@ -532,8 +527,8 @@ func TestConfig_Apply_CAS(t *testing.T) {
} }
out := &structs.ConfigEntryResponse{} out := &structs.ConfigEntryResponse{}
require.NoError(a.RPC("ConfigEntry.Get", &args, out)) require.NoError(t, a.RPC("ConfigEntry.Get", &args, out))
require.NotNil(out.Entry) require.NotNil(t, out.Entry)
entry := out.Entry.(*structs.ServiceConfigEntry) entry := out.Entry.(*structs.ServiceConfigEntry)
body = bytes.NewBuffer([]byte(` body = bytes.NewBuffer([]byte(`
@ -546,11 +541,11 @@ func TestConfig_Apply_CAS(t *testing.T) {
req, _ = http.NewRequest("PUT", "/v1/config?cas=0", body) req, _ = http.NewRequest("PUT", "/v1/config?cas=0", body)
resp = httptest.NewRecorder() resp = httptest.NewRecorder()
writtenRaw, err := a.srv.ConfigApply(resp, req) writtenRaw, err := a.srv.ConfigApply(resp, req)
require.NoError(err) require.NoError(t, err)
written, ok := writtenRaw.(bool) written, ok := writtenRaw.(bool)
require.True(ok) require.True(t, ok)
require.False(written) require.False(t, written)
require.EqualValues(200, resp.Code, resp.Body.String()) require.EqualValues(t, 200, resp.Code, resp.Body.String())
body = bytes.NewBuffer([]byte(` body = bytes.NewBuffer([]byte(`
{ {
@ -562,11 +557,11 @@ func TestConfig_Apply_CAS(t *testing.T) {
req, _ = http.NewRequest("PUT", fmt.Sprintf("/v1/config?cas=%d", entry.GetRaftIndex().ModifyIndex), body) req, _ = http.NewRequest("PUT", fmt.Sprintf("/v1/config?cas=%d", entry.GetRaftIndex().ModifyIndex), body)
resp = httptest.NewRecorder() resp = httptest.NewRecorder()
writtenRaw, err = a.srv.ConfigApply(resp, req) writtenRaw, err = a.srv.ConfigApply(resp, req)
require.NoError(err) require.NoError(t, err)
written, ok = writtenRaw.(bool) written, ok = writtenRaw.(bool)
require.True(ok) require.True(t, ok)
require.True(written) require.True(t, written)
require.EqualValues(200, resp.Code, resp.Body.String()) require.EqualValues(t, 200, resp.Code, resp.Body.String())
// Get the entry remaining entry. // Get the entry remaining entry.
args = structs.ConfigEntryQuery{ args = structs.ConfigEntryQuery{
@ -576,10 +571,10 @@ func TestConfig_Apply_CAS(t *testing.T) {
} }
out = &structs.ConfigEntryResponse{} out = &structs.ConfigEntryResponse{}
require.NoError(a.RPC("ConfigEntry.Get", &args, out)) require.NoError(t, a.RPC("ConfigEntry.Get", &args, out))
require.NotNil(out.Entry) require.NotNil(t, out.Entry)
newEntry := out.Entry.(*structs.ServiceConfigEntry) newEntry := out.Entry.(*structs.ServiceConfigEntry)
require.NotEqual(entry.GetRaftIndex(), newEntry.GetRaftIndex()) require.NotEqual(t, entry.GetRaftIndex(), newEntry.GetRaftIndex())
} }
func TestConfig_Apply_Decoding(t *testing.T) { func TestConfig_Apply_Decoding(t *testing.T) {

View File

@ -38,7 +38,6 @@ func TestAWSBootstrapAndSignPrimary(t *testing.T) {
for _, tc := range KeyTestCases { for _, tc := range KeyTestCases {
tc := tc tc := tc
t.Run(tc.Desc, func(t *testing.T) { t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
cfg := map[string]interface{}{ cfg := map[string]interface{}{
"PrivateKeyType": tc.KeyType, "PrivateKeyType": tc.KeyType,
"PrivateKeyBits": tc.KeyBits, "PrivateKeyBits": tc.KeyBits,
@ -48,33 +47,33 @@ func TestAWSBootstrapAndSignPrimary(t *testing.T) {
defer provider.Cleanup(true, nil) defer provider.Cleanup(true, nil)
// Generate the root // Generate the root
require.NoError(provider.GenerateRoot()) require.NoError(t, provider.GenerateRoot())
// Fetch Active Root // Fetch Active Root
rootPEM, err := provider.ActiveRoot() rootPEM, err := provider.ActiveRoot()
require.NoError(err) require.NoError(t, err)
// Generate Intermediate (not actually needed for this provider for now // Generate Intermediate (not actually needed for this provider for now
// but this simulates the calls in Server.initializeRoot). // but this simulates the calls in Server.initializeRoot).
interPEM, err := provider.GenerateIntermediate() interPEM, err := provider.GenerateIntermediate()
require.NoError(err) require.NoError(t, err)
// Should be the same for now // Should be the same for now
require.Equal(rootPEM, interPEM) require.Equal(t, rootPEM, interPEM)
// Ensure they use the right key type // Ensure they use the right key type
rootCert, err := connect.ParseCert(rootPEM) rootCert, err := connect.ParseCert(rootPEM)
require.NoError(err) require.NoError(t, err)
keyType, keyBits, err := connect.KeyInfoFromCert(rootCert) keyType, keyBits, err := connect.KeyInfoFromCert(rootCert)
require.NoError(err) require.NoError(t, err)
require.Equal(tc.KeyType, keyType) require.Equal(t, tc.KeyType, keyType)
require.Equal(tc.KeyBits, keyBits) require.Equal(t, tc.KeyBits, keyBits)
// Ensure that the root cert ttl is withing the configured value // Ensure that the root cert ttl is withing the configured value
// computation is similar to how we are passing the TTL thru the aws client // computation is similar to how we are passing the TTL thru the aws client
expectedTime := time.Now().AddDate(0, 0, int(8761*60*time.Minute/day)).UTC() expectedTime := time.Now().AddDate(0, 0, int(8761*60*time.Minute/day)).UTC()
require.WithinDuration(expectedTime, rootCert.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured") require.WithinDuration(t, expectedTime, rootCert.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured")
// Sign a leaf with it // Sign a leaf with it
testSignAndValidate(t, provider, rootPEM, nil) testSignAndValidate(t, provider, rootPEM, nil)

View File

@ -78,26 +78,25 @@ func requireNotEncoded(t *testing.T, v []byte) {
func TestConsulCAProvider_Bootstrap(t *testing.T) { func TestConsulCAProvider_Bootstrap(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
conf := testConsulCAConfig() conf := testConsulCAConfig()
delegate := newMockDelegate(t, conf) delegate := newMockDelegate(t, conf)
provider := TestConsulProvider(t, delegate) provider := TestConsulProvider(t, delegate)
require.NoError(provider.Configure(testProviderConfig(conf))) require.NoError(t, provider.Configure(testProviderConfig(conf)))
require.NoError(provider.GenerateRoot()) require.NoError(t, provider.GenerateRoot())
root, err := provider.ActiveRoot() root, err := provider.ActiveRoot()
require.NoError(err) require.NoError(t, err)
// Intermediate should be the same cert. // Intermediate should be the same cert.
inter, err := provider.ActiveIntermediate() inter, err := provider.ActiveIntermediate()
require.NoError(err) require.NoError(t, err)
require.Equal(root, inter) require.Equal(t, root, inter)
// Should be a valid cert // Should be a valid cert
parsed, err := connect.ParseCert(root) parsed, err := connect.ParseCert(root)
require.NoError(err) require.NoError(t, err)
require.Equal(parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", conf.ClusterID)) require.Equal(t, parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", conf.ClusterID))
requireNotEncoded(t, parsed.SubjectKeyId) requireNotEncoded(t, parsed.SubjectKeyId)
requireNotEncoded(t, parsed.AuthorityKeyId) requireNotEncoded(t, parsed.AuthorityKeyId)
@ -105,16 +104,15 @@ func TestConsulCAProvider_Bootstrap(t *testing.T) {
// notice that we allow a margin of "error" of 10 minutes between the // notice that we allow a margin of "error" of 10 minutes between the
// generateCA() creation and this check // generateCA() creation and this check
defaultRootCertTTL, err := time.ParseDuration(structs.DefaultRootCertTTL) defaultRootCertTTL, err := time.ParseDuration(structs.DefaultRootCertTTL)
require.NoError(err) require.NoError(t, err)
expectedNotAfter := time.Now().Add(defaultRootCertTTL).UTC() expectedNotAfter := time.Now().Add(defaultRootCertTTL).UTC()
require.WithinDuration(expectedNotAfter, parsed.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured") require.WithinDuration(t, expectedNotAfter, parsed.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured")
} }
func TestConsulCAProvider_Bootstrap_WithCert(t *testing.T) { func TestConsulCAProvider_Bootstrap_WithCert(t *testing.T) {
t.Parallel() t.Parallel()
// Make sure setting a custom private key/root cert works. // Make sure setting a custom private key/root cert works.
require := require.New(t)
rootCA := connect.TestCAWithTTL(t, nil, 5*time.Hour) rootCA := connect.TestCAWithTTL(t, nil, 5*time.Hour)
conf := testConsulCAConfig() conf := testConsulCAConfig()
conf.Config = map[string]interface{}{ conf.Config = map[string]interface{}{
@ -124,24 +122,24 @@ func TestConsulCAProvider_Bootstrap_WithCert(t *testing.T) {
delegate := newMockDelegate(t, conf) delegate := newMockDelegate(t, conf)
provider := TestConsulProvider(t, delegate) provider := TestConsulProvider(t, delegate)
require.NoError(provider.Configure(testProviderConfig(conf))) require.NoError(t, provider.Configure(testProviderConfig(conf)))
require.NoError(provider.GenerateRoot()) require.NoError(t, provider.GenerateRoot())
root, err := provider.ActiveRoot() root, err := provider.ActiveRoot()
require.NoError(err) require.NoError(t, err)
require.Equal(root, rootCA.RootCert) require.Equal(t, root, rootCA.RootCert)
// Should be a valid cert // Should be a valid cert
parsed, err := connect.ParseCert(root) parsed, err := connect.ParseCert(root)
require.NoError(err) require.NoError(t, err)
// test that the default root cert ttl was not applied to the provided cert // test that the default root cert ttl was not applied to the provided cert
defaultRootCertTTL, err := time.ParseDuration(structs.DefaultRootCertTTL) defaultRootCertTTL, err := time.ParseDuration(structs.DefaultRootCertTTL)
require.NoError(err) require.NoError(t, err)
defaultNotAfter := time.Now().Add(defaultRootCertTTL).UTC() defaultNotAfter := time.Now().Add(defaultRootCertTTL).UTC()
// we can't compare given the "delta" between the time the cert is generated // we can't compare given the "delta" between the time the cert is generated
// and when we start the test; so just look at the years for now, given different years // and when we start the test; so just look at the years for now, given different years
require.NotEqualf(defaultNotAfter.Year(), parsed.NotAfter.Year(), "parsed cert ttl expected to be different from default root cert ttl") require.NotEqualf(t, defaultNotAfter.Year(), parsed.NotAfter.Year(), "parsed cert ttl expected to be different from default root cert ttl")
} }
func TestConsulCAProvider_SignLeaf(t *testing.T) { func TestConsulCAProvider_SignLeaf(t *testing.T) {
@ -154,7 +152,6 @@ func TestConsulCAProvider_SignLeaf(t *testing.T) {
for _, tc := range KeyTestCases { for _, tc := range KeyTestCases {
tc := tc tc := tc
t.Run(tc.Desc, func(t *testing.T) { t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
conf := testConsulCAConfig() conf := testConsulCAConfig()
conf.Config["LeafCertTTL"] = "1h" conf.Config["LeafCertTTL"] = "1h"
conf.Config["PrivateKeyType"] = tc.KeyType conf.Config["PrivateKeyType"] = tc.KeyType
@ -162,8 +159,8 @@ func TestConsulCAProvider_SignLeaf(t *testing.T) {
delegate := newMockDelegate(t, conf) delegate := newMockDelegate(t, conf)
provider := TestConsulProvider(t, delegate) provider := TestConsulProvider(t, delegate)
require.NoError(provider.Configure(testProviderConfig(conf))) require.NoError(t, provider.Configure(testProviderConfig(conf)))
require.NoError(provider.GenerateRoot()) require.NoError(t, provider.GenerateRoot())
spiffeService := &connect.SpiffeIDService{ spiffeService := &connect.SpiffeIDService{
Host: connect.TestClusterID + ".consul", Host: connect.TestClusterID + ".consul",
@ -177,26 +174,26 @@ func TestConsulCAProvider_SignLeaf(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeService) raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw) csr, err := connect.ParseCSR(raw)
require.NoError(err) require.NoError(t, err)
cert, err := provider.Sign(csr) cert, err := provider.Sign(csr)
require.NoError(err) require.NoError(t, err)
requireTrailingNewline(t, cert) requireTrailingNewline(t, cert)
parsed, err := connect.ParseCert(cert) parsed, err := connect.ParseCert(cert)
require.NoError(err) require.NoError(t, err)
require.Equal(spiffeService.URI(), parsed.URIs[0]) require.Equal(t, spiffeService.URI(), parsed.URIs[0])
require.Empty(parsed.Subject.CommonName) require.Empty(t, parsed.Subject.CommonName)
require.Equal(uint64(3), parsed.SerialNumber.Uint64()) require.Equal(t, uint64(3), parsed.SerialNumber.Uint64())
subjectKeyID, err := connect.KeyId(csr.PublicKey) subjectKeyID, err := connect.KeyId(csr.PublicKey)
require.NoError(err) require.NoError(t, err)
require.Equal(subjectKeyID, parsed.SubjectKeyId) require.Equal(t, subjectKeyID, parsed.SubjectKeyId)
requireNotEncoded(t, parsed.SubjectKeyId) requireNotEncoded(t, parsed.SubjectKeyId)
requireNotEncoded(t, parsed.AuthorityKeyId) requireNotEncoded(t, parsed.AuthorityKeyId)
// Ensure the cert is valid now and expires within the correct limit. // Ensure the cert is valid now and expires within the correct limit.
now := time.Now() now := time.Now()
require.True(parsed.NotAfter.Sub(now) < time.Hour) require.True(t, parsed.NotAfter.Sub(now) < time.Hour)
require.True(parsed.NotBefore.Before(now)) require.True(t, parsed.NotBefore.Before(now))
} }
// Generate a new cert for another service and make sure // Generate a new cert for another service and make sure
@ -206,22 +203,22 @@ func TestConsulCAProvider_SignLeaf(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeService) raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw) csr, err := connect.ParseCSR(raw)
require.NoError(err) require.NoError(t, err)
cert, err := provider.Sign(csr) cert, err := provider.Sign(csr)
require.NoError(err) require.NoError(t, err)
parsed, err := connect.ParseCert(cert) parsed, err := connect.ParseCert(cert)
require.NoError(err) require.NoError(t, err)
require.Equal(spiffeService.URI(), parsed.URIs[0]) require.Equal(t, spiffeService.URI(), parsed.URIs[0])
require.Empty(parsed.Subject.CommonName) require.Empty(t, parsed.Subject.CommonName)
require.Equal(uint64(4), parsed.SerialNumber.Uint64()) require.Equal(t, uint64(4), parsed.SerialNumber.Uint64())
requireNotEncoded(t, parsed.SubjectKeyId) requireNotEncoded(t, parsed.SubjectKeyId)
requireNotEncoded(t, parsed.AuthorityKeyId) requireNotEncoded(t, parsed.AuthorityKeyId)
// Ensure the cert is valid now and expires within the correct limit. // Ensure the cert is valid now and expires within the correct limit.
require.True(time.Until(parsed.NotAfter) < 3*24*time.Hour) require.True(t, time.Until(parsed.NotAfter) < 3*24*time.Hour)
require.True(parsed.NotBefore.Before(time.Now())) require.True(t, parsed.NotBefore.Before(time.Now()))
} }
spiffeAgent := &connect.SpiffeIDAgent{ spiffeAgent := &connect.SpiffeIDAgent{
@ -234,23 +231,23 @@ func TestConsulCAProvider_SignLeaf(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeAgent) raw, _ := connect.TestCSR(t, spiffeAgent)
csr, err := connect.ParseCSR(raw) csr, err := connect.ParseCSR(raw)
require.NoError(err) require.NoError(t, err)
cert, err := provider.Sign(csr) cert, err := provider.Sign(csr)
require.NoError(err) require.NoError(t, err)
parsed, err := connect.ParseCert(cert) parsed, err := connect.ParseCert(cert)
require.NoError(err) require.NoError(t, err)
require.Equal(spiffeAgent.URI(), parsed.URIs[0]) require.Equal(t, spiffeAgent.URI(), parsed.URIs[0])
require.Empty(parsed.Subject.CommonName) require.Empty(t, parsed.Subject.CommonName)
require.Equal(uint64(5), parsed.SerialNumber.Uint64()) require.Equal(t, uint64(5), parsed.SerialNumber.Uint64())
requireNotEncoded(t, parsed.SubjectKeyId) requireNotEncoded(t, parsed.SubjectKeyId)
requireNotEncoded(t, parsed.AuthorityKeyId) requireNotEncoded(t, parsed.AuthorityKeyId)
// Ensure the cert is valid now and expires within the correct limit. // Ensure the cert is valid now and expires within the correct limit.
now := time.Now() now := time.Now()
require.True(parsed.NotAfter.Sub(now) < time.Hour) require.True(t, parsed.NotAfter.Sub(now) < time.Hour)
require.True(parsed.NotBefore.Before(now)) require.True(t, parsed.NotBefore.Before(now))
} }
}) })
} }
@ -268,15 +265,14 @@ func TestConsulCAProvider_CrossSignCA(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(tc.Desc, func(t *testing.T) { t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
conf1 := testConsulCAConfig() conf1 := testConsulCAConfig()
delegate1 := newMockDelegate(t, conf1) delegate1 := newMockDelegate(t, conf1)
provider1 := TestConsulProvider(t, delegate1) provider1 := TestConsulProvider(t, delegate1)
conf1.Config["PrivateKeyType"] = tc.SigningKeyType conf1.Config["PrivateKeyType"] = tc.SigningKeyType
conf1.Config["PrivateKeyBits"] = tc.SigningKeyBits conf1.Config["PrivateKeyBits"] = tc.SigningKeyBits
require.NoError(provider1.Configure(testProviderConfig(conf1))) require.NoError(t, provider1.Configure(testProviderConfig(conf1)))
require.NoError(provider1.GenerateRoot()) require.NoError(t, provider1.GenerateRoot())
conf2 := testConsulCAConfig() conf2 := testConsulCAConfig()
conf2.CreateIndex = 10 conf2.CreateIndex = 10
@ -284,8 +280,8 @@ func TestConsulCAProvider_CrossSignCA(t *testing.T) {
provider2 := TestConsulProvider(t, delegate2) provider2 := TestConsulProvider(t, delegate2)
conf2.Config["PrivateKeyType"] = tc.CSRKeyType conf2.Config["PrivateKeyType"] = tc.CSRKeyType
conf2.Config["PrivateKeyBits"] = tc.CSRKeyBits conf2.Config["PrivateKeyBits"] = tc.CSRKeyBits
require.NoError(provider2.Configure(testProviderConfig(conf2))) require.NoError(t, provider2.Configure(testProviderConfig(conf2)))
require.NoError(provider2.GenerateRoot()) require.NoError(t, provider2.GenerateRoot())
testCrossSignProviders(t, provider1, provider2) testCrossSignProviders(t, provider1, provider2)
}) })
@ -293,52 +289,51 @@ func TestConsulCAProvider_CrossSignCA(t *testing.T) {
} }
func testCrossSignProviders(t *testing.T, provider1, provider2 Provider) { func testCrossSignProviders(t *testing.T, provider1, provider2 Provider) {
require := require.New(t)
// Get the root from the new provider to be cross-signed. // Get the root from the new provider to be cross-signed.
newRootPEM, err := provider2.ActiveRoot() newRootPEM, err := provider2.ActiveRoot()
require.NoError(err) require.NoError(t, err)
newRoot, err := connect.ParseCert(newRootPEM) newRoot, err := connect.ParseCert(newRootPEM)
require.NoError(err) require.NoError(t, err)
oldSubject := newRoot.Subject.CommonName oldSubject := newRoot.Subject.CommonName
requireNotEncoded(t, newRoot.SubjectKeyId) requireNotEncoded(t, newRoot.SubjectKeyId)
requireNotEncoded(t, newRoot.AuthorityKeyId) requireNotEncoded(t, newRoot.AuthorityKeyId)
newInterPEM, err := provider2.ActiveIntermediate() newInterPEM, err := provider2.ActiveIntermediate()
require.NoError(err) require.NoError(t, err)
newIntermediate, err := connect.ParseCert(newInterPEM) newIntermediate, err := connect.ParseCert(newInterPEM)
require.NoError(err) require.NoError(t, err)
requireNotEncoded(t, newIntermediate.SubjectKeyId) requireNotEncoded(t, newIntermediate.SubjectKeyId)
requireNotEncoded(t, newIntermediate.AuthorityKeyId) requireNotEncoded(t, newIntermediate.AuthorityKeyId)
// Have provider1 cross sign our new root cert. // Have provider1 cross sign our new root cert.
xcPEM, err := provider1.CrossSignCA(newRoot) xcPEM, err := provider1.CrossSignCA(newRoot)
require.NoError(err) require.NoError(t, err)
xc, err := connect.ParseCert(xcPEM) xc, err := connect.ParseCert(xcPEM)
require.NoError(err) require.NoError(t, err)
requireNotEncoded(t, xc.SubjectKeyId) requireNotEncoded(t, xc.SubjectKeyId)
requireNotEncoded(t, xc.AuthorityKeyId) requireNotEncoded(t, xc.AuthorityKeyId)
oldRootPEM, err := provider1.ActiveRoot() oldRootPEM, err := provider1.ActiveRoot()
require.NoError(err) require.NoError(t, err)
oldRoot, err := connect.ParseCert(oldRootPEM) oldRoot, err := connect.ParseCert(oldRootPEM)
require.NoError(err) require.NoError(t, err)
requireNotEncoded(t, oldRoot.SubjectKeyId) requireNotEncoded(t, oldRoot.SubjectKeyId)
requireNotEncoded(t, oldRoot.AuthorityKeyId) requireNotEncoded(t, oldRoot.AuthorityKeyId)
// AuthorityKeyID should now be the signing root's, SubjectKeyId should be kept. // AuthorityKeyID should now be the signing root's, SubjectKeyId should be kept.
require.Equal(oldRoot.SubjectKeyId, xc.AuthorityKeyId, require.Equal(t, oldRoot.SubjectKeyId, xc.AuthorityKeyId,
"newSKID=%x\nnewAKID=%x\noldSKID=%x\noldAKID=%x\nxcSKID=%x\nxcAKID=%x", "newSKID=%x\nnewAKID=%x\noldSKID=%x\noldAKID=%x\nxcSKID=%x\nxcAKID=%x",
newRoot.SubjectKeyId, newRoot.AuthorityKeyId, newRoot.SubjectKeyId, newRoot.AuthorityKeyId,
oldRoot.SubjectKeyId, oldRoot.AuthorityKeyId, oldRoot.SubjectKeyId, oldRoot.AuthorityKeyId,
xc.SubjectKeyId, xc.AuthorityKeyId) xc.SubjectKeyId, xc.AuthorityKeyId)
require.Equal(newRoot.SubjectKeyId, xc.SubjectKeyId) require.Equal(t, newRoot.SubjectKeyId, xc.SubjectKeyId)
// Subject name should not have changed. // Subject name should not have changed.
require.Equal(oldSubject, xc.Subject.CommonName) require.Equal(t, oldSubject, xc.Subject.CommonName)
// Issuer should be the signing root. // Issuer should be the signing root.
require.Equal(oldRoot.Issuer.CommonName, xc.Issuer.CommonName) require.Equal(t, oldRoot.Issuer.CommonName, xc.Issuer.CommonName)
// Get a leaf cert so we can verify against the cross-signed cert. // Get a leaf cert so we can verify against the cross-signed cert.
spiffeService := &connect.SpiffeIDService{ spiffeService := &connect.SpiffeIDService{
@ -350,13 +345,13 @@ func testCrossSignProviders(t *testing.T, provider1, provider2 Provider) {
raw, _ := connect.TestCSR(t, spiffeService) raw, _ := connect.TestCSR(t, spiffeService)
leafCsr, err := connect.ParseCSR(raw) leafCsr, err := connect.ParseCSR(raw)
require.NoError(err) require.NoError(t, err)
leafPEM, err := provider2.Sign(leafCsr) leafPEM, err := provider2.Sign(leafCsr)
require.NoError(err) require.NoError(t, err)
cert, err := connect.ParseCert(leafPEM) cert, err := connect.ParseCert(leafPEM)
require.NoError(err) require.NoError(t, err)
requireNotEncoded(t, cert.SubjectKeyId) requireNotEncoded(t, cert.SubjectKeyId)
requireNotEncoded(t, cert.AuthorityKeyId) requireNotEncoded(t, cert.AuthorityKeyId)
@ -374,7 +369,7 @@ func testCrossSignProviders(t *testing.T, provider1, provider2 Provider) {
Intermediates: intermediatePool, Intermediates: intermediatePool,
Roots: rootPool, Roots: rootPool,
}) })
require.NoError(err) require.NoError(t, err)
} }
} }
@ -390,15 +385,14 @@ func TestConsulProvider_SignIntermediate(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(tc.Desc, func(t *testing.T) { t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
conf1 := testConsulCAConfig() conf1 := testConsulCAConfig()
delegate1 := newMockDelegate(t, conf1) delegate1 := newMockDelegate(t, conf1)
provider1 := TestConsulProvider(t, delegate1) provider1 := TestConsulProvider(t, delegate1)
conf1.Config["PrivateKeyType"] = tc.SigningKeyType conf1.Config["PrivateKeyType"] = tc.SigningKeyType
conf1.Config["PrivateKeyBits"] = tc.SigningKeyBits conf1.Config["PrivateKeyBits"] = tc.SigningKeyBits
require.NoError(provider1.Configure(testProviderConfig(conf1))) require.NoError(t, provider1.Configure(testProviderConfig(conf1)))
require.NoError(provider1.GenerateRoot()) require.NoError(t, provider1.GenerateRoot())
conf2 := testConsulCAConfig() conf2 := testConsulCAConfig()
conf2.CreateIndex = 10 conf2.CreateIndex = 10
@ -409,7 +403,7 @@ func TestConsulProvider_SignIntermediate(t *testing.T) {
cfg := testProviderConfig(conf2) cfg := testProviderConfig(conf2)
cfg.IsPrimary = false cfg.IsPrimary = false
cfg.Datacenter = "dc2" cfg.Datacenter = "dc2"
require.NoError(provider2.Configure(cfg)) require.NoError(t, provider2.Configure(cfg))
testSignIntermediateCrossDC(t, provider1, provider2) testSignIntermediateCrossDC(t, provider1, provider2)
}) })
@ -418,22 +412,21 @@ func TestConsulProvider_SignIntermediate(t *testing.T) {
} }
func testSignIntermediateCrossDC(t *testing.T, provider1, provider2 Provider) { func testSignIntermediateCrossDC(t *testing.T, provider1, provider2 Provider) {
require := require.New(t)
// Get the intermediate CSR from provider2. // Get the intermediate CSR from provider2.
csrPEM, err := provider2.GenerateIntermediateCSR() csrPEM, err := provider2.GenerateIntermediateCSR()
require.NoError(err) require.NoError(t, err)
csr, err := connect.ParseCSR(csrPEM) csr, err := connect.ParseCSR(csrPEM)
require.NoError(err) require.NoError(t, err)
// Sign the CSR with provider1. // Sign the CSR with provider1.
intermediatePEM, err := provider1.SignIntermediate(csr) intermediatePEM, err := provider1.SignIntermediate(csr)
require.NoError(err) require.NoError(t, err)
rootPEM, err := provider1.ActiveRoot() rootPEM, err := provider1.ActiveRoot()
require.NoError(err) require.NoError(t, err)
// Give the new intermediate to provider2 to use. // Give the new intermediate to provider2 to use.
require.NoError(provider2.SetIntermediate(intermediatePEM, rootPEM)) require.NoError(t, provider2.SetIntermediate(intermediatePEM, rootPEM))
// Have provider2 sign a leaf cert and make sure the chain is correct. // Have provider2 sign a leaf cert and make sure the chain is correct.
spiffeService := &connect.SpiffeIDService{ spiffeService := &connect.SpiffeIDService{
@ -445,13 +438,13 @@ func testSignIntermediateCrossDC(t *testing.T, provider1, provider2 Provider) {
raw, _ := connect.TestCSR(t, spiffeService) raw, _ := connect.TestCSR(t, spiffeService)
leafCsr, err := connect.ParseCSR(raw) leafCsr, err := connect.ParseCSR(raw)
require.NoError(err) require.NoError(t, err)
leafPEM, err := provider2.Sign(leafCsr) leafPEM, err := provider2.Sign(leafCsr)
require.NoError(err) require.NoError(t, err)
cert, err := connect.ParseCert(leafPEM) cert, err := connect.ParseCert(leafPEM)
require.NoError(err) require.NoError(t, err)
requireNotEncoded(t, cert.SubjectKeyId) requireNotEncoded(t, cert.SubjectKeyId)
requireNotEncoded(t, cert.AuthorityKeyId) requireNotEncoded(t, cert.AuthorityKeyId)
@ -466,7 +459,7 @@ func testSignIntermediateCrossDC(t *testing.T, provider1, provider2 Provider) {
Intermediates: intermediatePool, Intermediates: intermediatePool,
Roots: rootPool, Roots: rootPool,
}) })
require.NoError(err) require.NoError(t, err)
} }
func TestConsulCAProvider_MigrateOldID(t *testing.T) { func TestConsulCAProvider_MigrateOldID(t *testing.T) {

View File

@ -116,13 +116,12 @@ func TestVaultCAProvider_VaultTLSConfig(t *testing.T) {
TLSSkipVerify: true, TLSSkipVerify: true,
} }
tlsConfig := vaultTLSConfig(config) tlsConfig := vaultTLSConfig(config)
require := require.New(t) require.Equal(t, config.CAFile, tlsConfig.CACert)
require.Equal(config.CAFile, tlsConfig.CACert) require.Equal(t, config.CAPath, tlsConfig.CAPath)
require.Equal(config.CAPath, tlsConfig.CAPath) require.Equal(t, config.CertFile, tlsConfig.ClientCert)
require.Equal(config.CertFile, tlsConfig.ClientCert) require.Equal(t, config.KeyFile, tlsConfig.ClientKey)
require.Equal(config.KeyFile, tlsConfig.ClientKey) require.Equal(t, config.TLSServerName, tlsConfig.TLSServerName)
require.Equal(config.TLSServerName, tlsConfig.TLSServerName) require.Equal(t, config.TLSSkipVerify, tlsConfig.Insecure)
require.Equal(config.TLSSkipVerify, tlsConfig.Insecure)
} }
func TestVaultCAProvider_Configure(t *testing.T) { func TestVaultCAProvider_Configure(t *testing.T) {
@ -171,11 +170,10 @@ func TestVaultCAProvider_SecondaryActiveIntermediate(t *testing.T) {
provider, testVault := testVaultProviderWithConfig(t, false, nil) provider, testVault := testVaultProviderWithConfig(t, false, nil)
defer testVault.Stop() defer testVault.Stop()
require := require.New(t)
cert, err := provider.ActiveIntermediate() cert, err := provider.ActiveIntermediate()
require.Empty(cert) require.Empty(t, cert)
require.NoError(err) require.NoError(t, err)
} }
func TestVaultCAProvider_RenewToken(t *testing.T) { func TestVaultCAProvider_RenewToken(t *testing.T) {
@ -231,8 +229,6 @@ func TestVaultCAProvider_Bootstrap(t *testing.T) {
defer testvault2.Stop() defer testvault2.Stop()
client2 := testvault2.client client2 := testvault2.client
require := require.New(t)
cases := []struct { cases := []struct {
certFunc func() (string, error) certFunc func() (string, error)
backendPath string backendPath string
@ -264,28 +260,28 @@ func TestVaultCAProvider_Bootstrap(t *testing.T) {
provider := tc.provider provider := tc.provider
client := tc.client client := tc.client
cert, err := tc.certFunc() cert, err := tc.certFunc()
require.NoError(err) require.NoError(t, err)
req := client.NewRequest("GET", "/v1/"+tc.backendPath+"ca/pem") req := client.NewRequest("GET", "/v1/"+tc.backendPath+"ca/pem")
resp, err := client.RawRequest(req) resp, err := client.RawRequest(req)
require.NoError(err) require.NoError(t, err)
bytes, err := ioutil.ReadAll(resp.Body) bytes, err := ioutil.ReadAll(resp.Body)
require.NoError(err) require.NoError(t, err)
require.Equal(cert, string(bytes)+"\n") require.Equal(t, cert, string(bytes)+"\n")
// Should be a valid CA cert // Should be a valid CA cert
parsed, err := connect.ParseCert(cert) parsed, err := connect.ParseCert(cert)
require.NoError(err) require.NoError(t, err)
require.True(parsed.IsCA) require.True(t, parsed.IsCA)
require.Len(parsed.URIs, 1) require.Len(t, parsed.URIs, 1)
require.Equal(fmt.Sprintf("spiffe://%s.consul", provider.clusterID), parsed.URIs[0].String()) require.Equal(t, fmt.Sprintf("spiffe://%s.consul", provider.clusterID), parsed.URIs[0].String())
// test that the root cert ttl as applied // test that the root cert ttl as applied
if tc.rootCaCreation { if tc.rootCaCreation {
rootCertTTL, err := time.ParseDuration(tc.expectedRootCertTTL) rootCertTTL, err := time.ParseDuration(tc.expectedRootCertTTL)
require.NoError(err) require.NoError(t, err)
expectedNotAfter := time.Now().Add(rootCertTTL).UTC() expectedNotAfter := time.Now().Add(rootCertTTL).UTC()
require.WithinDuration(expectedNotAfter, parsed.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured") require.WithinDuration(t, expectedNotAfter, parsed.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured")
} }
} }
} }
@ -313,7 +309,6 @@ func TestVaultCAProvider_SignLeaf(t *testing.T) {
for _, tc := range KeyTestCases { for _, tc := range KeyTestCases {
tc := tc tc := tc
t.Run(tc.Desc, func(t *testing.T) { t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
provider, testVault := testVaultProviderWithConfig(t, true, map[string]interface{}{ provider, testVault := testVaultProviderWithConfig(t, true, map[string]interface{}{
"LeafCertTTL": "1h", "LeafCertTTL": "1h",
"PrivateKeyType": tc.KeyType, "PrivateKeyType": tc.KeyType,
@ -329,11 +324,11 @@ func TestVaultCAProvider_SignLeaf(t *testing.T) {
} }
rootPEM, err := provider.ActiveRoot() rootPEM, err := provider.ActiveRoot()
require.NoError(err) require.NoError(t, err)
assertCorrectKeyType(t, tc.KeyType, rootPEM) assertCorrectKeyType(t, tc.KeyType, rootPEM)
intPEM, err := provider.ActiveIntermediate() intPEM, err := provider.ActiveIntermediate()
require.NoError(err) require.NoError(t, err)
assertCorrectKeyType(t, tc.KeyType, intPEM) assertCorrectKeyType(t, tc.KeyType, intPEM)
// Generate a leaf cert for the service. // Generate a leaf cert for the service.
@ -342,23 +337,23 @@ func TestVaultCAProvider_SignLeaf(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeService) raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw) csr, err := connect.ParseCSR(raw)
require.NoError(err) require.NoError(t, err)
cert, err := provider.Sign(csr) cert, err := provider.Sign(csr)
require.NoError(err) require.NoError(t, err)
parsed, err := connect.ParseCert(cert) parsed, err := connect.ParseCert(cert)
require.NoError(err) require.NoError(t, err)
require.Equal(parsed.URIs[0], spiffeService.URI()) require.Equal(t, parsed.URIs[0], spiffeService.URI())
firstSerial = parsed.SerialNumber.Uint64() firstSerial = parsed.SerialNumber.Uint64()
// Ensure the cert is valid now and expires within the correct limit. // Ensure the cert is valid now and expires within the correct limit.
now := time.Now() now := time.Now()
require.True(parsed.NotAfter.Sub(now) < time.Hour) require.True(t, parsed.NotAfter.Sub(now) < time.Hour)
require.True(parsed.NotBefore.Before(now)) require.True(t, parsed.NotBefore.Before(now))
// Make sure we can validate the cert as expected. // Make sure we can validate the cert as expected.
require.NoError(connect.ValidateLeaf(rootPEM, cert, []string{intPEM})) require.NoError(t, connect.ValidateLeaf(rootPEM, cert, []string{intPEM}))
requireTrailingNewline(t, cert) requireTrailingNewline(t, cert)
} }
@ -369,22 +364,22 @@ func TestVaultCAProvider_SignLeaf(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeService) raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw) csr, err := connect.ParseCSR(raw)
require.NoError(err) require.NoError(t, err)
cert, err := provider.Sign(csr) cert, err := provider.Sign(csr)
require.NoError(err) require.NoError(t, err)
parsed, err := connect.ParseCert(cert) parsed, err := connect.ParseCert(cert)
require.NoError(err) require.NoError(t, err)
require.Equal(parsed.URIs[0], spiffeService.URI()) require.Equal(t, parsed.URIs[0], spiffeService.URI())
require.NotEqual(firstSerial, parsed.SerialNumber.Uint64()) require.NotEqual(t, firstSerial, parsed.SerialNumber.Uint64())
// Ensure the cert is valid now and expires within the correct limit. // Ensure the cert is valid now and expires within the correct limit.
require.True(time.Until(parsed.NotAfter) < time.Hour) require.True(t, time.Until(parsed.NotAfter) < time.Hour)
require.True(parsed.NotBefore.Before(time.Now())) require.True(t, parsed.NotBefore.Before(time.Now()))
// Make sure we can validate the cert as expected. // Make sure we can validate the cert as expected.
require.NoError(connect.ValidateLeaf(rootPEM, cert, []string{intPEM})) require.NoError(t, connect.ValidateLeaf(rootPEM, cert, []string{intPEM}))
} }
}) })
} }
@ -399,7 +394,6 @@ func TestVaultCAProvider_CrossSignCA(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(tc.Desc, func(t *testing.T) { t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
if tc.SigningKeyType != tc.CSRKeyType { if tc.SigningKeyType != tc.CSRKeyType {
// See https://github.com/hashicorp/vault/issues/7709 // See https://github.com/hashicorp/vault/issues/7709
@ -414,11 +408,11 @@ func TestVaultCAProvider_CrossSignCA(t *testing.T) {
{ {
rootPEM, err := provider1.ActiveRoot() rootPEM, err := provider1.ActiveRoot()
require.NoError(err) require.NoError(t, err)
assertCorrectKeyType(t, tc.SigningKeyType, rootPEM) assertCorrectKeyType(t, tc.SigningKeyType, rootPEM)
intPEM, err := provider1.ActiveIntermediate() intPEM, err := provider1.ActiveIntermediate()
require.NoError(err) require.NoError(t, err)
assertCorrectKeyType(t, tc.SigningKeyType, intPEM) assertCorrectKeyType(t, tc.SigningKeyType, intPEM)
} }
@ -431,11 +425,11 @@ func TestVaultCAProvider_CrossSignCA(t *testing.T) {
{ {
rootPEM, err := provider2.ActiveRoot() rootPEM, err := provider2.ActiveRoot()
require.NoError(err) require.NoError(t, err)
assertCorrectKeyType(t, tc.CSRKeyType, rootPEM) assertCorrectKeyType(t, tc.CSRKeyType, rootPEM)
intPEM, err := provider2.ActiveIntermediate() intPEM, err := provider2.ActiveIntermediate()
require.NoError(err) require.NoError(t, err)
assertCorrectKeyType(t, tc.CSRKeyType, intPEM) assertCorrectKeyType(t, tc.CSRKeyType, intPEM)
} }

View File

@ -48,32 +48,30 @@ func makeConfig(kc KeyConfig) structs.CommonCAProviderConfig {
} }
func testGenerateRSAKey(t *testing.T, bits int) { func testGenerateRSAKey(t *testing.T, bits int) {
require := require.New(t)
_, rsaBlock, err := GeneratePrivateKeyWithConfig("rsa", bits) _, rsaBlock, err := GeneratePrivateKeyWithConfig("rsa", bits)
require.NoError(err) require.NoError(t, err)
require.Contains(rsaBlock, "RSA PRIVATE KEY") require.Contains(t, rsaBlock, "RSA PRIVATE KEY")
rsaBytes, _ := pem.Decode([]byte(rsaBlock)) rsaBytes, _ := pem.Decode([]byte(rsaBlock))
require.NotNil(rsaBytes) require.NotNil(t, rsaBytes)
rsaKey, err := x509.ParsePKCS1PrivateKey(rsaBytes.Bytes) rsaKey, err := x509.ParsePKCS1PrivateKey(rsaBytes.Bytes)
require.NoError(err) require.NoError(t, err)
require.NoError(rsaKey.Validate()) require.NoError(t, rsaKey.Validate())
require.Equal(bits/8, rsaKey.Size()) // note: returned size is in bytes. 2048/8==256 require.Equal(t, bits/8, rsaKey.Size()) // note: returned size is in bytes. 2048/8==256
} }
func testGenerateECDSAKey(t *testing.T, bits int) { func testGenerateECDSAKey(t *testing.T, bits int) {
require := require.New(t)
_, pemBlock, err := GeneratePrivateKeyWithConfig("ec", bits) _, pemBlock, err := GeneratePrivateKeyWithConfig("ec", bits)
require.NoError(err) require.NoError(t, err)
require.Contains(pemBlock, "EC PRIVATE KEY") require.Contains(t, pemBlock, "EC PRIVATE KEY")
block, _ := pem.Decode([]byte(pemBlock)) block, _ := pem.Decode([]byte(pemBlock))
require.NotNil(block) require.NotNil(t, block)
pk, err := x509.ParseECPrivateKey(block.Bytes) pk, err := x509.ParseECPrivateKey(block.Bytes)
require.NoError(err) require.NoError(t, err)
require.Equal(bits, pk.Curve.Params().BitSize) require.Equal(t, bits, pk.Curve.Params().BitSize)
} }
// Tests to make sure we are able to generate every type of private key supported by the x509 lib. // Tests to make sure we are able to generate every type of private key supported by the x509 lib.
@ -132,7 +130,6 @@ func TestSignatureMismatches(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
for _, p1 := range goodParams { for _, p1 := range goodParams {
for _, p2 := range goodParams { for _, p2 := range goodParams {
if p1 == p2 { if p1 == p2 {
@ -140,14 +137,14 @@ func TestSignatureMismatches(t *testing.T) {
} }
t.Run(fmt.Sprintf("TestMismatches-%s%d-%s%d", p1.keyType, p1.keyBits, p2.keyType, p2.keyBits), func(t *testing.T) { t.Run(fmt.Sprintf("TestMismatches-%s%d-%s%d", p1.keyType, p1.keyBits, p2.keyType, p2.keyBits), func(t *testing.T) {
ca := TestCAWithKeyType(t, nil, p1.keyType, p1.keyBits) ca := TestCAWithKeyType(t, nil, p1.keyType, p1.keyBits)
require.Equal(p1.keyType, ca.PrivateKeyType) require.Equal(t, p1.keyType, ca.PrivateKeyType)
require.Equal(p1.keyBits, ca.PrivateKeyBits) require.Equal(t, p1.keyBits, ca.PrivateKeyBits)
certPEM, keyPEM, err := testLeaf(t, "foobar.service.consul", "default", ca, p2.keyType, p2.keyBits) certPEM, keyPEM, err := testLeaf(t, "foobar.service.consul", "default", ca, p2.keyType, p2.keyBits)
require.NoError(err) require.NoError(t, err)
_, err = ParseCert(certPEM) _, err = ParseCert(certPEM)
require.NoError(err) require.NoError(t, err)
_, err = ParseSigner(keyPEM) _, err = ParseSigner(keyPEM)
require.NoError(err) require.NoError(t, err)
}) })
} }
} }

View File

@ -29,20 +29,18 @@ func skipIfMissingOpenSSL(t *testing.T) {
func testCAAndLeaf(t *testing.T, keyType string, keyBits int) { func testCAAndLeaf(t *testing.T, keyType string, keyBits int) {
skipIfMissingOpenSSL(t) skipIfMissingOpenSSL(t)
require := require.New(t)
// Create the certs // Create the certs
ca := TestCAWithKeyType(t, nil, keyType, keyBits) ca := TestCAWithKeyType(t, nil, keyType, keyBits)
leaf, _ := TestLeaf(t, "web", ca) leaf, _ := TestLeaf(t, "web", ca)
// Create a temporary directory for storing the certs // Create a temporary directory for storing the certs
td, err := ioutil.TempDir("", "consul") td, err := ioutil.TempDir("", "consul")
require.NoError(err) require.NoError(t, err)
defer os.RemoveAll(td) defer os.RemoveAll(td)
// Write the cert // Write the cert
require.NoError(ioutil.WriteFile(filepath.Join(td, "ca.pem"), []byte(ca.RootCert), 0644)) require.NoError(t, ioutil.WriteFile(filepath.Join(td, "ca.pem"), []byte(ca.RootCert), 0644))
require.NoError(ioutil.WriteFile(filepath.Join(td, "leaf.pem"), []byte(leaf[:]), 0644)) require.NoError(t, ioutil.WriteFile(filepath.Join(td, "leaf.pem"), []byte(leaf[:]), 0644))
// Use OpenSSL to verify so we have an external, known-working process // Use OpenSSL to verify so we have an external, known-working process
// that can verify this outside of our own implementations. // that can verify this outside of our own implementations.
@ -54,15 +52,13 @@ func testCAAndLeaf(t *testing.T, keyType string, keyBits int) {
if ee, ok := err.(*exec.ExitError); ok { if ee, ok := err.(*exec.ExitError); ok {
t.Log("STDERR:", string(ee.Stderr)) t.Log("STDERR:", string(ee.Stderr))
} }
require.NoError(err) require.NoError(t, err)
} }
// Test cross-signing. // Test cross-signing.
func testCAAndLeaf_xc(t *testing.T, keyType string, keyBits int) { func testCAAndLeaf_xc(t *testing.T, keyType string, keyBits int) {
skipIfMissingOpenSSL(t) skipIfMissingOpenSSL(t)
assert := assert.New(t)
// Create the certs // Create the certs
ca1 := TestCAWithKeyType(t, nil, keyType, keyBits) ca1 := TestCAWithKeyType(t, nil, keyType, keyBits)
ca2 := TestCAWithKeyType(t, ca1, keyType, keyBits) ca2 := TestCAWithKeyType(t, ca1, keyType, keyBits)
@ -71,16 +67,16 @@ func testCAAndLeaf_xc(t *testing.T, keyType string, keyBits int) {
// Create a temporary directory for storing the certs // Create a temporary directory for storing the certs
td, err := ioutil.TempDir("", "consul") td, err := ioutil.TempDir("", "consul")
assert.Nil(err) assert.Nil(t, err)
defer os.RemoveAll(td) defer os.RemoveAll(td)
// Write the cert // Write the cert
xcbundle := []byte(ca1.RootCert) xcbundle := []byte(ca1.RootCert)
xcbundle = append(xcbundle, '\n') xcbundle = append(xcbundle, '\n')
xcbundle = append(xcbundle, []byte(ca2.SigningCert)...) xcbundle = append(xcbundle, []byte(ca2.SigningCert)...)
assert.Nil(ioutil.WriteFile(filepath.Join(td, "ca.pem"), xcbundle, 0644)) assert.Nil(t, ioutil.WriteFile(filepath.Join(td, "ca.pem"), xcbundle, 0644))
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf1.pem"), []byte(leaf1), 0644)) assert.Nil(t, ioutil.WriteFile(filepath.Join(td, "leaf1.pem"), []byte(leaf1), 0644))
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf2.pem"), []byte(leaf2), 0644)) assert.Nil(t, ioutil.WriteFile(filepath.Join(td, "leaf2.pem"), []byte(leaf2), 0644))
// OpenSSL verify the cross-signed leaf (leaf2) // OpenSSL verify the cross-signed leaf (leaf2)
{ {
@ -89,7 +85,7 @@ func testCAAndLeaf_xc(t *testing.T, keyType string, keyBits int) {
cmd.Dir = td cmd.Dir = td
output, err := cmd.Output() output, err := cmd.Output()
t.Log(string(output)) t.Log(string(output))
assert.Nil(err) assert.Nil(t, err)
} }
// OpenSSL verify the old leaf (leaf1) // OpenSSL verify the old leaf (leaf1)
@ -99,7 +95,7 @@ func testCAAndLeaf_xc(t *testing.T, keyType string, keyBits int) {
cmd.Dir = td cmd.Dir = td
output, err := cmd.Output() output, err := cmd.Output()
t.Log(string(output)) t.Log(string(output))
assert.Nil(err) assert.Nil(t, err)
} }
} }

View File

@ -43,7 +43,6 @@ func TestConnectCARoots_list(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -56,16 +55,16 @@ func TestConnectCARoots_list(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/connect/ca/roots", nil) req, _ := http.NewRequest("GET", "/v1/connect/ca/roots", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.ConnectCARoots(resp, req) obj, err := a.srv.ConnectCARoots(resp, req)
assert.NoError(err) assert.NoError(t, err)
value := obj.(structs.IndexedCARoots) value := obj.(structs.IndexedCARoots)
assert.Equal(value.ActiveRootID, ca2.ID) assert.Equal(t, value.ActiveRootID, ca2.ID)
assert.Len(value.Roots, 2) assert.Len(t, value.Roots, 2)
// We should never have the secret information // We should never have the secret information
for _, r := range value.Roots { for _, r := range value.Roots {
assert.Equal("", r.SigningCert) assert.Equal(t, "", r.SigningCert)
assert.Equal("", r.SigningKey) assert.Equal(t, "", r.SigningKey)
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -77,7 +77,6 @@ func TestAutopilot_CleanupDeadServer(t *testing.T) {
retry.Run(t, func(r *retry.R) { r.Check(wantPeers(s, 5)) }) retry.Run(t, func(r *retry.R) { r.Check(wantPeers(s, 5)) })
} }
require := require.New(t)
testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc1")
leaderIndex := -1 leaderIndex := -1
for i, s := range servers { for i, s := range servers {
@ -86,7 +85,7 @@ func TestAutopilot_CleanupDeadServer(t *testing.T) {
break break
} }
} }
require.NotEqual(leaderIndex, -1) require.NotEqual(t, leaderIndex, -1)
// Shutdown two non-leader servers // Shutdown two non-leader servers
killed := make(map[string]struct{}) killed := make(map[string]struct{})

View File

@ -388,7 +388,6 @@ func TestCatalog_Register_ConnectProxy(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -399,7 +398,7 @@ func TestCatalog_Register_ConnectProxy(t *testing.T) {
// Register // Register
var out struct{} var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// List // List
req := structs.ServiceSpecificRequest{ req := structs.ServiceSpecificRequest{
@ -407,11 +406,11 @@ func TestCatalog_Register_ConnectProxy(t *testing.T) {
ServiceName: args.Service.Service, ServiceName: args.Service.Service,
} }
var resp structs.IndexedServiceNodes var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1) assert.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0] v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind) assert.Equal(t, structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName) assert.Equal(t, args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName)
} }
// Test an invalid ConnectProxy. We don't need to exhaustively test because // Test an invalid ConnectProxy. We don't need to exhaustively test because
@ -423,7 +422,6 @@ func TestCatalog_Register_ConnectProxy_invalid(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -436,8 +434,8 @@ func TestCatalog_Register_ConnectProxy_invalid(t *testing.T) {
// Register // Register
var out struct{} var out struct{}
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out) err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.NotNil(err) assert.NotNil(t, err)
assert.Contains(err.Error(), "DestinationServiceName") assert.Contains(t, err.Error(), "DestinationServiceName")
} }
// Test that write is required for the proxy destination to register a proxy. // Test that write is required for the proxy destination to register a proxy.
@ -448,7 +446,6 @@ func TestCatalog_Register_ConnectProxy_ACLDestinationServiceName(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -479,7 +476,7 @@ node "foo" {
args.WriteRequest.Token = token args.WriteRequest.Token = token
var out struct{} var out struct{}
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out) err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.True(acl.IsErrPermissionDenied(err)) assert.True(t, acl.IsErrPermissionDenied(err))
// Register should fail with the right destination but wrong name // Register should fail with the right destination but wrong name
args = structs.TestRegisterRequestProxy(t) args = structs.TestRegisterRequestProxy(t)
@ -487,14 +484,14 @@ node "foo" {
args.Service.Proxy.DestinationServiceName = "foo" args.Service.Proxy.DestinationServiceName = "foo"
args.WriteRequest.Token = token args.WriteRequest.Token = token
err = msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out) err = msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.True(acl.IsErrPermissionDenied(err)) assert.True(t, acl.IsErrPermissionDenied(err))
// Register should work with the right destination // Register should work with the right destination
args = structs.TestRegisterRequestProxy(t) args = structs.TestRegisterRequestProxy(t)
args.Service.Service = "foo" args.Service.Service = "foo"
args.Service.Proxy.DestinationServiceName = "foo" args.Service.Proxy.DestinationServiceName = "foo"
args.WriteRequest.Token = token args.WriteRequest.Token = token
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
} }
func TestCatalog_Register_ConnectNative(t *testing.T) { func TestCatalog_Register_ConnectNative(t *testing.T) {
@ -504,7 +501,6 @@ func TestCatalog_Register_ConnectNative(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -516,7 +512,7 @@ func TestCatalog_Register_ConnectNative(t *testing.T) {
// Register // Register
var out struct{} var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// List // List
req := structs.ServiceSpecificRequest{ req := structs.ServiceSpecificRequest{
@ -524,11 +520,11 @@ func TestCatalog_Register_ConnectNative(t *testing.T) {
ServiceName: args.Service.Service, ServiceName: args.Service.Service,
} }
var resp structs.IndexedServiceNodes var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1) assert.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0] v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindTypical, v.ServiceKind) assert.Equal(t, structs.ServiceKindTypical, v.ServiceKind)
assert.True(v.ServiceConnect.Native) assert.True(t, v.ServiceConnect.Native)
} }
func TestCatalog_Deregister(t *testing.T) { func TestCatalog_Deregister(t *testing.T) {
@ -2149,7 +2145,6 @@ func TestCatalog_ListServiceNodes_ConnectProxy(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -2161,7 +2156,7 @@ func TestCatalog_ListServiceNodes_ConnectProxy(t *testing.T) {
// Register the service // Register the service
args := structs.TestRegisterRequestProxy(t) args := structs.TestRegisterRequestProxy(t)
var out struct{} var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List // List
req := structs.ServiceSpecificRequest{ req := structs.ServiceSpecificRequest{
@ -2170,11 +2165,11 @@ func TestCatalog_ListServiceNodes_ConnectProxy(t *testing.T) {
TagFilter: false, TagFilter: false,
} }
var resp structs.IndexedServiceNodes var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1) assert.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0] v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind) assert.Equal(t, structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName) assert.Equal(t, args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName)
} }
func TestCatalog_ServiceNodes_Gateway(t *testing.T) { func TestCatalog_ServiceNodes_Gateway(t *testing.T) {
@ -2304,7 +2299,6 @@ func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -2316,7 +2310,7 @@ func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
// Register the proxy service // Register the proxy service
args := structs.TestRegisterRequestProxy(t) args := structs.TestRegisterRequestProxy(t)
var out struct{} var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// Register the service // Register the service
{ {
@ -2324,7 +2318,7 @@ func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
args := structs.TestRegisterRequest(t) args := structs.TestRegisterRequest(t)
args.Service.Service = dst args.Service.Service = dst
var out struct{} var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
} }
// List // List
@ -2334,22 +2328,22 @@ func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
ServiceName: args.Service.Proxy.DestinationServiceName, ServiceName: args.Service.Proxy.DestinationServiceName,
} }
var resp structs.IndexedServiceNodes var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1) assert.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0] v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind) assert.Equal(t, structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName) assert.Equal(t, args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName)
// List by non-Connect // List by non-Connect
req = structs.ServiceSpecificRequest{ req = structs.ServiceSpecificRequest{
Datacenter: "dc1", Datacenter: "dc1",
ServiceName: args.Service.Proxy.DestinationServiceName, ServiceName: args.Service.Proxy.DestinationServiceName,
} }
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1) assert.Len(t, resp.ServiceNodes, 1)
v = resp.ServiceNodes[0] v = resp.ServiceNodes[0]
assert.Equal(args.Service.Proxy.DestinationServiceName, v.ServiceName) assert.Equal(t, args.Service.Proxy.DestinationServiceName, v.ServiceName)
assert.Equal("", v.ServiceProxy.DestinationServiceName) assert.Equal(t, "", v.ServiceProxy.DestinationServiceName)
} }
// Test that calling ServiceNodes with Connect: true will return // Test that calling ServiceNodes with Connect: true will return
@ -2361,7 +2355,6 @@ func TestCatalog_ListServiceNodes_ConnectDestinationNative(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -2374,7 +2367,7 @@ func TestCatalog_ListServiceNodes_ConnectDestinationNative(t *testing.T) {
args := structs.TestRegisterRequest(t) args := structs.TestRegisterRequest(t)
args.Service.Connect.Native = true args.Service.Connect.Native = true
var out struct{} var out struct{}
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List // List
req := structs.ServiceSpecificRequest{ req := structs.ServiceSpecificRequest{
@ -2383,20 +2376,20 @@ func TestCatalog_ListServiceNodes_ConnectDestinationNative(t *testing.T) {
ServiceName: args.Service.Service, ServiceName: args.Service.Service,
} }
var resp structs.IndexedServiceNodes var resp structs.IndexedServiceNodes
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
require.Len(resp.ServiceNodes, 1) require.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0] v := resp.ServiceNodes[0]
require.Equal(args.Service.Service, v.ServiceName) require.Equal(t, args.Service.Service, v.ServiceName)
// List by non-Connect // List by non-Connect
req = structs.ServiceSpecificRequest{ req = structs.ServiceSpecificRequest{
Datacenter: "dc1", Datacenter: "dc1",
ServiceName: args.Service.Service, ServiceName: args.Service.Service,
} }
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
require.Len(resp.ServiceNodes, 1) require.Len(t, resp.ServiceNodes, 1)
v = resp.ServiceNodes[0] v = resp.ServiceNodes[0]
require.Equal(args.Service.Service, v.ServiceName) require.Equal(t, args.Service.Service, v.ServiceName)
} }
func TestCatalog_ListServiceNodes_ConnectProxy_ACL(t *testing.T) { func TestCatalog_ListServiceNodes_ConnectProxy_ACL(t *testing.T) {
@ -2491,7 +2484,6 @@ func TestCatalog_ListServiceNodes_ConnectNative(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -2504,7 +2496,7 @@ func TestCatalog_ListServiceNodes_ConnectNative(t *testing.T) {
args := structs.TestRegisterRequest(t) args := structs.TestRegisterRequest(t)
args.Service.Connect.Native = true args.Service.Connect.Native = true
var out struct{} var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List // List
req := structs.ServiceSpecificRequest{ req := structs.ServiceSpecificRequest{
@ -2513,10 +2505,10 @@ func TestCatalog_ListServiceNodes_ConnectNative(t *testing.T) {
TagFilter: false, TagFilter: false,
} }
var resp structs.IndexedServiceNodes var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1) assert.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0] v := resp.ServiceNodes[0]
assert.Equal(args.Service.Connect.Native, v.ServiceConnect.Native) assert.Equal(t, args.Service.Connect.Native, v.ServiceConnect.Native)
} }
func TestCatalog_NodeServices(t *testing.T) { func TestCatalog_NodeServices(t *testing.T) {
@ -2581,7 +2573,6 @@ func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -2593,7 +2584,7 @@ func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
// Register the service // Register the service
args := structs.TestRegisterRequestProxy(t) args := structs.TestRegisterRequestProxy(t)
var out struct{} var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List // List
req := structs.NodeSpecificRequest{ req := structs.NodeSpecificRequest{
@ -2601,12 +2592,12 @@ func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
Node: args.Node, Node: args.Node,
} }
var resp structs.IndexedNodeServices var resp structs.IndexedNodeServices
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
assert.Len(resp.NodeServices.Services, 1) assert.Len(t, resp.NodeServices.Services, 1)
v := resp.NodeServices.Services[args.Service.Service] v := resp.NodeServices.Services[args.Service.Service]
assert.Equal(structs.ServiceKindConnectProxy, v.Kind) assert.Equal(t, structs.ServiceKindConnectProxy, v.Kind)
assert.Equal(args.Service.Proxy.DestinationServiceName, v.Proxy.DestinationServiceName) assert.Equal(t, args.Service.Proxy.DestinationServiceName, v.Proxy.DestinationServiceName)
} }
func TestCatalog_NodeServices_ConnectNative(t *testing.T) { func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
@ -2616,7 +2607,6 @@ func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -2628,7 +2618,7 @@ func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
// Register the service // Register the service
args := structs.TestRegisterRequest(t) args := structs.TestRegisterRequest(t)
var out struct{} var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List // List
req := structs.NodeSpecificRequest{ req := structs.NodeSpecificRequest{
@ -2636,11 +2626,11 @@ func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
Node: args.Node, Node: args.Node,
} }
var resp structs.IndexedNodeServices var resp structs.IndexedNodeServices
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
assert.Len(resp.NodeServices.Services, 1) assert.Len(t, resp.NodeServices.Services, 1)
v := resp.NodeServices.Services[args.Service.Service] v := resp.NodeServices.Services[args.Service.Service]
assert.Equal(args.Service.Connect.Native, v.Connect.Native) assert.Equal(t, args.Service.Connect.Native, v.Connect.Native)
} }
// Used to check for a regression against a known bug // Used to check for a regression against a known bug
@ -2883,27 +2873,25 @@ func TestCatalog_NodeServices_ACL(t *testing.T) {
} }
t.Run("deny", func(t *testing.T) { t.Run("deny", func(t *testing.T) {
require := require.New(t)
args.Token = token("deny") args.Token = token("deny")
var reply structs.IndexedNodeServices var reply structs.IndexedNodeServices
err := msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &args, &reply) err := msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &args, &reply)
require.NoError(err) require.NoError(t, err)
require.Nil(reply.NodeServices) require.Nil(t, reply.NodeServices)
require.True(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
}) })
t.Run("allow", func(t *testing.T) { t.Run("allow", func(t *testing.T) {
require := require.New(t)
args.Token = token("read") args.Token = token("read")
var reply structs.IndexedNodeServices var reply structs.IndexedNodeServices
err := msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &args, &reply) err := msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &args, &reply)
require.NoError(err) require.NoError(t, err)
require.NotNil(reply.NodeServices) require.NotNil(t, reply.NodeServices)
require.False(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") require.False(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
}) })
} }

View File

@ -150,8 +150,6 @@ func TestConfigEntry_Apply_ACLDeny(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -191,16 +189,16 @@ operator = "write"
Name: "foo", Name: "foo",
} }
err = msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &args, &out) err = msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &args, &out)
require.NoError(err) require.NoError(t, err)
state := s1.fsm.State() state := s1.fsm.State()
_, entry, err := state.ConfigEntry(nil, structs.ServiceDefaults, "foo", nil) _, entry, err := state.ConfigEntry(nil, structs.ServiceDefaults, "foo", nil)
require.NoError(err) require.NoError(t, err)
serviceConf, ok := entry.(*structs.ServiceConfigEntry) serviceConf, ok := entry.(*structs.ServiceConfigEntry)
require.True(ok) require.True(t, ok)
require.Equal("foo", serviceConf.Name) require.Equal(t, "foo", serviceConf.Name)
require.Equal(structs.ServiceDefaults, serviceConf.Kind) require.Equal(t, structs.ServiceDefaults, serviceConf.Kind)
// Try to update the global proxy args with the anonymous token - this should fail. // Try to update the global proxy args with the anonymous token - this should fail.
proxyArgs := structs.ConfigEntryRequest{ proxyArgs := structs.ConfigEntryRequest{
@ -219,7 +217,7 @@ operator = "write"
// Now with the privileged token. // Now with the privileged token.
proxyArgs.WriteRequest.Token = id proxyArgs.WriteRequest.Token = id
err = msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &proxyArgs, &out) err = msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &proxyArgs, &out)
require.NoError(err) require.NoError(t, err)
} }
func TestConfigEntry_Get(t *testing.T) { func TestConfigEntry_Get(t *testing.T) {
@ -229,8 +227,6 @@ func TestConfigEntry_Get(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -243,7 +239,7 @@ func TestConfigEntry_Get(t *testing.T) {
Name: "foo", Name: "foo",
} }
state := s1.fsm.State() state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, entry)) require.NoError(t, state.EnsureConfigEntry(1, entry))
args := structs.ConfigEntryQuery{ args := structs.ConfigEntryQuery{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
@ -251,12 +247,12 @@ func TestConfigEntry_Get(t *testing.T) {
Datacenter: s1.config.Datacenter, Datacenter: s1.config.Datacenter,
} }
var out structs.ConfigEntryResponse var out structs.ConfigEntryResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Get", &args, &out)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Get", &args, &out))
serviceConf, ok := out.Entry.(*structs.ServiceConfigEntry) serviceConf, ok := out.Entry.(*structs.ServiceConfigEntry)
require.True(ok) require.True(t, ok)
require.Equal("foo", serviceConf.Name) require.Equal(t, "foo", serviceConf.Name)
require.Equal(structs.ServiceDefaults, serviceConf.Kind) require.Equal(t, structs.ServiceDefaults, serviceConf.Kind)
} }
func TestConfigEntry_Get_ACLDeny(t *testing.T) { func TestConfigEntry_Get_ACLDeny(t *testing.T) {
@ -266,8 +262,6 @@ func TestConfigEntry_Get_ACLDeny(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -290,11 +284,11 @@ operator = "read"
// Create some dummy service/proxy configs to be looked up. // Create some dummy service/proxy configs to be looked up.
state := s1.fsm.State() state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{ require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults, Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal, Name: structs.ProxyConfigGlobal,
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "foo", Name: "foo",
})) }))
@ -314,12 +308,12 @@ operator = "read"
// The "foo" service should work. // The "foo" service should work.
args.Name = "foo" args.Name = "foo"
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Get", &args, &out)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Get", &args, &out))
serviceConf, ok := out.Entry.(*structs.ServiceConfigEntry) serviceConf, ok := out.Entry.(*structs.ServiceConfigEntry)
require.True(ok) require.True(t, ok)
require.Equal("foo", serviceConf.Name) require.Equal(t, "foo", serviceConf.Name)
require.Equal(structs.ServiceDefaults, serviceConf.Kind) require.Equal(t, structs.ServiceDefaults, serviceConf.Kind)
} }
func TestConfigEntry_List(t *testing.T) { func TestConfigEntry_List(t *testing.T) {
@ -329,8 +323,6 @@ func TestConfigEntry_List(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -351,19 +343,19 @@ func TestConfigEntry_List(t *testing.T) {
}, },
}, },
} }
require.NoError(state.EnsureConfigEntry(1, expected.Entries[0])) require.NoError(t, state.EnsureConfigEntry(1, expected.Entries[0]))
require.NoError(state.EnsureConfigEntry(2, expected.Entries[1])) require.NoError(t, state.EnsureConfigEntry(2, expected.Entries[1]))
args := structs.ConfigEntryQuery{ args := structs.ConfigEntryQuery{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Datacenter: "dc1", Datacenter: "dc1",
} }
var out structs.IndexedConfigEntries var out structs.IndexedConfigEntries
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.List", &args, &out)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.List", &args, &out))
expected.Kind = structs.ServiceDefaults expected.Kind = structs.ServiceDefaults
expected.QueryMeta = out.QueryMeta expected.QueryMeta = out.QueryMeta
require.Equal(expected, out) require.Equal(t, expected, out)
} }
func TestConfigEntry_ListAll(t *testing.T) { func TestConfigEntry_ListAll(t *testing.T) {
@ -466,8 +458,6 @@ func TestConfigEntry_List_ACLDeny(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -490,15 +480,15 @@ operator = "read"
// Create some dummy service/proxy configs to be looked up. // Create some dummy service/proxy configs to be looked up.
state := s1.fsm.State() state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{ require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults, Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal, Name: structs.ProxyConfigGlobal,
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "foo", Name: "foo",
})) }))
require.NoError(state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "db", Name: "db",
})) }))
@ -511,26 +501,26 @@ operator = "read"
} }
var out structs.IndexedConfigEntries var out structs.IndexedConfigEntries
err := msgpackrpc.CallWithCodec(codec, "ConfigEntry.List", &args, &out) err := msgpackrpc.CallWithCodec(codec, "ConfigEntry.List", &args, &out)
require.NoError(err) require.NoError(t, err)
serviceConf, ok := out.Entries[0].(*structs.ServiceConfigEntry) serviceConf, ok := out.Entries[0].(*structs.ServiceConfigEntry)
require.Len(out.Entries, 1) require.Len(t, out.Entries, 1)
require.True(ok) require.True(t, ok)
require.Equal("foo", serviceConf.Name) require.Equal(t, "foo", serviceConf.Name)
require.Equal(structs.ServiceDefaults, serviceConf.Kind) require.Equal(t, structs.ServiceDefaults, serviceConf.Kind)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// Get the global proxy config. // Get the global proxy config.
args.Kind = structs.ProxyDefaults args.Kind = structs.ProxyDefaults
err = msgpackrpc.CallWithCodec(codec, "ConfigEntry.List", &args, &out) err = msgpackrpc.CallWithCodec(codec, "ConfigEntry.List", &args, &out)
require.NoError(err) require.NoError(t, err)
proxyConf, ok := out.Entries[0].(*structs.ProxyConfigEntry) proxyConf, ok := out.Entries[0].(*structs.ProxyConfigEntry)
require.Len(out.Entries, 1) require.Len(t, out.Entries, 1)
require.True(ok) require.True(t, ok)
require.Equal(structs.ProxyConfigGlobal, proxyConf.Name) require.Equal(t, structs.ProxyConfigGlobal, proxyConf.Name)
require.Equal(structs.ProxyDefaults, proxyConf.Kind) require.Equal(t, structs.ProxyDefaults, proxyConf.Kind)
require.False(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") require.False(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
} }
func TestConfigEntry_ListAll_ACLDeny(t *testing.T) { func TestConfigEntry_ListAll_ACLDeny(t *testing.T) {
@ -540,8 +530,6 @@ func TestConfigEntry_ListAll_ACLDeny(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -564,15 +552,15 @@ operator = "read"
// Create some dummy service/proxy configs to be looked up. // Create some dummy service/proxy configs to be looked up.
state := s1.fsm.State() state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{ require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults, Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal, Name: structs.ProxyConfigGlobal,
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "foo", Name: "foo",
})) }))
require.NoError(state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "db", Name: "db",
})) }))
@ -585,8 +573,8 @@ operator = "read"
} }
var out structs.IndexedGenericConfigEntries var out structs.IndexedGenericConfigEntries
err := msgpackrpc.CallWithCodec(codec, "ConfigEntry.ListAll", &args, &out) err := msgpackrpc.CallWithCodec(codec, "ConfigEntry.ListAll", &args, &out)
require.NoError(err) require.NoError(t, err)
require.Len(out.Entries, 2) require.Len(t, out.Entries, 2)
svcIndex := 0 svcIndex := 0
proxyIndex := 1 proxyIndex := 1
if out.Entries[0].GetKind() == structs.ProxyDefaults { if out.Entries[0].GetKind() == structs.ProxyDefaults {
@ -595,15 +583,15 @@ operator = "read"
} }
svcConf, ok := out.Entries[svcIndex].(*structs.ServiceConfigEntry) svcConf, ok := out.Entries[svcIndex].(*structs.ServiceConfigEntry)
require.True(ok) require.True(t, ok)
proxyConf, ok := out.Entries[proxyIndex].(*structs.ProxyConfigEntry) proxyConf, ok := out.Entries[proxyIndex].(*structs.ProxyConfigEntry)
require.True(ok) require.True(t, ok)
require.Equal("foo", svcConf.Name) require.Equal(t, "foo", svcConf.Name)
require.Equal(structs.ServiceDefaults, svcConf.Kind) require.Equal(t, structs.ServiceDefaults, svcConf.Kind)
require.Equal(structs.ProxyConfigGlobal, proxyConf.Name) require.Equal(t, structs.ProxyConfigGlobal, proxyConf.Name)
require.Equal(structs.ProxyDefaults, proxyConf.Kind) require.Equal(t, structs.ProxyDefaults, proxyConf.Kind)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
} }
func TestConfigEntry_Delete(t *testing.T) { func TestConfigEntry_Delete(t *testing.T) {
@ -686,8 +674,6 @@ func TestConfigEntry_DeleteCAS(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
dir, s := testServer(t) dir, s := testServer(t)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
defer s.Shutdown() defer s.Shutdown()
@ -703,11 +689,11 @@ func TestConfigEntry_DeleteCAS(t *testing.T) {
Name: "foo", Name: "foo",
} }
state := s.fsm.State() state := s.fsm.State()
require.NoError(state.EnsureConfigEntry(1, entry)) require.NoError(t, state.EnsureConfigEntry(1, entry))
// Verify it's there. // Verify it's there.
_, existing, err := state.ConfigEntry(nil, entry.Kind, entry.Name, nil) _, existing, err := state.ConfigEntry(nil, entry.Kind, entry.Name, nil)
require.NoError(err) require.NoError(t, err)
// Send a delete CAS request with an invalid index. // Send a delete CAS request with an invalid index.
args := structs.ConfigEntryRequest{ args := structs.ConfigEntryRequest{
@ -718,24 +704,24 @@ func TestConfigEntry_DeleteCAS(t *testing.T) {
args.Entry.GetRaftIndex().ModifyIndex = existing.GetRaftIndex().ModifyIndex - 1 args.Entry.GetRaftIndex().ModifyIndex = existing.GetRaftIndex().ModifyIndex - 1
var rsp structs.ConfigEntryDeleteResponse var rsp structs.ConfigEntryDeleteResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &rsp)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &rsp))
require.False(rsp.Deleted) require.False(t, rsp.Deleted)
// Verify the entry was not deleted. // Verify the entry was not deleted.
_, existing, err = s.fsm.State().ConfigEntry(nil, structs.ServiceDefaults, "foo", nil) _, existing, err = s.fsm.State().ConfigEntry(nil, structs.ServiceDefaults, "foo", nil)
require.NoError(err) require.NoError(t, err)
require.NotNil(existing) require.NotNil(t, existing)
// Restore the valid index and try again. // Restore the valid index and try again.
args.Entry.GetRaftIndex().ModifyIndex = existing.GetRaftIndex().ModifyIndex args.Entry.GetRaftIndex().ModifyIndex = existing.GetRaftIndex().ModifyIndex
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &rsp)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &rsp))
require.True(rsp.Deleted) require.True(t, rsp.Deleted)
// Verify the entry was deleted. // Verify the entry was deleted.
_, existing, err = s.fsm.State().ConfigEntry(nil, structs.ServiceDefaults, "foo", nil) _, existing, err = s.fsm.State().ConfigEntry(nil, structs.ServiceDefaults, "foo", nil)
require.NoError(err) require.NoError(t, err)
require.Nil(existing) require.Nil(t, existing)
} }
func TestConfigEntry_Delete_ACLDeny(t *testing.T) { func TestConfigEntry_Delete_ACLDeny(t *testing.T) {
@ -745,8 +731,6 @@ func TestConfigEntry_Delete_ACLDeny(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -769,11 +753,11 @@ operator = "write"
// Create some dummy service/proxy configs to be looked up. // Create some dummy service/proxy configs to be looked up.
state := s1.fsm.State() state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{ require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults, Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal, Name: structs.ProxyConfigGlobal,
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "foo", Name: "foo",
})) }))
@ -796,12 +780,12 @@ operator = "write"
args.Entry = &structs.ServiceConfigEntry{ args.Entry = &structs.ServiceConfigEntry{
Name: "foo", Name: "foo",
} }
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &out)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &out))
// Verify the entry was deleted. // Verify the entry was deleted.
_, existing, err := state.ConfigEntry(nil, structs.ServiceDefaults, "foo", nil) _, existing, err := state.ConfigEntry(nil, structs.ServiceDefaults, "foo", nil)
require.NoError(err) require.NoError(t, err)
require.Nil(existing) require.Nil(t, existing)
// Try to delete the global proxy config without a token. // Try to delete the global proxy config without a token.
args = structs.ConfigEntryRequest{ args = structs.ConfigEntryRequest{
@ -817,11 +801,11 @@ operator = "write"
// Now delete with a valid token. // Now delete with a valid token.
args.WriteRequest.Token = id args.WriteRequest.Token = id
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &out)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &out))
_, existing, err = state.ConfigEntry(nil, structs.ServiceDefaults, "foo", nil) _, existing, err = state.ConfigEntry(nil, structs.ServiceDefaults, "foo", nil)
require.NoError(err) require.NoError(t, err)
require.Nil(existing) require.Nil(t, existing)
} }
func TestConfigEntry_ResolveServiceConfig(t *testing.T) { func TestConfigEntry_ResolveServiceConfig(t *testing.T) {
@ -831,8 +815,6 @@ func TestConfigEntry_ResolveServiceConfig(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -841,19 +823,19 @@ func TestConfigEntry_ResolveServiceConfig(t *testing.T) {
// Create a dummy proxy/service config in the state store to look up. // Create a dummy proxy/service config in the state store to look up.
state := s1.fsm.State() state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{ require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults, Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal, Name: structs.ProxyConfigGlobal,
Config: map[string]interface{}{ Config: map[string]interface{}{
"foo": 1, "foo": 1,
}, },
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "foo", Name: "foo",
Protocol: "http", Protocol: "http",
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "bar", Name: "bar",
Protocol: "grpc", Protocol: "grpc",
@ -865,7 +847,7 @@ func TestConfigEntry_ResolveServiceConfig(t *testing.T) {
Upstreams: []string{"bar", "baz"}, Upstreams: []string{"bar", "baz"},
} }
var out structs.ServiceConfigResponse var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
expected := structs.ServiceConfigResponse{ expected := structs.ServiceConfigResponse{
ProxyConfig: map[string]interface{}{ ProxyConfig: map[string]interface{}{
@ -880,14 +862,14 @@ func TestConfigEntry_ResolveServiceConfig(t *testing.T) {
// Don't know what this is deterministically // Don't know what this is deterministically
QueryMeta: out.QueryMeta, QueryMeta: out.QueryMeta,
} }
require.Equal(expected, out) require.Equal(t, expected, out)
_, entry, err := s1.fsm.State().ConfigEntry(nil, structs.ProxyDefaults, structs.ProxyConfigGlobal, nil) _, entry, err := s1.fsm.State().ConfigEntry(nil, structs.ProxyDefaults, structs.ProxyConfigGlobal, nil)
require.NoError(err) require.NoError(t, err)
require.NotNil(entry) require.NotNil(t, entry)
proxyConf, ok := entry.(*structs.ProxyConfigEntry) proxyConf, ok := entry.(*structs.ProxyConfigEntry)
require.True(ok) require.True(t, ok)
require.Equal(map[string]interface{}{"foo": 1}, proxyConf.Config) require.Equal(t, map[string]interface{}{"foo": 1}, proxyConf.Config)
} }
func TestConfigEntry_ResolveServiceConfig_TransparentProxy(t *testing.T) { func TestConfigEntry_ResolveServiceConfig_TransparentProxy(t *testing.T) {
@ -1426,8 +1408,6 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -1443,19 +1423,19 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
// TestConfigEntry_ResolveServiceConfig_Upstreams_Blocking // TestConfigEntry_ResolveServiceConfig_Upstreams_Blocking
state := s1.fsm.State() state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{ require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults, Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal, Name: structs.ProxyConfigGlobal,
Config: map[string]interface{}{ Config: map[string]interface{}{
"global": 1, "global": 1,
}, },
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "foo", Name: "foo",
Protocol: "grpc", Protocol: "grpc",
})) }))
require.NoError(state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "bar", Name: "bar",
Protocol: "http", Protocol: "http",
@ -1465,7 +1445,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
{ // Verify that we get the results of proxy-defaults and service-defaults for 'foo'. { // Verify that we get the results of proxy-defaults and service-defaults for 'foo'.
var out structs.ServiceConfigResponse var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
&structs.ServiceConfigRequest{ &structs.ServiceConfigRequest{
Name: "foo", Name: "foo",
Datacenter: "dc1", Datacenter: "dc1",
@ -1480,7 +1460,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
}, },
QueryMeta: out.QueryMeta, QueryMeta: out.QueryMeta,
} }
require.Equal(expected, out) require.Equal(t, expected, out)
index = out.Index index = out.Index
} }
@ -1490,7 +1470,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
start := time.Now() start := time.Now()
go func() { go func() {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
require.NoError(state.DeleteConfigEntry(index+1, require.NoError(t, state.DeleteConfigEntry(index+1,
structs.ServiceDefaults, structs.ServiceDefaults,
"foo", "foo",
nil, nil,
@ -1499,7 +1479,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
// Re-run the query // Re-run the query
var out structs.ServiceConfigResponse var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
&structs.ServiceConfigRequest{ &structs.ServiceConfigRequest{
Name: "foo", Name: "foo",
Datacenter: "dc1", Datacenter: "dc1",
@ -1512,10 +1492,10 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
)) ))
// Should block at least 100ms // Should block at least 100ms
require.True(time.Since(start) >= 100*time.Millisecond, "too fast") require.True(t, time.Since(start) >= 100*time.Millisecond, "too fast")
// Check the indexes // Check the indexes
require.Equal(out.Index, index+1) require.Equal(t, out.Index, index+1)
expected := structs.ServiceConfigResponse{ expected := structs.ServiceConfigResponse{
ProxyConfig: map[string]interface{}{ ProxyConfig: map[string]interface{}{
@ -1523,14 +1503,14 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
}, },
QueryMeta: out.QueryMeta, QueryMeta: out.QueryMeta,
} }
require.Equal(expected, out) require.Equal(t, expected, out)
index = out.Index index = out.Index
} }
{ // Verify that we get the results of proxy-defaults and service-defaults for 'bar'. { // Verify that we get the results of proxy-defaults and service-defaults for 'bar'.
var out structs.ServiceConfigResponse var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
&structs.ServiceConfigRequest{ &structs.ServiceConfigRequest{
Name: "bar", Name: "bar",
Datacenter: "dc1", Datacenter: "dc1",
@ -1545,7 +1525,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
}, },
QueryMeta: out.QueryMeta, QueryMeta: out.QueryMeta,
} }
require.Equal(expected, out) require.Equal(t, expected, out)
index = out.Index index = out.Index
} }
@ -1555,7 +1535,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
start := time.Now() start := time.Now()
go func() { go func() {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
require.NoError(state.DeleteConfigEntry(index+1, require.NoError(t, state.DeleteConfigEntry(index+1,
structs.ProxyDefaults, structs.ProxyDefaults,
structs.ProxyConfigGlobal, structs.ProxyConfigGlobal,
nil, nil,
@ -1564,7 +1544,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
// Re-run the query // Re-run the query
var out structs.ServiceConfigResponse var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
&structs.ServiceConfigRequest{ &structs.ServiceConfigRequest{
Name: "bar", Name: "bar",
Datacenter: "dc1", Datacenter: "dc1",
@ -1577,10 +1557,10 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
)) ))
// Should block at least 100ms // Should block at least 100ms
require.True(time.Since(start) >= 100*time.Millisecond, "too fast") require.True(t, time.Since(start) >= 100*time.Millisecond, "too fast")
// Check the indexes // Check the indexes
require.Equal(out.Index, index+1) require.Equal(t, out.Index, index+1)
expected := structs.ServiceConfigResponse{ expected := structs.ServiceConfigResponse{
ProxyConfig: map[string]interface{}{ ProxyConfig: map[string]interface{}{
@ -1588,7 +1568,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
}, },
QueryMeta: out.QueryMeta, QueryMeta: out.QueryMeta,
} }
require.Equal(expected, out) require.Equal(t, expected, out)
} }
} }
@ -1798,8 +1778,6 @@ func TestConfigEntry_ResolveServiceConfig_UpstreamProxyDefaultsProtocol(t *testi
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -1808,26 +1786,26 @@ func TestConfigEntry_ResolveServiceConfig_UpstreamProxyDefaultsProtocol(t *testi
// Create a dummy proxy/service config in the state store to look up. // Create a dummy proxy/service config in the state store to look up.
state := s1.fsm.State() state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{ require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults, Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal, Name: structs.ProxyConfigGlobal,
Config: map[string]interface{}{ Config: map[string]interface{}{
"protocol": "http", "protocol": "http",
}, },
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "foo", Name: "foo",
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "bar", Name: "bar",
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "other", Name: "other",
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "alreadyprotocol", Name: "alreadyprotocol",
Protocol: "grpc", Protocol: "grpc",
@ -1839,7 +1817,7 @@ func TestConfigEntry_ResolveServiceConfig_UpstreamProxyDefaultsProtocol(t *testi
Upstreams: []string{"bar", "other", "alreadyprotocol", "dne"}, Upstreams: []string{"bar", "other", "alreadyprotocol", "dne"},
} }
var out structs.ServiceConfigResponse var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
expected := structs.ServiceConfigResponse{ expected := structs.ServiceConfigResponse{
ProxyConfig: map[string]interface{}{ ProxyConfig: map[string]interface{}{
@ -1862,7 +1840,7 @@ func TestConfigEntry_ResolveServiceConfig_UpstreamProxyDefaultsProtocol(t *testi
// Don't know what this is deterministically // Don't know what this is deterministically
QueryMeta: out.QueryMeta, QueryMeta: out.QueryMeta,
} }
require.Equal(expected, out) require.Equal(t, expected, out)
} }
func TestConfigEntry_ResolveServiceConfig_ProxyDefaultsProtocol_UsedForAllUpstreams(t *testing.T) { func TestConfigEntry_ResolveServiceConfig_ProxyDefaultsProtocol_UsedForAllUpstreams(t *testing.T) {
@ -1872,8 +1850,6 @@ func TestConfigEntry_ResolveServiceConfig_ProxyDefaultsProtocol_UsedForAllUpstre
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -1882,7 +1858,7 @@ func TestConfigEntry_ResolveServiceConfig_ProxyDefaultsProtocol_UsedForAllUpstre
// Create a dummy proxy/service config in the state store to look up. // Create a dummy proxy/service config in the state store to look up.
state := s1.fsm.State() state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{ require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults, Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal, Name: structs.ProxyConfigGlobal,
Config: map[string]interface{}{ Config: map[string]interface{}{
@ -1896,7 +1872,7 @@ func TestConfigEntry_ResolveServiceConfig_ProxyDefaultsProtocol_UsedForAllUpstre
Upstreams: []string{"bar"}, Upstreams: []string{"bar"},
} }
var out structs.ServiceConfigResponse var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
expected := structs.ServiceConfigResponse{ expected := structs.ServiceConfigResponse{
ProxyConfig: map[string]interface{}{ ProxyConfig: map[string]interface{}{
@ -1910,7 +1886,7 @@ func TestConfigEntry_ResolveServiceConfig_ProxyDefaultsProtocol_UsedForAllUpstre
// Don't know what this is deterministically // Don't know what this is deterministically
QueryMeta: out.QueryMeta, QueryMeta: out.QueryMeta,
} }
require.Equal(expected, out) require.Equal(t, expected, out)
} }
func TestConfigEntry_ResolveServiceConfigNoConfig(t *testing.T) { func TestConfigEntry_ResolveServiceConfigNoConfig(t *testing.T) {
@ -1920,8 +1896,6 @@ func TestConfigEntry_ResolveServiceConfigNoConfig(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -1936,7 +1910,7 @@ func TestConfigEntry_ResolveServiceConfigNoConfig(t *testing.T) {
Upstreams: []string{"bar", "baz"}, Upstreams: []string{"bar", "baz"},
} }
var out structs.ServiceConfigResponse var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
expected := structs.ServiceConfigResponse{ expected := structs.ServiceConfigResponse{
ProxyConfig: nil, ProxyConfig: nil,
@ -1944,7 +1918,7 @@ func TestConfigEntry_ResolveServiceConfigNoConfig(t *testing.T) {
// Don't know what this is deterministically // Don't know what this is deterministically
QueryMeta: out.QueryMeta, QueryMeta: out.QueryMeta,
} }
require.Equal(expected, out) require.Equal(t, expected, out)
} }
func TestConfigEntry_ResolveServiceConfig_ACLDeny(t *testing.T) { func TestConfigEntry_ResolveServiceConfig_ACLDeny(t *testing.T) {
@ -1954,8 +1928,6 @@ func TestConfigEntry_ResolveServiceConfig_ACLDeny(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -1978,15 +1950,15 @@ operator = "write"
// Create some dummy service/proxy configs to be looked up. // Create some dummy service/proxy configs to be looked up.
state := s1.fsm.State() state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{ require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults, Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal, Name: structs.ProxyConfigGlobal,
})) }))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "foo", Name: "foo",
})) }))
require.NoError(state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{ require.NoError(t, state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "db", Name: "db",
})) }))
@ -2005,7 +1977,7 @@ operator = "write"
// The "foo" service should work. // The "foo" service should work.
args.Name = "foo" args.Name = "foo"
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
} }

View File

@ -38,8 +38,6 @@ func TestConnectCARoots(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -54,29 +52,29 @@ func TestConnectCARoots(t *testing.T) {
ca2 := connect.TestCA(t, nil) ca2 := connect.TestCA(t, nil)
ca2.Active = false ca2.Active = false
idx, _, err := state.CARoots(nil) idx, _, err := state.CARoots(nil)
require.NoError(err) require.NoError(t, err)
ok, err := state.CARootSetCAS(idx, idx, []*structs.CARoot{ca1, ca2}) ok, err := state.CARootSetCAS(idx, idx, []*structs.CARoot{ca1, ca2})
assert.True(ok) assert.True(t, ok)
require.NoError(err) require.NoError(t, err)
_, caCfg, err := state.CAConfig(nil) _, caCfg, err := state.CAConfig(nil)
require.NoError(err) require.NoError(t, err)
// Request // Request
args := &structs.DCSpecificRequest{ args := &structs.DCSpecificRequest{
Datacenter: "dc1", Datacenter: "dc1",
} }
var reply structs.IndexedCARoots var reply structs.IndexedCARoots
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
// Verify // Verify
assert.Equal(ca1.ID, reply.ActiveRootID) assert.Equal(t, ca1.ID, reply.ActiveRootID)
assert.Len(reply.Roots, 2) assert.Len(t, reply.Roots, 2)
for _, r := range reply.Roots { for _, r := range reply.Roots {
// These must never be set, for security // These must never be set, for security
assert.Equal("", r.SigningCert) assert.Equal(t, "", r.SigningCert)
assert.Equal("", r.SigningKey) assert.Equal(t, "", r.SigningKey)
} }
assert.Equal(fmt.Sprintf("%s.consul", caCfg.ClusterID), reply.TrustDomain) assert.Equal(t, fmt.Sprintf("%s.consul", caCfg.ClusterID), reply.TrustDomain)
} }
func TestConnectCAConfig_GetSet(t *testing.T) { func TestConnectCAConfig_GetSet(t *testing.T) {
@ -86,7 +84,6 @@ func TestConnectCAConfig_GetSet(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -101,14 +98,14 @@ func TestConnectCAConfig_GetSet(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var reply structs.CAConfiguration var reply structs.CAConfiguration
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply)) assert.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config) actual, err := ca.ParseConsulCAConfig(reply.Config)
assert.NoError(err) assert.NoError(t, err)
expected, err := ca.ParseConsulCAConfig(s1.config.CAConfig.Config) expected, err := ca.ParseConsulCAConfig(s1.config.CAConfig.Config)
assert.NoError(err) assert.NoError(t, err)
assert.Equal(reply.Provider, s1.config.CAConfig.Provider) assert.Equal(t, reply.Provider, s1.config.CAConfig.Provider)
assert.Equal(actual, expected) assert.Equal(t, actual, expected)
} }
testState := map[string]string{"foo": "bar"} testState := map[string]string{"foo": "bar"}
@ -141,15 +138,15 @@ func TestConnectCAConfig_GetSet(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var reply structs.CAConfiguration var reply structs.CAConfiguration
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply)) assert.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config) actual, err := ca.ParseConsulCAConfig(reply.Config)
assert.NoError(err) assert.NoError(t, err)
expected, err := ca.ParseConsulCAConfig(newConfig.Config) expected, err := ca.ParseConsulCAConfig(newConfig.Config)
assert.NoError(err) assert.NoError(t, err)
assert.Equal(reply.Provider, newConfig.Provider) assert.Equal(t, reply.Provider, newConfig.Provider)
assert.Equal(actual, expected) assert.Equal(t, actual, expected)
assert.Equal(testState, reply.State) assert.Equal(t, testState, reply.State)
} }
} }
@ -254,7 +251,6 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
// Setup a server with a built-in CA that as artificially disabled cross // Setup a server with a built-in CA that as artificially disabled cross
// signing. This is simpler than running tests with external CA dependencies. // signing. This is simpler than running tests with external CA dependencies.
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
@ -272,8 +268,8 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var rootList structs.IndexedCARoots var rootList structs.IndexedCARoots
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(rootList.Roots, 1) require.Len(t, rootList.Roots, 1)
oldRoot := rootList.Roots[0] oldRoot := rootList.Roots[0]
// Get the starting config // Get the starting config
@ -282,20 +278,20 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var reply structs.CAConfiguration var reply structs.CAConfiguration
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config) actual, err := ca.ParseConsulCAConfig(reply.Config)
require.NoError(err) require.NoError(t, err)
expected, err := ca.ParseConsulCAConfig(s1.config.CAConfig.Config) expected, err := ca.ParseConsulCAConfig(s1.config.CAConfig.Config)
require.NoError(err) require.NoError(t, err)
require.Equal(reply.Provider, s1.config.CAConfig.Provider) require.Equal(t, reply.Provider, s1.config.CAConfig.Provider)
require.Equal(actual, expected) require.Equal(t, actual, expected)
} }
// Update to a new CA with different key. This should fail since the existing // Update to a new CA with different key. This should fail since the existing
// CA doesn't support cross signing so can't rotate safely. // CA doesn't support cross signing so can't rotate safely.
_, newKey, err := connect.GeneratePrivateKey() _, newKey, err := connect.GeneratePrivateKey()
require.NoError(err) require.NoError(t, err)
newConfig := &structs.CAConfiguration{ newConfig := &structs.CAConfiguration{
Provider: "consul", Provider: "consul",
Config: map[string]interface{}{ Config: map[string]interface{}{
@ -309,7 +305,7 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
} }
var reply interface{} var reply interface{}
err := msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply) err := msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply)
require.EqualError(err, "The current CA Provider does not support cross-signing. "+ require.EqualError(t, err, "The current CA Provider does not support cross-signing. "+
"You can try again with ForceWithoutCrossSigningSet but this may cause disruption"+ "You can try again with ForceWithoutCrossSigningSet but this may cause disruption"+
" - see documentation for more.") " - see documentation for more.")
} }
@ -323,7 +319,7 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
} }
var reply interface{} var reply interface{}
err := msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply) err := msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply)
require.NoError(err) require.NoError(t, err)
} }
// Make sure the new root has been added but with no cross-signed intermediate // Make sure the new root has been added but with no cross-signed intermediate
@ -332,23 +328,23 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var reply structs.IndexedCARoots var reply structs.IndexedCARoots
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
require.Len(reply.Roots, 2) require.Len(t, reply.Roots, 2)
for _, r := range reply.Roots { for _, r := range reply.Roots {
if r.ID == oldRoot.ID { if r.ID == oldRoot.ID {
// The old root should no longer be marked as the active root, // The old root should no longer be marked as the active root,
// and none of its other fields should have changed. // and none of its other fields should have changed.
require.False(r.Active) require.False(t, r.Active)
require.Equal(r.Name, oldRoot.Name) require.Equal(t, r.Name, oldRoot.Name)
require.Equal(r.RootCert, oldRoot.RootCert) require.Equal(t, r.RootCert, oldRoot.RootCert)
require.Equal(r.SigningCert, oldRoot.SigningCert) require.Equal(t, r.SigningCert, oldRoot.SigningCert)
require.Equal(r.IntermediateCerts, oldRoot.IntermediateCerts) require.Equal(t, r.IntermediateCerts, oldRoot.IntermediateCerts)
} else { } else {
// The new root should NOT have a valid cross-signed cert from the old // The new root should NOT have a valid cross-signed cert from the old
// root as an intermediate. // root as an intermediate.
require.True(r.Active) require.True(t, r.Active)
require.Empty(r.IntermediateCerts) require.Empty(t, r.IntermediateCerts)
} }
} }
} }
@ -664,9 +660,6 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
require := require.New(t)
// Initialize primary as the primary DC // Initialize primary as the primary DC
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.Datacenter = "primary" c.Datacenter = "primary"
@ -693,8 +686,8 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
// Capture the current root // Capture the current root
rootList, activeRoot, err := getTestRoots(s1, "primary") rootList, activeRoot, err := getTestRoots(s1, "primary")
require.NoError(err) require.NoError(t, err)
require.Len(rootList.Roots, 1) require.Len(t, rootList.Roots, 1)
rootCert := activeRoot rootCert := activeRoot
testrpc.WaitForActiveCARoot(t, s1.RPC, "primary", rootCert) testrpc.WaitForActiveCARoot(t, s1.RPC, "primary", rootCert)
@ -702,15 +695,15 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
// Capture the current intermediate // Capture the current intermediate
rootList, activeRoot, err = getTestRoots(s2, "secondary") rootList, activeRoot, err = getTestRoots(s2, "secondary")
require.NoError(err) require.NoError(t, err)
require.Len(rootList.Roots, 1) require.Len(t, rootList.Roots, 1)
require.Len(activeRoot.IntermediateCerts, 1) require.Len(t, activeRoot.IntermediateCerts, 1)
oldIntermediatePEM := activeRoot.IntermediateCerts[0] oldIntermediatePEM := activeRoot.IntermediateCerts[0]
// Update the secondary CA config to use a new private key, which should // Update the secondary CA config to use a new private key, which should
// cause a re-signing with a new intermediate. // cause a re-signing with a new intermediate.
_, newKey, err := connect.GeneratePrivateKey() _, newKey, err := connect.GeneratePrivateKey()
assert.NoError(err) assert.NoError(t, err)
newConfig := &structs.CAConfiguration{ newConfig := &structs.CAConfiguration{
Provider: "consul", Provider: "consul",
Config: map[string]interface{}{ Config: map[string]interface{}{
@ -725,7 +718,7 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
} }
var reply interface{} var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
} }
// Make sure the new intermediate has replaced the old one in the active root, // Make sure the new intermediate has replaced the old one in the active root,
@ -736,12 +729,12 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
Datacenter: "secondary", Datacenter: "secondary",
} }
var reply structs.IndexedCARoots var reply structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
require.Len(reply.Roots, 1) require.Len(t, reply.Roots, 1)
require.Len(reply.Roots[0].IntermediateCerts, 1) require.Len(t, reply.Roots[0].IntermediateCerts, 1)
newIntermediatePEM = reply.Roots[0].IntermediateCerts[0] newIntermediatePEM = reply.Roots[0].IntermediateCerts[0]
require.NotEqual(oldIntermediatePEM, newIntermediatePEM) require.NotEqual(t, oldIntermediatePEM, newIntermediatePEM)
require.Equal(reply.Roots[0].RootCert, rootCert.RootCert) require.Equal(t, reply.Roots[0].RootCert, rootCert.RootCert)
} }
// Verify the new config was set. // Verify the new config was set.
@ -750,14 +743,14 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
Datacenter: "secondary", Datacenter: "secondary",
} }
var reply structs.CAConfiguration var reply structs.CAConfiguration
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config) actual, err := ca.ParseConsulCAConfig(reply.Config)
require.NoError(err) require.NoError(t, err)
expected, err := ca.ParseConsulCAConfig(newConfig.Config) expected, err := ca.ParseConsulCAConfig(newConfig.Config)
require.NoError(err) require.NoError(t, err)
assert.Equal(reply.Provider, newConfig.Provider) assert.Equal(t, reply.Provider, newConfig.Provider)
assert.Equal(actual, expected) assert.Equal(t, actual, expected)
} }
// Verify that new leaf certs get the new intermediate bundled // Verify that new leaf certs get the new intermediate bundled
@ -770,28 +763,28 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
CSR: csr, CSR: csr,
} }
var reply structs.IssuedCert var reply structs.IssuedCert
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
// Verify the leaf cert has the new intermediate. // Verify the leaf cert has the new intermediate.
{ {
roots := x509.NewCertPool() roots := x509.NewCertPool()
assert.True(roots.AppendCertsFromPEM([]byte(rootCert.RootCert))) assert.True(t, roots.AppendCertsFromPEM([]byte(rootCert.RootCert)))
leaf, err := connect.ParseCert(reply.CertPEM) leaf, err := connect.ParseCert(reply.CertPEM)
require.NoError(err) require.NoError(t, err)
intermediates := x509.NewCertPool() intermediates := x509.NewCertPool()
require.True(intermediates.AppendCertsFromPEM([]byte(newIntermediatePEM))) require.True(t, intermediates.AppendCertsFromPEM([]byte(newIntermediatePEM)))
_, err = leaf.Verify(x509.VerifyOptions{ _, err = leaf.Verify(x509.VerifyOptions{
Roots: roots, Roots: roots,
Intermediates: intermediates, Intermediates: intermediates,
}) })
require.NoError(err) require.NoError(t, err)
} }
// Verify other fields // Verify other fields
assert.Equal("web", reply.Service) assert.Equal(t, "web", reply.Service)
assert.Equal(spiffeId.URI().String(), reply.ServiceURI) assert.Equal(t, spiffeId.URI().String(), reply.ServiceURI)
} }
// Update a minor field in the config that doesn't trigger an intermediate refresh. // Update a minor field in the config that doesn't trigger an intermediate refresh.
@ -810,7 +803,7 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
} }
var reply interface{} var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
} }
} }
} }
@ -840,8 +833,6 @@ func TestConnectCASign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(fmt.Sprintf("%s-%d", tt.caKeyType, tt.caKeyBits), func(t *testing.T) { t.Run(fmt.Sprintf("%s-%d", tt.caKeyType, tt.caKeyBits), func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(cfg *Config) { dir1, s1 := testServerWithConfig(t, func(cfg *Config) {
cfg.PrimaryDatacenter = "dc1" cfg.PrimaryDatacenter = "dc1"
cfg.CAConfig.Config["PrivateKeyType"] = tt.caKeyType cfg.CAConfig.Config["PrivateKeyType"] = tt.caKeyType
@ -864,7 +855,7 @@ func TestConnectCASign(t *testing.T) {
CSR: csr, CSR: csr,
} }
var reply structs.IssuedCert var reply structs.IssuedCert
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
// Generate a second CSR and request signing // Generate a second CSR and request signing
spiffeId2 := connect.TestSpiffeIDService(t, "web2") spiffeId2 := connect.TestSpiffeIDService(t, "web2")
@ -875,20 +866,20 @@ func TestConnectCASign(t *testing.T) {
} }
var reply2 structs.IssuedCert var reply2 structs.IssuedCert
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply2)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply2))
require.True(reply2.ModifyIndex > reply.ModifyIndex) require.True(t, reply2.ModifyIndex > reply.ModifyIndex)
// Get the current CA // Get the current CA
state := s1.fsm.State() state := s1.fsm.State()
_, ca, err := state.CARootActive(nil) _, ca, err := state.CARootActive(nil)
require.NoError(err) require.NoError(t, err)
// Verify that the cert is signed by the CA // Verify that the cert is signed by the CA
require.NoError(connect.ValidateLeaf(ca.RootCert, reply.CertPEM, nil)) require.NoError(t, connect.ValidateLeaf(ca.RootCert, reply.CertPEM, nil))
// Verify other fields // Verify other fields
assert.Equal("web", reply.Service) assert.Equal(t, "web", reply.Service)
assert.Equal(spiffeId.URI().String(), reply.ServiceURI) assert.Equal(t, spiffeId.URI().String(), reply.ServiceURI)
}) })
} }
} }
@ -931,7 +922,6 @@ func TestConnectCASign_rateLimit(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.Datacenter = "dc1" c.Datacenter = "dc1"
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
@ -976,7 +966,7 @@ func TestConnectCASign_rateLimit(t *testing.T) {
} else if err.Error() == ErrRateLimited.Error() { } else if err.Error() == ErrRateLimited.Error() {
limitedCount++ limitedCount++
} else { } else {
require.NoError(err) require.NoError(t, err)
} }
} }
// I've only ever seen this as 1/9 however if the test runs slowly on an // I've only ever seen this as 1/9 however if the test runs slowly on an
@ -986,8 +976,8 @@ func TestConnectCASign_rateLimit(t *testing.T) {
// check that some limiting is being applied. Note that we can't just measure // check that some limiting is being applied. Note that we can't just measure
// the time it took to send them all and infer how many should have succeeded // the time it took to send them all and infer how many should have succeeded
// without some complex modeling of the token bucket algorithm. // without some complex modeling of the token bucket algorithm.
require.Truef(successCount >= 1, "at least 1 CSRs should have succeeded, got %d", successCount) require.Truef(t, successCount >= 1, "at least 1 CSRs should have succeeded, got %d", successCount)
require.Truef(limitedCount >= 7, "at least 7 CSRs should have been rate limited, got %d", limitedCount) require.Truef(t, limitedCount >= 7, "at least 7 CSRs should have been rate limited, got %d", limitedCount)
} }
func TestConnectCASign_concurrencyLimit(t *testing.T) { func TestConnectCASign_concurrencyLimit(t *testing.T) {
@ -997,7 +987,6 @@ func TestConnectCASign_concurrencyLimit(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.Datacenter = "dc1" c.Datacenter = "dc1"
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
@ -1057,7 +1046,7 @@ func TestConnectCASign_concurrencyLimit(t *testing.T) {
} else if err.Error() == ErrRateLimited.Error() { } else if err.Error() == ErrRateLimited.Error() {
limitedCount++ limitedCount++
} else { } else {
require.NoError(err) require.NoError(t, err)
} }
} }
@ -1096,7 +1085,7 @@ func TestConnectCASign_concurrencyLimit(t *testing.T) {
// requests were serialized. // requests were serialized.
t.Logf("min=%s, max=%s", minTime, maxTime) t.Logf("min=%s, max=%s", minTime, maxTime)
//t.Fail() // Uncomment to see the time spread logged //t.Fail() // Uncomment to see the time spread logged
require.Truef(successCount >= 1, "at least 1 CSRs should have succeeded, got %d", successCount) require.Truef(t, successCount >= 1, "at least 1 CSRs should have succeeded, got %d", successCount)
} }
func TestConnectCASignValidation(t *testing.T) { func TestConnectCASignValidation(t *testing.T) {

View File

@ -1101,10 +1101,9 @@ func TestFSM_Autopilot(t *testing.T) {
func TestFSM_Intention_CRUD(t *testing.T) { func TestFSM_Intention_CRUD(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
logger := testutil.Logger(t) logger := testutil.Logger(t)
fsm, err := New(nil, logger) fsm, err := New(nil, logger)
assert.Nil(err) assert.Nil(t, err)
// Create a new intention. // Create a new intention.
ixn := structs.IntentionRequest{ ixn := structs.IntentionRequest{
@ -1118,19 +1117,19 @@ func TestFSM_Intention_CRUD(t *testing.T) {
{ {
buf, err := structs.Encode(structs.IntentionRequestType, ixn) buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err) assert.Nil(t, err)
assert.Nil(fsm.Apply(makeLog(buf))) assert.Nil(t, fsm.Apply(makeLog(buf)))
} }
// Verify it's in the state store. // Verify it's in the state store.
{ {
_, _, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID) _, _, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err) assert.Nil(t, err)
actual.CreateIndex, actual.ModifyIndex = 0, 0 actual.CreateIndex, actual.ModifyIndex = 0, 0
actual.CreatedAt = ixn.Intention.CreatedAt actual.CreatedAt = ixn.Intention.CreatedAt
actual.UpdatedAt = ixn.Intention.UpdatedAt actual.UpdatedAt = ixn.Intention.UpdatedAt
assert.Equal(ixn.Intention, actual) assert.Equal(t, ixn.Intention, actual)
} }
// Make an update // Make an update
@ -1138,44 +1137,43 @@ func TestFSM_Intention_CRUD(t *testing.T) {
ixn.Intention.SourceName = "api" ixn.Intention.SourceName = "api"
{ {
buf, err := structs.Encode(structs.IntentionRequestType, ixn) buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err) assert.Nil(t, err)
assert.Nil(fsm.Apply(makeLog(buf))) assert.Nil(t, fsm.Apply(makeLog(buf)))
} }
// Verify the update. // Verify the update.
{ {
_, _, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID) _, _, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err) assert.Nil(t, err)
actual.CreateIndex, actual.ModifyIndex = 0, 0 actual.CreateIndex, actual.ModifyIndex = 0, 0
actual.CreatedAt = ixn.Intention.CreatedAt actual.CreatedAt = ixn.Intention.CreatedAt
actual.UpdatedAt = ixn.Intention.UpdatedAt actual.UpdatedAt = ixn.Intention.UpdatedAt
assert.Equal(ixn.Intention, actual) assert.Equal(t, ixn.Intention, actual)
} }
// Delete // Delete
ixn.Op = structs.IntentionOpDelete ixn.Op = structs.IntentionOpDelete
{ {
buf, err := structs.Encode(structs.IntentionRequestType, ixn) buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err) assert.Nil(t, err)
assert.Nil(fsm.Apply(makeLog(buf))) assert.Nil(t, fsm.Apply(makeLog(buf)))
} }
// Make sure it's gone. // Make sure it's gone.
{ {
_, _, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID) _, _, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err) assert.Nil(t, err)
assert.Nil(actual) assert.Nil(t, actual)
} }
} }
func TestFSM_CAConfig(t *testing.T) { func TestFSM_CAConfig(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
logger := testutil.Logger(t) logger := testutil.Logger(t)
fsm, err := New(nil, logger) fsm, err := New(nil, logger)
assert.Nil(err) assert.Nil(t, err)
// Set the autopilot config using a request. // Set the autopilot config using a request.
req := structs.CARequest{ req := structs.CARequest{
@ -1190,7 +1188,7 @@ func TestFSM_CAConfig(t *testing.T) {
}, },
} }
buf, err := structs.Encode(structs.ConnectCARequestType, req) buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err) assert.Nil(t, err)
resp := fsm.Apply(makeLog(buf)) resp := fsm.Apply(makeLog(buf))
if _, ok := resp.(error); ok { if _, ok := resp.(error); ok {
t.Fatalf("bad: %v", resp) t.Fatalf("bad: %v", resp)
@ -1231,7 +1229,7 @@ func TestFSM_CAConfig(t *testing.T) {
} }
_, config, err = fsm.state.CAConfig(nil) _, config, err = fsm.state.CAConfig(nil)
assert.Nil(err) assert.Nil(t, err)
if config.Provider != "static" { if config.Provider != "static" {
t.Fatalf("bad: %v", config.Provider) t.Fatalf("bad: %v", config.Provider)
} }
@ -1240,10 +1238,9 @@ func TestFSM_CAConfig(t *testing.T) {
func TestFSM_CARoots(t *testing.T) { func TestFSM_CARoots(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
logger := testutil.Logger(t) logger := testutil.Logger(t)
fsm, err := New(nil, logger) fsm, err := New(nil, logger)
assert.Nil(err) assert.Nil(t, err)
// Roots // Roots
ca1 := connect.TestCA(t, nil) ca1 := connect.TestCA(t, nil)
@ -1258,25 +1255,24 @@ func TestFSM_CARoots(t *testing.T) {
{ {
buf, err := structs.Encode(structs.ConnectCARequestType, req) buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err) assert.Nil(t, err)
assert.True(fsm.Apply(makeLog(buf)).(bool)) assert.True(t, fsm.Apply(makeLog(buf)).(bool))
} }
// Verify it's in the state store. // Verify it's in the state store.
{ {
_, roots, err := fsm.state.CARoots(nil) _, roots, err := fsm.state.CARoots(nil)
assert.Nil(err) assert.Nil(t, err)
assert.Len(roots, 2) assert.Len(t, roots, 2)
} }
} }
func TestFSM_CABuiltinProvider(t *testing.T) { func TestFSM_CABuiltinProvider(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
logger := testutil.Logger(t) logger := testutil.Logger(t)
fsm, err := New(nil, logger) fsm, err := New(nil, logger)
assert.Nil(err) assert.Nil(t, err)
// Provider state. // Provider state.
expected := &structs.CAConsulProviderState{ expected := &structs.CAConsulProviderState{
@ -1297,25 +1293,24 @@ func TestFSM_CABuiltinProvider(t *testing.T) {
{ {
buf, err := structs.Encode(structs.ConnectCARequestType, req) buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err) assert.Nil(t, err)
assert.True(fsm.Apply(makeLog(buf)).(bool)) assert.True(t, fsm.Apply(makeLog(buf)).(bool))
} }
// Verify it's in the state store. // Verify it's in the state store.
{ {
_, state, err := fsm.state.CAProviderState("foo") _, state, err := fsm.state.CAProviderState("foo")
assert.Nil(err) assert.Nil(t, err)
assert.Equal(expected, state) assert.Equal(t, expected, state)
} }
} }
func TestFSM_ConfigEntry(t *testing.T) { func TestFSM_ConfigEntry(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
logger := testutil.Logger(t) logger := testutil.Logger(t)
fsm, err := New(nil, logger) fsm, err := New(nil, logger)
require.NoError(err) require.NoError(t, err)
// Create a simple config entry // Create a simple config entry
entry := &structs.ProxyConfigEntry{ entry := &structs.ProxyConfigEntry{
@ -1335,7 +1330,7 @@ func TestFSM_ConfigEntry(t *testing.T) {
{ {
buf, err := structs.Encode(structs.ConfigEntryRequestType, req) buf, err := structs.Encode(structs.ConfigEntryRequestType, req)
require.NoError(err) require.NoError(t, err)
resp := fsm.Apply(makeLog(buf)) resp := fsm.Apply(makeLog(buf))
if _, ok := resp.(error); ok { if _, ok := resp.(error); ok {
t.Fatalf("bad: %v", resp) t.Fatalf("bad: %v", resp)
@ -1345,33 +1340,31 @@ func TestFSM_ConfigEntry(t *testing.T) {
// Verify it's in the state store. // Verify it's in the state store.
{ {
_, config, err := fsm.state.ConfigEntry(nil, structs.ProxyDefaults, "global", nil) _, config, err := fsm.state.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err) require.NoError(t, err)
entry.RaftIndex.CreateIndex = 1 entry.RaftIndex.CreateIndex = 1
entry.RaftIndex.ModifyIndex = 1 entry.RaftIndex.ModifyIndex = 1
require.Equal(entry, config) require.Equal(t, entry, config)
} }
} }
func TestFSM_ConfigEntry_DeleteCAS(t *testing.T) { func TestFSM_ConfigEntry_DeleteCAS(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
logger := testutil.Logger(t) logger := testutil.Logger(t)
fsm, err := New(nil, logger) fsm, err := New(nil, logger)
require.NoError(err) require.NoError(t, err)
// Create a simple config entry and write it to the state store. // Create a simple config entry and write it to the state store.
entry := &structs.ServiceConfigEntry{ entry := &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "global", Name: "global",
} }
require.NoError(fsm.state.EnsureConfigEntry(1, entry)) require.NoError(t, fsm.state.EnsureConfigEntry(1, entry))
// Raft index is populated by EnsureConfigEntry, hold on to it so that we can // Raft index is populated by EnsureConfigEntry, hold on to it so that we can
// restore it later. // restore it later.
raftIndex := entry.RaftIndex raftIndex := entry.RaftIndex
require.NotZero(raftIndex.ModifyIndex) require.NotZero(t, raftIndex.ModifyIndex)
// Attempt a CAS delete with an invalid index. // Attempt a CAS delete with an invalid index.
entry = entry.Clone() entry = entry.Clone()
@ -1383,24 +1376,24 @@ func TestFSM_ConfigEntry_DeleteCAS(t *testing.T) {
Entry: entry, Entry: entry,
} }
buf, err := structs.Encode(structs.ConfigEntryRequestType, req) buf, err := structs.Encode(structs.ConfigEntryRequestType, req)
require.NoError(err) require.NoError(t, err)
// Expect to get boolean false back. // Expect to get boolean false back.
rsp := fsm.Apply(makeLog(buf)) rsp := fsm.Apply(makeLog(buf))
didDelete, isBool := rsp.(bool) didDelete, isBool := rsp.(bool)
require.True(isBool) require.True(t, isBool)
require.False(didDelete) require.False(t, didDelete)
// Attempt a CAS delete with a valid index. // Attempt a CAS delete with a valid index.
entry.RaftIndex = raftIndex entry.RaftIndex = raftIndex
buf, err = structs.Encode(structs.ConfigEntryRequestType, req) buf, err = structs.Encode(structs.ConfigEntryRequestType, req)
require.NoError(err) require.NoError(t, err)
// Expect to get boolean true back. // Expect to get boolean true back.
rsp = fsm.Apply(makeLog(buf)) rsp = fsm.Apply(makeLog(buf))
didDelete, isBool = rsp.(bool) didDelete, isBool = rsp.(bool)
require.True(isBool) require.True(t, isBool)
require.True(didDelete) require.True(t, didDelete)
} }
// This adapts another test by chunking the encoded data and then performing // This adapts another test by chunking the encoded data and then performing
@ -1413,12 +1406,10 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
assert := assert.New(t)
logger := testutil.Logger(t) logger := testutil.Logger(t)
fsm, err := New(nil, logger) fsm, err := New(nil, logger)
require.NoError(err) require.NoError(t, err)
var logOfLogs [][]*raft.Log var logOfLogs [][]*raft.Log
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -1442,7 +1433,7 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
} }
buf, err := structs.Encode(structs.RegisterRequestType, req) buf, err := structs.Encode(structs.RegisterRequestType, req)
require.NoError(err) require.NoError(t, err)
var logs []*raft.Log var logs []*raft.Log
@ -1453,7 +1444,7 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
NumChunks: uint32(len(buf)), NumChunks: uint32(len(buf)),
} }
chunkBytes, err := proto.Marshal(chunkInfo) chunkBytes, err := proto.Marshal(chunkInfo)
require.NoError(err) require.NoError(t, err)
logs = append(logs, &raft.Log{ logs = append(logs, &raft.Log{
Data: []byte{b}, Data: []byte{b},
@ -1468,41 +1459,41 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
// the full set, and out of order. // the full set, and out of order.
for _, logs := range logOfLogs { for _, logs := range logOfLogs {
resp := fsm.chunker.Apply(logs[8]) resp := fsm.chunker.Apply(logs[8])
assert.Nil(resp) assert.Nil(t, resp)
resp = fsm.chunker.Apply(logs[0]) resp = fsm.chunker.Apply(logs[0])
assert.Nil(resp) assert.Nil(t, resp)
resp = fsm.chunker.Apply(logs[3]) resp = fsm.chunker.Apply(logs[3])
assert.Nil(resp) assert.Nil(t, resp)
} }
// Verify we are not registered // Verify we are not registered
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
_, node, err := fsm.state.GetNode(fmt.Sprintf("foo%d", i), nil) _, node, err := fsm.state.GetNode(fmt.Sprintf("foo%d", i), nil)
require.NoError(err) require.NoError(t, err)
assert.Nil(node) assert.Nil(t, node)
} }
// Snapshot, restore elsewhere, apply the rest of the logs, make sure it // Snapshot, restore elsewhere, apply the rest of the logs, make sure it
// looks right // looks right
snap, err := fsm.Snapshot() snap, err := fsm.Snapshot()
require.NoError(err) require.NoError(t, err)
defer snap.Release() defer snap.Release()
sinkBuf := bytes.NewBuffer(nil) sinkBuf := bytes.NewBuffer(nil)
sink := &MockSink{sinkBuf, false} sink := &MockSink{sinkBuf, false}
err = snap.Persist(sink) err = snap.Persist(sink)
require.NoError(err) require.NoError(t, err)
fsm2, err := New(nil, logger) fsm2, err := New(nil, logger)
require.NoError(err) require.NoError(t, err)
err = fsm2.Restore(sink) err = fsm2.Restore(sink)
require.NoError(err) require.NoError(t, err)
// Verify we are still not registered // Verify we are still not registered
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
_, node, err := fsm2.state.GetNode(fmt.Sprintf("foo%d", i), nil) _, node, err := fsm2.state.GetNode(fmt.Sprintf("foo%d", i), nil)
require.NoError(err) require.NoError(t, err)
assert.Nil(node) assert.Nil(t, node)
} }
// Apply the rest of the logs // Apply the rest of the logs
@ -1514,43 +1505,41 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
default: default:
resp = fsm2.chunker.Apply(log) resp = fsm2.chunker.Apply(log)
if i != len(logs)-1 { if i != len(logs)-1 {
assert.Nil(resp) assert.Nil(t, resp)
} }
} }
} }
_, ok := resp.(raftchunking.ChunkingSuccess) _, ok := resp.(raftchunking.ChunkingSuccess)
assert.True(ok) assert.True(t, ok)
} }
// Verify we are registered // Verify we are registered
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
_, node, err := fsm2.state.GetNode(fmt.Sprintf("foo%d", i), nil) _, node, err := fsm2.state.GetNode(fmt.Sprintf("foo%d", i), nil)
require.NoError(err) require.NoError(t, err)
assert.NotNil(node) assert.NotNil(t, node)
// Verify service registered // Verify service registered
_, services, err := fsm2.state.NodeServices(nil, fmt.Sprintf("foo%d", i), structs.DefaultEnterpriseMetaInDefaultPartition()) _, services, err := fsm2.state.NodeServices(nil, fmt.Sprintf("foo%d", i), structs.DefaultEnterpriseMetaInDefaultPartition())
require.NoError(err) require.NoError(t, err)
require.NotNil(services) require.NotNil(t, services)
_, ok := services.Services["db"] _, ok := services.Services["db"]
assert.True(ok) assert.True(t, ok)
// Verify check // Verify check
_, checks, err := fsm2.state.NodeChecks(nil, fmt.Sprintf("foo%d", i), nil) _, checks, err := fsm2.state.NodeChecks(nil, fmt.Sprintf("foo%d", i), nil)
require.NoError(err) require.NoError(t, err)
require.NotNil(checks) require.NotNil(t, checks)
assert.Equal(string(checks[0].CheckID), "db") assert.Equal(t, string(checks[0].CheckID), "db")
} }
} }
func TestFSM_Chunking_TermChange(t *testing.T) { func TestFSM_Chunking_TermChange(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
require := require.New(t)
logger := testutil.Logger(t) logger := testutil.Logger(t)
fsm, err := New(nil, logger) fsm, err := New(nil, logger)
require.NoError(err) require.NoError(t, err)
req := structs.RegisterRequest{ req := structs.RegisterRequest{
Datacenter: "dc1", Datacenter: "dc1",
@ -1571,7 +1560,7 @@ func TestFSM_Chunking_TermChange(t *testing.T) {
}, },
} }
buf, err := structs.Encode(structs.RegisterRequestType, req) buf, err := structs.Encode(structs.RegisterRequestType, req)
require.NoError(err) require.NoError(t, err)
// Only need two chunks to test this // Only need two chunks to test this
chunks := [][]byte{ chunks := [][]byte{
@ -1599,7 +1588,7 @@ func TestFSM_Chunking_TermChange(t *testing.T) {
// We should see nil for both // We should see nil for both
for _, log := range logs { for _, log := range logs {
resp := fsm.chunker.Apply(log) resp := fsm.chunker.Apply(log)
assert.Nil(resp) assert.Nil(t, resp)
} }
// Now verify the other baseline, that when the term doesn't change we see // Now verify the other baseline, that when the term doesn't change we see
@ -1616,10 +1605,10 @@ func TestFSM_Chunking_TermChange(t *testing.T) {
for i, log := range logs { for i, log := range logs {
resp := fsm.chunker.Apply(log) resp := fsm.chunker.Apply(log)
if i == 0 { if i == 0 {
assert.Nil(resp) assert.Nil(t, resp)
} }
if i == 1 { if i == 1 {
assert.NotNil(resp) assert.NotNil(t, resp)
} }
} }
} }

View File

@ -979,7 +979,6 @@ func TestHealth_ServiceNodes_ConnectProxy_ACL(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -1020,7 +1019,7 @@ node "foo" {
Status: api.HealthPassing, Status: api.HealthPassing,
ServiceID: args.Service.ID, ServiceID: args.Service.ID,
} }
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a service // Register a service
args = structs.TestRegisterRequestProxy(t) args = structs.TestRegisterRequestProxy(t)
@ -1032,7 +1031,7 @@ node "foo" {
Status: api.HealthPassing, Status: api.HealthPassing,
ServiceID: args.Service.Service, ServiceID: args.Service.Service,
} }
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a service // Register a service
args = structs.TestRegisterRequestProxy(t) args = structs.TestRegisterRequestProxy(t)
@ -1044,7 +1043,7 @@ node "foo" {
Status: api.HealthPassing, Status: api.HealthPassing,
ServiceID: args.Service.Service, ServiceID: args.Service.Service,
} }
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
} }
// List w/ token. This should disallow because we don't have permission // List w/ token. This should disallow because we don't have permission
@ -1056,8 +1055,8 @@ node "foo" {
QueryOptions: structs.QueryOptions{Token: token}, QueryOptions: structs.QueryOptions{Token: token},
} }
var resp structs.IndexedCheckServiceNodes var resp structs.IndexedCheckServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
assert.Len(resp.Nodes, 0) assert.Len(t, resp.Nodes, 0)
// List w/ token. This should work since we're requesting "foo", but should // List w/ token. This should work since we're requesting "foo", but should
// also only contain the proxies with names that adhere to our ACL. // also only contain the proxies with names that adhere to our ACL.
@ -1067,8 +1066,8 @@ node "foo" {
ServiceName: "foo", ServiceName: "foo",
QueryOptions: structs.QueryOptions{Token: token}, QueryOptions: structs.QueryOptions{Token: token},
} }
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp)) assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
assert.Len(resp.Nodes, 1) assert.Len(t, resp.Nodes, 1)
} }
func TestHealth_ServiceNodes_Gateway(t *testing.T) { func TestHealth_ServiceNodes_Gateway(t *testing.T) {
@ -1432,8 +1431,6 @@ func TestHealth_NodeChecks_FilterACL(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir, token, srv, codec := testACLFilterServer(t) dir, token, srv, codec := testACLFilterServer(t)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
defer srv.Shutdown() defer srv.Shutdown()
@ -1446,7 +1443,7 @@ func TestHealth_NodeChecks_FilterACL(t *testing.T) {
} }
reply := structs.IndexedHealthChecks{} reply := structs.IndexedHealthChecks{}
err := msgpackrpc.CallWithCodec(codec, "Health.NodeChecks", &opt, &reply) err := msgpackrpc.CallWithCodec(codec, "Health.NodeChecks", &opt, &reply)
require.NoError(err) require.NoError(t, err)
found := false found := false
for _, chk := range reply.HealthChecks { for _, chk := range reply.HealthChecks {
@ -1457,8 +1454,8 @@ func TestHealth_NodeChecks_FilterACL(t *testing.T) {
t.Fatalf("bad: %#v", reply.HealthChecks) t.Fatalf("bad: %#v", reply.HealthChecks)
} }
} }
require.True(found, "bad: %#v", reply.HealthChecks) require.True(t, found, "bad: %#v", reply.HealthChecks)
require.True(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// We've already proven that we call the ACL filtering function so we // We've already proven that we call the ACL filtering function so we
// test node filtering down in acl.go for node cases. This also proves // test node filtering down in acl.go for node cases. This also proves
@ -1474,8 +1471,6 @@ func TestHealth_ServiceChecks_FilterACL(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir, token, srv, codec := testACLFilterServer(t) dir, token, srv, codec := testACLFilterServer(t)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
defer srv.Shutdown() defer srv.Shutdown()
@ -1488,7 +1483,7 @@ func TestHealth_ServiceChecks_FilterACL(t *testing.T) {
} }
reply := structs.IndexedHealthChecks{} reply := structs.IndexedHealthChecks{}
err := msgpackrpc.CallWithCodec(codec, "Health.ServiceChecks", &opt, &reply) err := msgpackrpc.CallWithCodec(codec, "Health.ServiceChecks", &opt, &reply)
require.NoError(err) require.NoError(t, err)
found := false found := false
for _, chk := range reply.HealthChecks { for _, chk := range reply.HealthChecks {
@ -1497,14 +1492,14 @@ func TestHealth_ServiceChecks_FilterACL(t *testing.T) {
break break
} }
} }
require.True(found, "bad: %#v", reply.HealthChecks) require.True(t, found, "bad: %#v", reply.HealthChecks)
opt.ServiceName = "bar" opt.ServiceName = "bar"
reply = structs.IndexedHealthChecks{} reply = structs.IndexedHealthChecks{}
err = msgpackrpc.CallWithCodec(codec, "Health.ServiceChecks", &opt, &reply) err = msgpackrpc.CallWithCodec(codec, "Health.ServiceChecks", &opt, &reply)
require.NoError(err) require.NoError(t, err)
require.Empty(reply.HealthChecks) require.Empty(t, reply.HealthChecks)
require.True(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// We've already proven that we call the ACL filtering function so we // We've already proven that we call the ACL filtering function so we
// test node filtering down in acl.go for node cases. This also proves // test node filtering down in acl.go for node cases. This also proves
@ -1520,8 +1515,6 @@ func TestHealth_ServiceNodes_FilterACL(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir, token, srv, codec := testACLFilterServer(t) dir, token, srv, codec := testACLFilterServer(t)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
defer srv.Shutdown() defer srv.Shutdown()
@ -1534,15 +1527,15 @@ func TestHealth_ServiceNodes_FilterACL(t *testing.T) {
} }
reply := structs.IndexedCheckServiceNodes{} reply := structs.IndexedCheckServiceNodes{}
err := msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &opt, &reply) err := msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &opt, &reply)
require.NoError(err) require.NoError(t, err)
require.Len(reply.Nodes, 1) require.Len(t, reply.Nodes, 1)
opt.ServiceName = "bar" opt.ServiceName = "bar"
reply = structs.IndexedCheckServiceNodes{} reply = structs.IndexedCheckServiceNodes{}
err = msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &opt, &reply) err = msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &opt, &reply)
require.NoError(err) require.NoError(t, err)
require.Empty(reply.Nodes) require.Empty(t, reply.Nodes)
require.True(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// We've already proven that we call the ACL filtering function so we // We've already proven that we call the ACL filtering function so we
// test node filtering down in acl.go for node cases. This also proves // test node filtering down in acl.go for node cases. This also proves
@ -1558,8 +1551,6 @@ func TestHealth_ChecksInState_FilterACL(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir, token, srv, codec := testACLFilterServer(t) dir, token, srv, codec := testACLFilterServer(t)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
defer srv.Shutdown() defer srv.Shutdown()
@ -1572,7 +1563,7 @@ func TestHealth_ChecksInState_FilterACL(t *testing.T) {
} }
reply := structs.IndexedHealthChecks{} reply := structs.IndexedHealthChecks{}
err := msgpackrpc.CallWithCodec(codec, "Health.ChecksInState", &opt, &reply) err := msgpackrpc.CallWithCodec(codec, "Health.ChecksInState", &opt, &reply)
require.NoError(err) require.NoError(t, err)
found := false found := false
for _, chk := range reply.HealthChecks { for _, chk := range reply.HealthChecks {
@ -1583,8 +1574,8 @@ func TestHealth_ChecksInState_FilterACL(t *testing.T) {
t.Fatalf("bad service 'bar': %#v", reply.HealthChecks) t.Fatalf("bad service 'bar': %#v", reply.HealthChecks)
} }
} }
require.True(found, "missing service 'foo': %#v", reply.HealthChecks) require.True(t, found, "missing service 'foo': %#v", reply.HealthChecks)
require.True(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// We've already proven that we call the ACL filtering function so we // We've already proven that we call the ACL filtering function so we
// test node filtering down in acl.go for node cases. This also proves // test node filtering down in acl.go for node cases. This also proves

View File

@ -111,7 +111,6 @@ func TestIntentionApply_defaultSourceType(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -135,8 +134,8 @@ func TestIntentionApply_defaultSourceType(t *testing.T) {
var reply string var reply string
// Create // Create
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.NotEmpty(reply) require.NotEmpty(t, reply)
// Read // Read
ixn.Intention.ID = reply ixn.Intention.ID = reply
@ -146,10 +145,10 @@ func TestIntentionApply_defaultSourceType(t *testing.T) {
IntentionID: ixn.Intention.ID, IntentionID: ixn.Intention.ID,
} }
var resp structs.IndexedIntentions var resp structs.IndexedIntentions
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp))
require.Len(resp.Intentions, 1) require.Len(t, resp.Intentions, 1)
actual := resp.Intentions[0] actual := resp.Intentions[0]
require.Equal(structs.IntentionSourceConsul, actual.SourceType) require.Equal(t, structs.IntentionSourceConsul, actual.SourceType)
} }
} }
@ -161,7 +160,6 @@ func TestIntentionApply_createWithID(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -184,8 +182,8 @@ func TestIntentionApply_createWithID(t *testing.T) {
// Create // Create
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply) err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.NotNil(err) require.NotNil(t, err)
require.Contains(err, "ID must be empty") require.Contains(t, err, "ID must be empty")
} }
// Test basic updating // Test basic updating
@ -282,7 +280,6 @@ func TestIntentionApply_updateNonExist(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -304,8 +301,8 @@ func TestIntentionApply_updateNonExist(t *testing.T) {
// Create // Create
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply) err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.NotNil(err) require.NotNil(t, err)
require.Contains(err, "Cannot modify non-existent intention") require.Contains(t, err, "Cannot modify non-existent intention")
} }
// Test basic deleting // Test basic deleting
@ -316,7 +313,6 @@ func TestIntentionApply_deleteGood(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -346,13 +342,13 @@ func TestIntentionApply_deleteGood(t *testing.T) {
}, &reply), "Cannot delete non-existent intention") }, &reply), "Cannot delete non-existent intention")
// Create // Create
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.NotEmpty(reply) require.NotEmpty(t, reply)
// Delete // Delete
ixn.Op = structs.IntentionOpDelete ixn.Op = structs.IntentionOpDelete
ixn.Intention.ID = reply ixn.Intention.ID = reply
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Read // Read
ixn.Intention.ID = reply ixn.Intention.ID = reply
@ -363,8 +359,8 @@ func TestIntentionApply_deleteGood(t *testing.T) {
} }
var resp structs.IndexedIntentions var resp structs.IndexedIntentions
err := msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp) err := msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp)
require.NotNil(err) require.NotNil(t, err)
require.Contains(err, ErrIntentionNotFound.Error()) require.Contains(t, err, ErrIntentionNotFound.Error())
} }
} }
@ -863,7 +859,6 @@ func TestIntentionApply_aclDeny(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -895,11 +890,11 @@ service "foobar" {
// Create without a token should error since default deny // Create without a token should error since default deny
var reply string var reply string
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply) err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.True(acl.IsErrPermissionDenied(err)) require.True(t, acl.IsErrPermissionDenied(err))
// Now add the token and try again. // Now add the token and try again.
ixn.WriteRequest.Token = token ixn.WriteRequest.Token = token
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Read // Read
ixn.Intention.ID = reply ixn.Intention.ID = reply
@ -910,10 +905,10 @@ service "foobar" {
QueryOptions: structs.QueryOptions{Token: "root"}, QueryOptions: structs.QueryOptions{Token: "root"},
} }
var resp structs.IndexedIntentions var resp structs.IndexedIntentions
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp))
require.Len(resp.Intentions, 1) require.Len(t, resp.Intentions, 1)
actual := resp.Intentions[0] actual := resp.Intentions[0]
require.Equal(resp.Index, actual.ModifyIndex) require.Equal(t, resp.Index, actual.ModifyIndex)
actual.CreateIndex, actual.ModifyIndex = 0, 0 actual.CreateIndex, actual.ModifyIndex = 0, 0
actual.CreatedAt = ixn.Intention.CreatedAt actual.CreatedAt = ixn.Intention.CreatedAt
@ -921,7 +916,7 @@ service "foobar" {
actual.Hash = ixn.Intention.Hash actual.Hash = ixn.Intention.Hash
//nolint:staticcheck //nolint:staticcheck
ixn.Intention.UpdatePrecedence() ixn.Intention.UpdatePrecedence()
require.Equal(ixn.Intention, actual) require.Equal(t, ixn.Intention, actual)
} }
} }
@ -1253,7 +1248,6 @@ func TestIntentionApply_aclDelete(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -1285,18 +1279,18 @@ service "foobar" {
// Create // Create
var reply string var reply string
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Try to do a delete with no token; this should get rejected. // Try to do a delete with no token; this should get rejected.
ixn.Op = structs.IntentionOpDelete ixn.Op = structs.IntentionOpDelete
ixn.Intention.ID = reply ixn.Intention.ID = reply
ixn.WriteRequest.Token = "" ixn.WriteRequest.Token = ""
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply) err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.True(acl.IsErrPermissionDenied(err)) require.True(t, acl.IsErrPermissionDenied(err))
// Try again with the original token. This should go through. // Try again with the original token. This should go through.
ixn.WriteRequest.Token = token ixn.WriteRequest.Token = token
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Verify it is gone // Verify it is gone
{ {
@ -1306,8 +1300,8 @@ service "foobar" {
} }
var resp structs.IndexedIntentions var resp structs.IndexedIntentions
err := msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp) err := msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp)
require.NotNil(err) require.NotNil(t, err)
require.Contains(err.Error(), ErrIntentionNotFound.Error()) require.Contains(t, err.Error(), ErrIntentionNotFound.Error())
} }
} }
@ -1319,7 +1313,6 @@ func TestIntentionApply_aclUpdate(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -1351,18 +1344,18 @@ service "foobar" {
// Create // Create
var reply string var reply string
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Try to do an update without a token; this should get rejected. // Try to do an update without a token; this should get rejected.
ixn.Op = structs.IntentionOpUpdate ixn.Op = structs.IntentionOpUpdate
ixn.Intention.ID = reply ixn.Intention.ID = reply
ixn.WriteRequest.Token = "" ixn.WriteRequest.Token = ""
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply) err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.True(acl.IsErrPermissionDenied(err)) require.True(t, acl.IsErrPermissionDenied(err))
// Try again with the original token; this should go through. // Try again with the original token; this should go through.
ixn.WriteRequest.Token = token ixn.WriteRequest.Token = token
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
} }
// Test apply with a management token // Test apply with a management token
@ -1373,7 +1366,6 @@ func TestIntentionApply_aclManagement(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -1398,16 +1390,16 @@ func TestIntentionApply_aclManagement(t *testing.T) {
// Create // Create
var reply string var reply string
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
ixn.Intention.ID = reply ixn.Intention.ID = reply
// Update // Update
ixn.Op = structs.IntentionOpUpdate ixn.Op = structs.IntentionOpUpdate
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Delete // Delete
ixn.Op = structs.IntentionOpDelete ixn.Op = structs.IntentionOpDelete
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
} }
// Test update changing the name where an ACL won't allow it // Test update changing the name where an ACL won't allow it
@ -1418,7 +1410,6 @@ func TestIntentionApply_aclUpdateChange(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -1450,7 +1441,7 @@ service "foobar" {
// Create // Create
var reply string var reply string
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Try to do an update without a token; this should get rejected. // Try to do an update without a token; this should get rejected.
ixn.Op = structs.IntentionOpUpdate ixn.Op = structs.IntentionOpUpdate
@ -1458,7 +1449,7 @@ service "foobar" {
ixn.Intention.DestinationName = "foo" ixn.Intention.DestinationName = "foo"
ixn.WriteRequest.Token = token ixn.WriteRequest.Token = token
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply) err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.True(acl.IsErrPermissionDenied(err)) require.True(t, acl.IsErrPermissionDenied(err))
} }
// Test reading with ACLs // Test reading with ACLs
@ -1570,7 +1561,6 @@ func TestIntentionList(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -1585,9 +1575,9 @@ func TestIntentionList(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var resp structs.IndexedIntentions var resp structs.IndexedIntentions
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.List", req, &resp)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.List", req, &resp))
require.NotNil(resp.Intentions) require.NotNil(t, resp.Intentions)
require.Len(resp.Intentions, 0) require.Len(t, resp.Intentions, 0)
} }
} }

View File

@ -853,7 +853,6 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
} }
t.Run("can read all", func(t *testing.T) { t.Run("can read all", func(t *testing.T) {
require := require.New(t)
token := tokenWithRules(t, ` token := tokenWithRules(t, `
node_prefix "" { node_prefix "" {
@ -870,14 +869,13 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
} }
var out structs.IndexedNodesWithGateways var out structs.IndexedNodesWithGateways
err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out) err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out)
require.NoError(err) require.NoError(t, err)
require.NotEmpty(out.Nodes) require.NotEmpty(t, out.Nodes)
require.NotEmpty(out.Gateways) require.NotEmpty(t, out.Gateways)
require.False(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") require.False(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
}) })
t.Run("cannot read service node", func(t *testing.T) { t.Run("cannot read service node", func(t *testing.T) {
require := require.New(t)
token := tokenWithRules(t, ` token := tokenWithRules(t, `
node "node1" { node "node1" {
@ -894,13 +892,12 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
} }
var out structs.IndexedNodesWithGateways var out structs.IndexedNodesWithGateways
err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out) err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out)
require.NoError(err) require.NoError(t, err)
require.Empty(out.Nodes) require.Empty(t, out.Nodes)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
}) })
t.Run("cannot read service", func(t *testing.T) { t.Run("cannot read service", func(t *testing.T) {
require := require.New(t)
token := tokenWithRules(t, ` token := tokenWithRules(t, `
node "node1" { node "node1" {
@ -917,13 +914,12 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
} }
var out structs.IndexedNodesWithGateways var out structs.IndexedNodesWithGateways
err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out) err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out)
require.NoError(err) require.NoError(t, err)
require.Empty(out.Nodes) require.Empty(t, out.Nodes)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
}) })
t.Run("cannot read gateway node", func(t *testing.T) { t.Run("cannot read gateway node", func(t *testing.T) {
require := require.New(t)
token := tokenWithRules(t, ` token := tokenWithRules(t, `
node "node2" { node "node2" {
@ -940,13 +936,12 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
} }
var out structs.IndexedNodesWithGateways var out structs.IndexedNodesWithGateways
err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out) err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out)
require.NoError(err) require.NoError(t, err)
require.Empty(out.Gateways) require.Empty(t, out.Gateways)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
}) })
t.Run("cannot read gateway", func(t *testing.T) { t.Run("cannot read gateway", func(t *testing.T) {
require := require.New(t)
token := tokenWithRules(t, ` token := tokenWithRules(t, `
node "node2" { node "node2" {
@ -963,9 +958,9 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
} }
var out structs.IndexedNodesWithGateways var out structs.IndexedNodesWithGateways
err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out) err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out)
require.NoError(err) require.NoError(t, err)
require.Empty(out.Gateways) require.Empty(t, out.Gateways)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
}) })
} }

View File

@ -327,7 +327,6 @@ func TestCAManager_RenewIntermediate_Vault_Primary(t *testing.T) {
// no parallel execution because we change globals // no parallel execution because we change globals
patchIntermediateCertRenewInterval(t) patchIntermediateCertRenewInterval(t)
require := require.New(t)
testVault := ca.NewTestVaultServer(t) testVault := ca.NewTestVaultServer(t)
@ -354,15 +353,15 @@ func TestCAManager_RenewIntermediate_Vault_Primary(t *testing.T) {
store := s1.caManager.delegate.State() store := s1.caManager.delegate.State()
_, activeRoot, err := store.CARootActive(nil) _, activeRoot, err := store.CARootActive(nil)
require.NoError(err) require.NoError(t, err)
t.Log("original SigningKeyID", activeRoot.SigningKeyID) t.Log("original SigningKeyID", activeRoot.SigningKeyID)
intermediatePEM := s1.caManager.getLeafSigningCertFromRoot(activeRoot) intermediatePEM := s1.caManager.getLeafSigningCertFromRoot(activeRoot)
intermediateCert, err := connect.ParseCert(intermediatePEM) intermediateCert, err := connect.ParseCert(intermediatePEM)
require.NoError(err) require.NoError(t, err)
require.Equal(connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID) require.Equal(t, connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
require.Equal(intermediatePEM, s1.caManager.getLeafSigningCertFromRoot(activeRoot)) require.Equal(t, intermediatePEM, s1.caManager.getLeafSigningCertFromRoot(activeRoot))
// Wait for dc1's intermediate to be refreshed. // Wait for dc1's intermediate to be refreshed.
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
@ -382,12 +381,12 @@ func TestCAManager_RenewIntermediate_Vault_Primary(t *testing.T) {
codec := rpcClient(t, s1) codec := rpcClient(t, s1)
roots := structs.IndexedCARoots{} roots := structs.IndexedCARoots{}
err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", &structs.DCSpecificRequest{}, &roots) err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", &structs.DCSpecificRequest{}, &roots)
require.NoError(err) require.NoError(t, err)
require.Len(roots.Roots, 1) require.Len(t, roots.Roots, 1)
activeRoot = roots.Active() activeRoot = roots.Active()
require.Equal(connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID) require.Equal(t, connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
require.Equal(intermediatePEM, s1.caManager.getLeafSigningCertFromRoot(activeRoot)) require.Equal(t, intermediatePEM, s1.caManager.getLeafSigningCertFromRoot(activeRoot))
// Have the new intermediate sign a leaf cert and make sure the chain is correct. // Have the new intermediate sign a leaf cert and make sure the chain is correct.
spiffeService := &connect.SpiffeIDService{ spiffeService := &connect.SpiffeIDService{
@ -401,7 +400,7 @@ func TestCAManager_RenewIntermediate_Vault_Primary(t *testing.T) {
req := structs.CASignRequest{CSR: csr} req := structs.CASignRequest{CSR: csr}
cert := structs.IssuedCert{} cert := structs.IssuedCert{}
err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", &req, &cert) err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", &req, &cert)
require.NoError(err) require.NoError(t, err)
verifyLeafCert(t, activeRoot, cert.CertPEM) verifyLeafCert(t, activeRoot, cert.CertPEM)
} }
@ -425,7 +424,6 @@ func TestCAManager_RenewIntermediate_Secondary(t *testing.T) {
// no parallel execution because we change globals // no parallel execution because we change globals
patchIntermediateCertRenewInterval(t) patchIntermediateCertRenewInterval(t)
require := require.New(t)
_, s1 := testServerWithConfig(t, func(c *Config) { _, s1 := testServerWithConfig(t, func(c *Config) {
c.Build = "1.6.0" c.Build = "1.6.0"
@ -469,15 +467,15 @@ func TestCAManager_RenewIntermediate_Secondary(t *testing.T) {
store := s2.fsm.State() store := s2.fsm.State()
_, activeRoot, err := store.CARootActive(nil) _, activeRoot, err := store.CARootActive(nil)
require.NoError(err) require.NoError(t, err)
t.Log("original SigningKeyID", activeRoot.SigningKeyID) t.Log("original SigningKeyID", activeRoot.SigningKeyID)
intermediatePEM := s2.caManager.getLeafSigningCertFromRoot(activeRoot) intermediatePEM := s2.caManager.getLeafSigningCertFromRoot(activeRoot)
intermediateCert, err := connect.ParseCert(intermediatePEM) intermediateCert, err := connect.ParseCert(intermediatePEM)
require.NoError(err) require.NoError(t, err)
require.Equal(intermediatePEM, s2.caManager.getLeafSigningCertFromRoot(activeRoot)) require.Equal(t, intermediatePEM, s2.caManager.getLeafSigningCertFromRoot(activeRoot))
require.Equal(connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID) require.Equal(t, connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
// Wait for dc2's intermediate to be refreshed. // Wait for dc2's intermediate to be refreshed.
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
@ -497,13 +495,13 @@ func TestCAManager_RenewIntermediate_Secondary(t *testing.T) {
codec := rpcClient(t, s2) codec := rpcClient(t, s2)
roots := structs.IndexedCARoots{} roots := structs.IndexedCARoots{}
err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", &structs.DCSpecificRequest{}, &roots) err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", &structs.DCSpecificRequest{}, &roots)
require.NoError(err) require.NoError(t, err)
require.Len(roots.Roots, 1) require.Len(t, roots.Roots, 1)
_, activeRoot, err = store.CARootActive(nil) _, activeRoot, err = store.CARootActive(nil)
require.NoError(err) require.NoError(t, err)
require.Equal(connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID) require.Equal(t, connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
require.Equal(intermediatePEM, s2.caManager.getLeafSigningCertFromRoot(activeRoot)) require.Equal(t, intermediatePEM, s2.caManager.getLeafSigningCertFromRoot(activeRoot))
// Have dc2 sign a leaf cert and make sure the chain is correct. // Have dc2 sign a leaf cert and make sure the chain is correct.
spiffeService := &connect.SpiffeIDService{ spiffeService := &connect.SpiffeIDService{
@ -517,7 +515,7 @@ func TestCAManager_RenewIntermediate_Secondary(t *testing.T) {
req := structs.CASignRequest{CSR: csr} req := structs.CASignRequest{CSR: csr}
cert := structs.IssuedCert{} cert := structs.IssuedCert{}
err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", &req, &cert) err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", &req, &cert)
require.NoError(err) require.NoError(t, err)
verifyLeafCert(t, activeRoot, cert.CertPEM) verifyLeafCert(t, activeRoot, cert.CertPEM)
} }
@ -528,8 +526,6 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.Build = "1.6.0" c.Build = "1.6.0"
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
@ -555,15 +551,15 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
// Get the original intermediate // Get the original intermediate
secondaryProvider, _ := getCAProviderWithLock(s2) secondaryProvider, _ := getCAProviderWithLock(s2)
oldIntermediatePEM, err := secondaryProvider.ActiveIntermediate() oldIntermediatePEM, err := secondaryProvider.ActiveIntermediate()
require.NoError(err) require.NoError(t, err)
require.NotEmpty(oldIntermediatePEM) require.NotEmpty(t, oldIntermediatePEM)
// Capture the current root // Capture the current root
var originalRoot *structs.CARoot var originalRoot *structs.CARoot
{ {
rootList, activeRoot, err := getTestRoots(s1, "dc1") rootList, activeRoot, err := getTestRoots(s1, "dc1")
require.NoError(err) require.NoError(t, err)
require.Len(rootList.Roots, 1) require.Len(t, rootList.Roots, 1)
originalRoot = activeRoot originalRoot = activeRoot
} }
@ -574,7 +570,7 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
// Update the provider config to use a new private key, which should // Update the provider config to use a new private key, which should
// cause a rotation. // cause a rotation.
_, newKey, err := connect.GeneratePrivateKey() _, newKey, err := connect.GeneratePrivateKey()
require.NoError(err) require.NoError(t, err)
newConfig := &structs.CAConfiguration{ newConfig := &structs.CAConfiguration{
Provider: "consul", Provider: "consul",
Config: map[string]interface{}{ Config: map[string]interface{}{
@ -590,14 +586,14 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
} }
var reply interface{} var reply interface{}
require.NoError(s1.RPC("ConnectCA.ConfigurationSet", args, &reply)) require.NoError(t, s1.RPC("ConnectCA.ConfigurationSet", args, &reply))
} }
var updatedRoot *structs.CARoot var updatedRoot *structs.CARoot
{ {
rootList, activeRoot, err := getTestRoots(s1, "dc1") rootList, activeRoot, err := getTestRoots(s1, "dc1")
require.NoError(err) require.NoError(t, err)
require.Len(rootList.Roots, 2) require.Len(t, rootList.Roots, 2)
updatedRoot = activeRoot updatedRoot = activeRoot
} }
@ -613,17 +609,17 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
r.Fatal("not a new intermediate") r.Fatal("not a new intermediate")
} }
}) })
require.NoError(err) require.NoError(t, err)
// Verify the root lists have been rotated in each DC's state store. // Verify the root lists have been rotated in each DC's state store.
state1 := s1.fsm.State() state1 := s1.fsm.State()
_, primaryRoot, err := state1.CARootActive(nil) _, primaryRoot, err := state1.CARootActive(nil)
require.NoError(err) require.NoError(t, err)
state2 := s2.fsm.State() state2 := s2.fsm.State()
_, roots2, err := state2.CARoots(nil) _, roots2, err := state2.CARoots(nil)
require.NoError(err) require.NoError(t, err)
require.Equal(2, len(roots2)) require.Equal(t, 2, len(roots2))
newRoot := roots2[0] newRoot := roots2[0]
oldRoot := roots2[1] oldRoot := roots2[1]
@ -631,10 +627,10 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
newRoot = roots2[1] newRoot = roots2[1]
oldRoot = roots2[0] oldRoot = roots2[0]
} }
require.False(oldRoot.Active) require.False(t, oldRoot.Active)
require.True(newRoot.Active) require.True(t, newRoot.Active)
require.Equal(primaryRoot.ID, newRoot.ID) require.Equal(t, primaryRoot.ID, newRoot.ID)
require.Equal(primaryRoot.RootCert, newRoot.RootCert) require.Equal(t, primaryRoot.RootCert, newRoot.RootCert)
// Get the new root from dc1 and validate a chain of: // Get the new root from dc1 and validate a chain of:
// dc2 leaf -> dc2 intermediate -> dc1 root // dc2 leaf -> dc2 intermediate -> dc1 root
@ -650,13 +646,13 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeService) raw, _ := connect.TestCSR(t, spiffeService)
leafCsr, err := connect.ParseCSR(raw) leafCsr, err := connect.ParseCSR(raw)
require.NoError(err) require.NoError(t, err)
leafPEM, err := secondaryProvider.Sign(leafCsr) leafPEM, err := secondaryProvider.Sign(leafCsr)
require.NoError(err) require.NoError(t, err)
cert, err := connect.ParseCert(leafPEM) cert, err := connect.ParseCert(leafPEM)
require.NoError(err) require.NoError(t, err)
// Check that the leaf signed by the new intermediate can be verified using the // Check that the leaf signed by the new intermediate can be verified using the
// returned cert chain (signed intermediate + remote root). // returned cert chain (signed intermediate + remote root).
@ -669,7 +665,7 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
Intermediates: intermediatePool, Intermediates: intermediatePool,
Roots: rootPool, Roots: rootPool,
}) })
require.NoError(err) require.NoError(t, err)
} }
func TestCAManager_Initialize_Vault_FixesSigningKeyID_Primary(t *testing.T) { func TestCAManager_Initialize_Vault_FixesSigningKeyID_Primary(t *testing.T) {
@ -1113,7 +1109,6 @@ func TestLeader_CARootPruning(t *testing.T) {
caRootPruneInterval = origPruneInterval caRootPruneInterval = origPruneInterval
}) })
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -1127,14 +1122,14 @@ func TestLeader_CARootPruning(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var rootList structs.IndexedCARoots var rootList structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(rootList.Roots, 1) require.Len(t, rootList.Roots, 1)
oldRoot := rootList.Roots[0] oldRoot := rootList.Roots[0]
// Update the provider config to use a new private key, which should // Update the provider config to use a new private key, which should
// cause a rotation. // cause a rotation.
_, newKey, err := connect.GeneratePrivateKey() _, newKey, err := connect.GeneratePrivateKey()
require.NoError(err) require.NoError(t, err)
newConfig := &structs.CAConfiguration{ newConfig := &structs.CAConfiguration{
Provider: "consul", Provider: "consul",
Config: map[string]interface{}{ Config: map[string]interface{}{
@ -1151,22 +1146,22 @@ func TestLeader_CARootPruning(t *testing.T) {
} }
var reply interface{} var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
} }
// Should have 2 roots now. // Should have 2 roots now.
_, roots, err := s1.fsm.State().CARoots(nil) _, roots, err := s1.fsm.State().CARoots(nil)
require.NoError(err) require.NoError(t, err)
require.Len(roots, 2) require.Len(t, roots, 2)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
// Now the old root should be pruned. // Now the old root should be pruned.
_, roots, err = s1.fsm.State().CARoots(nil) _, roots, err = s1.fsm.State().CARoots(nil)
require.NoError(err) require.NoError(t, err)
require.Len(roots, 1) require.Len(t, roots, 1)
require.True(roots[0].Active) require.True(t, roots[0].Active)
require.NotEqual(roots[0].ID, oldRoot.ID) require.NotEqual(t, roots[0].ID, oldRoot.ID)
} }
func TestConnectCA_ConfigurationSet_PersistsRoots(t *testing.T) { func TestConnectCA_ConfigurationSet_PersistsRoots(t *testing.T) {
@ -1176,7 +1171,6 @@ func TestConnectCA_ConfigurationSet_PersistsRoots(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -1201,13 +1195,13 @@ func TestConnectCA_ConfigurationSet_PersistsRoots(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var rootList structs.IndexedCARoots var rootList structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(rootList.Roots, 1) require.Len(t, rootList.Roots, 1)
// Update the provider config to use a new private key, which should // Update the provider config to use a new private key, which should
// cause a rotation. // cause a rotation.
_, newKey, err := connect.GeneratePrivateKey() _, newKey, err := connect.GeneratePrivateKey()
require.NoError(err) require.NoError(t, err)
newConfig := &structs.CAConfiguration{ newConfig := &structs.CAConfiguration{
Provider: "consul", Provider: "consul",
Config: map[string]interface{}{ Config: map[string]interface{}{
@ -1222,12 +1216,12 @@ func TestConnectCA_ConfigurationSet_PersistsRoots(t *testing.T) {
} }
var reply interface{} var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
} }
// Get the active root before leader change. // Get the active root before leader change.
_, root := getCAProviderWithLock(s1) _, root := getCAProviderWithLock(s1)
require.Len(root.IntermediateCerts, 1) require.Len(t, root.IntermediateCerts, 1)
// Force a leader change and make sure the root CA values are preserved. // Force a leader change and make sure the root CA values are preserved.
s1.Leave() s1.Leave()
@ -1310,17 +1304,16 @@ func TestParseCARoot(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
root, err := parseCARoot(tt.pem, "consul", "cluster") root, err := parseCARoot(tt.pem, "consul", "cluster")
if tt.wantErr { if tt.wantErr {
require.Error(err) require.Error(t, err)
return return
} }
require.NoError(err) require.NoError(t, err)
require.Equal(tt.wantSerial, root.SerialNumber) require.Equal(t, tt.wantSerial, root.SerialNumber)
require.Equal(strings.ToLower(tt.wantSigningKeyID), root.SigningKeyID) require.Equal(t, strings.ToLower(tt.wantSigningKeyID), root.SigningKeyID)
require.Equal(tt.wantKeyType, root.PrivateKeyType) require.Equal(t, tt.wantKeyType, root.PrivateKeyType)
require.Equal(tt.wantKeyBits, root.PrivateKeyBits) require.Equal(t, tt.wantKeyBits, root.PrivateKeyBits)
}) })
} }
} }
@ -1491,7 +1484,6 @@ func TestCAManager_Initialize_BadCAConfigDoesNotPreventLeaderEstablishment(t *te
} }
func TestConnectCA_ConfigurationSet_ForceWithoutCrossSigning(t *testing.T) { func TestConnectCA_ConfigurationSet_ForceWithoutCrossSigning(t *testing.T) {
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -1505,14 +1497,14 @@ func TestConnectCA_ConfigurationSet_ForceWithoutCrossSigning(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var rootList structs.IndexedCARoots var rootList structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(rootList.Roots, 1) require.Len(t, rootList.Roots, 1)
oldRoot := rootList.Roots[0] oldRoot := rootList.Roots[0]
// Update the provider config to use a new private key, which should // Update the provider config to use a new private key, which should
// cause a rotation. // cause a rotation.
_, newKey, err := connect.GeneratePrivateKey() _, newKey, err := connect.GeneratePrivateKey()
require.NoError(err) require.NoError(t, err)
newConfig := &structs.CAConfiguration{ newConfig := &structs.CAConfiguration{
Provider: "consul", Provider: "consul",
Config: map[string]interface{}{ Config: map[string]interface{}{
@ -1530,18 +1522,18 @@ func TestConnectCA_ConfigurationSet_ForceWithoutCrossSigning(t *testing.T) {
} }
var reply interface{} var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
} }
// Old root should no longer be active. // Old root should no longer be active.
_, roots, err := s1.fsm.State().CARoots(nil) _, roots, err := s1.fsm.State().CARoots(nil)
require.NoError(err) require.NoError(t, err)
require.Len(roots, 2) require.Len(t, roots, 2)
for _, r := range roots { for _, r := range roots {
if r.ID == oldRoot.ID { if r.ID == oldRoot.ID {
require.False(r.Active) require.False(t, r.Active)
} else { } else {
require.True(r.Active) require.True(t, r.Active)
} }
} }
} }
@ -1549,7 +1541,6 @@ func TestConnectCA_ConfigurationSet_ForceWithoutCrossSigning(t *testing.T) {
func TestConnectCA_ConfigurationSet_Vault_ForceWithoutCrossSigning(t *testing.T) { func TestConnectCA_ConfigurationSet_Vault_ForceWithoutCrossSigning(t *testing.T) {
ca.SkipIfVaultNotPresent(t) ca.SkipIfVaultNotPresent(t)
require := require.New(t)
testVault := ca.NewTestVaultServer(t) testVault := ca.NewTestVaultServer(t)
defer testVault.Stop() defer testVault.Stop()
@ -1577,8 +1568,8 @@ func TestConnectCA_ConfigurationSet_Vault_ForceWithoutCrossSigning(t *testing.T)
Datacenter: "dc1", Datacenter: "dc1",
} }
var rootList structs.IndexedCARoots var rootList structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList)) require.Nil(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(rootList.Roots, 1) require.Len(t, rootList.Roots, 1)
oldRoot := rootList.Roots[0] oldRoot := rootList.Roots[0]
// Update the provider config to use a new PKI path, which should // Update the provider config to use a new PKI path, which should
@ -1600,18 +1591,18 @@ func TestConnectCA_ConfigurationSet_Vault_ForceWithoutCrossSigning(t *testing.T)
} }
var reply interface{} var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
} }
// Old root should no longer be active. // Old root should no longer be active.
_, roots, err := s1.fsm.State().CARoots(nil) _, roots, err := s1.fsm.State().CARoots(nil)
require.NoError(err) require.NoError(t, err)
require.Len(roots, 2) require.Len(t, roots, 2)
for _, r := range roots { for _, r := range roots {
if r.ID == oldRoot.ID { if r.ID == oldRoot.ID {
require.False(r.Active) require.False(t, r.Active)
} else { } else {
require.True(r.Active) require.True(t, r.Active)
} }
} }
} }

View File

@ -217,7 +217,6 @@ func TestLeader_ReplicateIntentions(t *testing.T) {
func TestLeader_batchLegacyIntentionUpdates(t *testing.T) { func TestLeader_batchLegacyIntentionUpdates(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
ixn1 := structs.TestIntention(t) ixn1 := structs.TestIntention(t)
ixn1.ID = "ixn1" ixn1.ID = "ixn1"
ixn2 := structs.TestIntention(t) ixn2 := structs.TestIntention(t)
@ -356,7 +355,7 @@ func TestLeader_batchLegacyIntentionUpdates(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
actual := batchLegacyIntentionUpdates(tc.deletes, tc.updates) actual := batchLegacyIntentionUpdates(tc.deletes, tc.updates)
assert.Equal(tc.expected, actual) assert.Equal(t, tc.expected, actual)
} }
} }

View File

@ -10,15 +10,14 @@ import (
func TestLoggerStore_Named(t *testing.T) { func TestLoggerStore_Named(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
logger := testutil.Logger(t) logger := testutil.Logger(t)
store := newLoggerStore(logger) store := newLoggerStore(logger)
require.NotNil(store) require.NotNil(t, store)
l1 := store.Named("test1") l1 := store.Named("test1")
l2 := store.Named("test2") l2 := store.Named("test2")
require.Truef(l1 != l2, require.Truef(t, l1 != l2,
"expected %p and %p to have a different memory address", "expected %p and %p to have a different memory address",
l1, l1,
l2, l2,
@ -27,15 +26,14 @@ func TestLoggerStore_Named(t *testing.T) {
func TestLoggerStore_NamedCache(t *testing.T) { func TestLoggerStore_NamedCache(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
logger := testutil.Logger(t) logger := testutil.Logger(t)
store := newLoggerStore(logger) store := newLoggerStore(logger)
require.NotNil(store) require.NotNil(t, store)
l1 := store.Named("test") l1 := store.Named("test")
l2 := store.Named("test") l2 := store.Named("test")
require.Truef(l1 == l2, require.Truef(t, l1 == l2,
"expected %p and %p to have the same memory address", "expected %p and %p to have the same memory address",
l1, l1,
l2, l2,

View File

@ -2448,7 +2448,6 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -2484,7 +2483,7 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
} }
var reply struct{} var reply struct{}
require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &reply))
} }
// The query, start with connect disabled // The query, start with connect disabled
@ -2501,7 +2500,7 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
}, },
}, },
} }
require.NoError(msgpackrpc.CallWithCodec( require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID)) codec, "PreparedQuery.Apply", &query, &query.Query.ID))
// In the future we'll run updates // In the future we'll run updates
@ -2515,15 +2514,15 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
} }
var reply structs.PreparedQueryExecuteResponse var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec( require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply)) codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because it omits the proxy whose name // Result should have two because it omits the proxy whose name
// doesn't match the query. // doesn't match the query.
require.Len(reply.Nodes, 2) require.Len(t, reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service) require.Equal(t, query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS) require.Equal(t, query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader") require.True(t, reply.QueryMeta.KnownLeader, "queried leader")
} }
// Run with the Connect setting specified on the request // Run with the Connect setting specified on the request
@ -2535,31 +2534,31 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
} }
var reply structs.PreparedQueryExecuteResponse var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec( require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply)) codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because we should get the native AND // Result should have two because we should get the native AND
// the proxy (since the destination matches our service name). // the proxy (since the destination matches our service name).
require.Len(reply.Nodes, 2) require.Len(t, reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service) require.Equal(t, query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS) require.Equal(t, query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader") require.True(t, reply.QueryMeta.KnownLeader, "queried leader")
// Make sure the native is the first one // Make sure the native is the first one
if !reply.Nodes[0].Service.Connect.Native { if !reply.Nodes[0].Service.Connect.Native {
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0] reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
} }
require.True(reply.Nodes[0].Service.Connect.Native, "native") require.True(t, reply.Nodes[0].Service.Connect.Native, "native")
require.Equal(reply.Service, reply.Nodes[0].Service.Service) require.Equal(t, reply.Service, reply.Nodes[0].Service.Service)
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind) require.Equal(t, structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
require.Equal(reply.Service, reply.Nodes[1].Service.Proxy.DestinationServiceName) require.Equal(t, reply.Service, reply.Nodes[1].Service.Proxy.DestinationServiceName)
} }
// Update the query // Update the query
query.Query.Service.Connect = true query.Query.Service.Connect = true
require.NoError(msgpackrpc.CallWithCodec( require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID)) codec, "PreparedQuery.Apply", &query, &query.Query.ID))
// Run the registered query. // Run the registered query.
@ -2570,31 +2569,31 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
} }
var reply structs.PreparedQueryExecuteResponse var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec( require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply)) codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because we should get the native AND // Result should have two because we should get the native AND
// the proxy (since the destination matches our service name). // the proxy (since the destination matches our service name).
require.Len(reply.Nodes, 2) require.Len(t, reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service) require.Equal(t, query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS) require.Equal(t, query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader") require.True(t, reply.QueryMeta.KnownLeader, "queried leader")
// Make sure the native is the first one // Make sure the native is the first one
if !reply.Nodes[0].Service.Connect.Native { if !reply.Nodes[0].Service.Connect.Native {
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0] reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
} }
require.True(reply.Nodes[0].Service.Connect.Native, "native") require.True(t, reply.Nodes[0].Service.Connect.Native, "native")
require.Equal(reply.Service, reply.Nodes[0].Service.Service) require.Equal(t, reply.Service, reply.Nodes[0].Service.Service)
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind) require.Equal(t, structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
require.Equal(reply.Service, reply.Nodes[1].Service.Proxy.DestinationServiceName) require.Equal(t, reply.Service, reply.Nodes[1].Service.Proxy.DestinationServiceName)
} }
// Unset the query // Unset the query
query.Query.Service.Connect = false query.Query.Service.Connect = false
require.NoError(msgpackrpc.CallWithCodec( require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID)) codec, "PreparedQuery.Apply", &query, &query.Query.ID))
} }

View File

@ -233,9 +233,6 @@ func TestRPC_blockingQuery(t *testing.T) {
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
defer s.Shutdown() defer s.Shutdown()
require := require.New(t)
assert := assert.New(t)
// Perform a non-blocking query. Note that it's significant that the meta has // Perform a non-blocking query. Note that it's significant that the meta has
// a zero index in response - the implied opts.MinQueryIndex is also zero but // a zero index in response - the implied opts.MinQueryIndex is also zero but
// this should not block still. // this should not block still.
@ -311,9 +308,9 @@ func TestRPC_blockingQuery(t *testing.T) {
calls++ calls++
return nil return nil
} }
require.NoError(s.blockingQuery(&opts, &meta, fn)) require.NoError(t, s.blockingQuery(&opts, &meta, fn))
assert.Equal(1, calls) assert.Equal(t, 1, calls)
assert.Equal(uint64(1), meta.Index, assert.Equal(t, uint64(1), meta.Index,
"expect fake index of 1 to force client to block on next update") "expect fake index of 1 to force client to block on next update")
// Simulate client making next request // Simulate client making next request
@ -322,12 +319,12 @@ func TestRPC_blockingQuery(t *testing.T) {
// This time we should block even though the func returns index 0 still // This time we should block even though the func returns index 0 still
t0 := time.Now() t0 := time.Now()
require.NoError(s.blockingQuery(&opts, &meta, fn)) require.NoError(t, s.blockingQuery(&opts, &meta, fn))
t1 := time.Now() t1 := time.Now()
assert.Equal(2, calls) assert.Equal(t, 2, calls)
assert.Equal(uint64(1), meta.Index, assert.Equal(t, uint64(1), meta.Index,
"expect fake index of 1 to force client to block on next update") "expect fake index of 1 to force client to block on next update")
assert.True(t1.Sub(t0) > 20*time.Millisecond, assert.True(t, t1.Sub(t0) > 20*time.Millisecond,
"should have actually blocked waiting for timeout") "should have actually blocked waiting for timeout")
} }
@ -382,13 +379,13 @@ func TestRPC_blockingQuery(t *testing.T) {
} }
err := s.blockingQuery(&opts, &meta, fn) err := s.blockingQuery(&opts, &meta, fn)
require.NoError(err) require.NoError(t, err)
require.False(meta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be reset for unauthenticated calls") require.False(t, meta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be reset for unauthenticated calls")
}) })
t.Run("ResultsFilteredByACLs is honored for authenticated calls", func(t *testing.T) { t.Run("ResultsFilteredByACLs is honored for authenticated calls", func(t *testing.T) {
token, err := lib.GenerateUUID(nil) token, err := lib.GenerateUUID(nil)
require.NoError(err) require.NoError(t, err)
opts := structs.QueryOptions{ opts := structs.QueryOptions{
Token: token, Token: token,
@ -400,8 +397,8 @@ func TestRPC_blockingQuery(t *testing.T) {
} }
err = s.blockingQuery(&opts, &meta, fn) err = s.blockingQuery(&opts, &meta, fn)
require.NoError(err) require.NoError(t, err)
require.True(meta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be honored for authenticated calls") require.True(t, meta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be honored for authenticated calls")
}) })
} }

View File

@ -420,7 +420,6 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Run("Get", func(t *testing.T) { t.Run("Get", func(t *testing.T) {
require := require.New(t)
req := &structs.SessionSpecificRequest{ req := &structs.SessionSpecificRequest{
Datacenter: "dc1", Datacenter: "dc1",
@ -432,30 +431,29 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) {
var sessions structs.IndexedSessions var sessions structs.IndexedSessions
err := msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions) err := msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions)
require.NoError(err) require.NoError(t, err)
require.Empty(sessions.Sessions) require.Empty(t, sessions.Sessions)
require.True(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// ACL-restricted results included. // ACL-restricted results included.
req.Token = allowedToken req.Token = allowedToken
err = msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions) err = msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions)
require.NoError(err) require.NoError(t, err)
require.Len(sessions.Sessions, 1) require.Len(t, sessions.Sessions, 1)
require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") require.False(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
// Try to get a session that doesn't exist to make sure that's handled // Try to get a session that doesn't exist to make sure that's handled
// correctly by the filter (it will get passed a nil slice). // correctly by the filter (it will get passed a nil slice).
req.SessionID = "adf4238a-882b-9ddc-4a9d-5b6758e4159e" req.SessionID = "adf4238a-882b-9ddc-4a9d-5b6758e4159e"
err = msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions) err = msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions)
require.NoError(err) require.NoError(t, err)
require.Empty(sessions.Sessions) require.Empty(t, sessions.Sessions)
require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") require.False(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
}) })
t.Run("List", func(t *testing.T) { t.Run("List", func(t *testing.T) {
require := require.New(t)
req := &structs.DCSpecificRequest{ req := &structs.DCSpecificRequest{
Datacenter: "dc1", Datacenter: "dc1",
@ -466,21 +464,20 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) {
var sessions structs.IndexedSessions var sessions structs.IndexedSessions
err := msgpackrpc.CallWithCodec(codec, "Session.List", req, &sessions) err := msgpackrpc.CallWithCodec(codec, "Session.List", req, &sessions)
require.NoError(err) require.NoError(t, err)
require.Empty(sessions.Sessions) require.Empty(t, sessions.Sessions)
require.True(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// ACL-restricted results included. // ACL-restricted results included.
req.Token = allowedToken req.Token = allowedToken
err = msgpackrpc.CallWithCodec(codec, "Session.List", req, &sessions) err = msgpackrpc.CallWithCodec(codec, "Session.List", req, &sessions)
require.NoError(err) require.NoError(t, err)
require.Len(sessions.Sessions, 1) require.Len(t, sessions.Sessions, 1)
require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") require.False(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
}) })
t.Run("NodeSessions", func(t *testing.T) { t.Run("NodeSessions", func(t *testing.T) {
require := require.New(t)
req := &structs.NodeSpecificRequest{ req := &structs.NodeSpecificRequest{
Datacenter: "dc1", Datacenter: "dc1",
@ -492,17 +489,17 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) {
var sessions structs.IndexedSessions var sessions structs.IndexedSessions
err := msgpackrpc.CallWithCodec(codec, "Session.NodeSessions", req, &sessions) err := msgpackrpc.CallWithCodec(codec, "Session.NodeSessions", req, &sessions)
require.NoError(err) require.NoError(t, err)
require.Empty(sessions.Sessions) require.Empty(t, sessions.Sessions)
require.True(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// ACL-restricted results included. // ACL-restricted results included.
req.Token = allowedToken req.Token = allowedToken
err = msgpackrpc.CallWithCodec(codec, "Session.NodeSessions", req, &sessions) err = msgpackrpc.CallWithCodec(codec, "Session.NodeSessions", req, &sessions)
require.NoError(err) require.NoError(t, err)
require.Len(sessions.Sessions, 1) require.Len(t, sessions.Sessions, 1)
require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false") require.False(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
}) })
} }

View File

@ -1515,7 +1515,6 @@ func TestStateStore_EnsureService(t *testing.T) {
} }
func TestStateStore_EnsureService_connectProxy(t *testing.T) { func TestStateStore_EnsureService_connectProxy(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Create the service registration. // Create the service registration.
@ -1535,21 +1534,20 @@ func TestStateStore_EnsureService_connectProxy(t *testing.T) {
// Service successfully registers into the state store. // Service successfully registers into the state store.
testRegisterNode(t, s, 0, "node1") testRegisterNode(t, s, 0, "node1")
assert.Nil(s.EnsureService(10, "node1", ns1)) assert.Nil(t, s.EnsureService(10, "node1", ns1))
// Retrieve and verify // Retrieve and verify
_, out, err := s.NodeServices(nil, "node1", nil) _, out, err := s.NodeServices(nil, "node1", nil)
assert.Nil(err) assert.Nil(t, err)
assert.NotNil(out) assert.NotNil(t, out)
assert.Len(out.Services, 1) assert.Len(t, out.Services, 1)
expect1 := *ns1 expect1 := *ns1
expect1.CreateIndex, expect1.ModifyIndex = 10, 10 expect1.CreateIndex, expect1.ModifyIndex = 10, 10
assert.Equal(&expect1, out.Services["connect-proxy"]) assert.Equal(t, &expect1, out.Services["connect-proxy"])
} }
func TestStateStore_EnsureService_VirtualIPAssign(t *testing.T) { func TestStateStore_EnsureService_VirtualIPAssign(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
setVirtualIPFlags(t, s) setVirtualIPFlags(t, s)
@ -1575,17 +1573,17 @@ func TestStateStore_EnsureService_VirtualIPAssign(t *testing.T) {
// Make sure there's a virtual IP for the foo service. // Make sure there's a virtual IP for the foo service.
vip, err := s.VirtualIPForService(structs.ServiceName{Name: "foo"}) vip, err := s.VirtualIPForService(structs.ServiceName{Name: "foo"})
require.NoError(t, err) require.NoError(t, err)
assert.Equal("240.0.0.1", vip) assert.Equal(t, "240.0.0.1", vip)
// Retrieve and verify // Retrieve and verify
_, out, err := s.NodeServices(nil, "node1", nil) _, out, err := s.NodeServices(nil, "node1", nil)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(out) assert.NotNil(t, out)
assert.Len(out.Services, 1) assert.Len(t, out.Services, 1)
taggedAddress := out.Services["foo"].TaggedAddresses[structs.TaggedAddressVirtualIP] taggedAddress := out.Services["foo"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address) assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(ns1.Port, taggedAddress.Port) assert.Equal(t, ns1.Port, taggedAddress.Port)
// Create the service registration. // Create the service registration.
ns2 := &structs.NodeService{ ns2 := &structs.NodeService{
@ -1606,23 +1604,23 @@ func TestStateStore_EnsureService_VirtualIPAssign(t *testing.T) {
// Make sure the virtual IP has been incremented for the redis service. // Make sure the virtual IP has been incremented for the redis service.
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"}) vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"})
require.NoError(t, err) require.NoError(t, err)
assert.Equal("240.0.0.2", vip) assert.Equal(t, "240.0.0.2", vip)
// Retrieve and verify // Retrieve and verify
_, out, err = s.NodeServices(nil, "node1", nil) _, out, err = s.NodeServices(nil, "node1", nil)
assert.Nil(err) assert.Nil(t, err)
assert.NotNil(out) assert.NotNil(t, out)
assert.Len(out.Services, 2) assert.Len(t, out.Services, 2)
taggedAddress = out.Services["redis-proxy"].TaggedAddresses[structs.TaggedAddressVirtualIP] taggedAddress = out.Services["redis-proxy"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address) assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(ns2.Port, taggedAddress.Port) assert.Equal(t, ns2.Port, taggedAddress.Port)
// Delete the first service and make sure it no longer has a virtual IP assigned. // Delete the first service and make sure it no longer has a virtual IP assigned.
require.NoError(t, s.DeleteService(12, "node1", "foo", entMeta)) require.NoError(t, s.DeleteService(12, "node1", "foo", entMeta))
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "connect-proxy"}) vip, err = s.VirtualIPForService(structs.ServiceName{Name: "connect-proxy"})
require.NoError(t, err) require.NoError(t, err)
assert.Equal("", vip) assert.Equal(t, "", vip)
// Register another instance of redis-proxy and make sure the virtual IP is unchanged. // Register another instance of redis-proxy and make sure the virtual IP is unchanged.
ns3 := &structs.NodeService{ ns3 := &structs.NodeService{
@ -1643,14 +1641,14 @@ func TestStateStore_EnsureService_VirtualIPAssign(t *testing.T) {
// Make sure the virtual IP is unchanged for the redis service. // Make sure the virtual IP is unchanged for the redis service.
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"}) vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"})
require.NoError(t, err) require.NoError(t, err)
assert.Equal("240.0.0.2", vip) assert.Equal(t, "240.0.0.2", vip)
// Make sure the new instance has the same virtual IP. // Make sure the new instance has the same virtual IP.
_, out, err = s.NodeServices(nil, "node1", nil) _, out, err = s.NodeServices(nil, "node1", nil)
require.NoError(t, err) require.NoError(t, err)
taggedAddress = out.Services["redis-proxy2"].TaggedAddresses[structs.TaggedAddressVirtualIP] taggedAddress = out.Services["redis-proxy2"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address) assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(ns3.Port, taggedAddress.Port) assert.Equal(t, ns3.Port, taggedAddress.Port)
// Register another service to take its virtual IP. // Register another service to take its virtual IP.
ns4 := &structs.NodeService{ ns4 := &structs.NodeService{
@ -1671,18 +1669,17 @@ func TestStateStore_EnsureService_VirtualIPAssign(t *testing.T) {
// Make sure the virtual IP has allocated from the previously freed service. // Make sure the virtual IP has allocated from the previously freed service.
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "web"}) vip, err = s.VirtualIPForService(structs.ServiceName{Name: "web"})
require.NoError(t, err) require.NoError(t, err)
assert.Equal("240.0.0.1", vip) assert.Equal(t, "240.0.0.1", vip)
// Retrieve and verify // Retrieve and verify
_, out, err = s.NodeServices(nil, "node1", nil) _, out, err = s.NodeServices(nil, "node1", nil)
require.NoError(t, err) require.NoError(t, err)
taggedAddress = out.Services["web-proxy"].TaggedAddresses[structs.TaggedAddressVirtualIP] taggedAddress = out.Services["web-proxy"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address) assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(ns4.Port, taggedAddress.Port) assert.Equal(t, ns4.Port, taggedAddress.Port)
} }
func TestStateStore_EnsureService_ReassignFreedVIPs(t *testing.T) { func TestStateStore_EnsureService_ReassignFreedVIPs(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
setVirtualIPFlags(t, s) setVirtualIPFlags(t, s)
@ -1708,16 +1705,16 @@ func TestStateStore_EnsureService_ReassignFreedVIPs(t *testing.T) {
// Make sure there's a virtual IP for the foo service. // Make sure there's a virtual IP for the foo service.
vip, err := s.VirtualIPForService(structs.ServiceName{Name: "foo"}) vip, err := s.VirtualIPForService(structs.ServiceName{Name: "foo"})
require.NoError(t, err) require.NoError(t, err)
assert.Equal("240.0.0.1", vip) assert.Equal(t, "240.0.0.1", vip)
// Retrieve and verify // Retrieve and verify
_, out, err := s.NodeServices(nil, "node1", nil) _, out, err := s.NodeServices(nil, "node1", nil)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(out) assert.NotNil(t, out)
taggedAddress := out.Services["foo"].TaggedAddresses[structs.TaggedAddressVirtualIP] taggedAddress := out.Services["foo"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address) assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(ns1.Port, taggedAddress.Port) assert.Equal(t, ns1.Port, taggedAddress.Port)
// Create the service registration. // Create the service registration.
ns2 := &structs.NodeService{ ns2 := &structs.NodeService{
@ -1738,22 +1735,22 @@ func TestStateStore_EnsureService_ReassignFreedVIPs(t *testing.T) {
// Make sure the virtual IP has been incremented for the redis service. // Make sure the virtual IP has been incremented for the redis service.
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"}) vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"})
require.NoError(t, err) require.NoError(t, err)
assert.Equal("240.0.0.2", vip) assert.Equal(t, "240.0.0.2", vip)
// Retrieve and verify // Retrieve and verify
_, out, err = s.NodeServices(nil, "node1", nil) _, out, err = s.NodeServices(nil, "node1", nil)
assert.Nil(err) assert.Nil(t, err)
assert.NotNil(out) assert.NotNil(t, out)
taggedAddress = out.Services["redis"].TaggedAddresses[structs.TaggedAddressVirtualIP] taggedAddress = out.Services["redis"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address) assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(ns2.Port, taggedAddress.Port) assert.Equal(t, ns2.Port, taggedAddress.Port)
// Delete the last service and make sure it no longer has a virtual IP assigned. // Delete the last service and make sure it no longer has a virtual IP assigned.
require.NoError(t, s.DeleteService(12, "node1", "redis", entMeta)) require.NoError(t, s.DeleteService(12, "node1", "redis", entMeta))
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"}) vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"})
require.NoError(t, err) require.NoError(t, err)
assert.Equal("", vip) assert.Equal(t, "", vip)
// Register a new service, should end up with the freed 240.0.0.2 address. // Register a new service, should end up with the freed 240.0.0.2 address.
ns3 := &structs.NodeService{ ns3 := &structs.NodeService{
@ -1773,16 +1770,16 @@ func TestStateStore_EnsureService_ReassignFreedVIPs(t *testing.T) {
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "backend"}) vip, err = s.VirtualIPForService(structs.ServiceName{Name: "backend"})
require.NoError(t, err) require.NoError(t, err)
assert.Equal("240.0.0.2", vip) assert.Equal(t, "240.0.0.2", vip)
// Retrieve and verify // Retrieve and verify
_, out, err = s.NodeServices(nil, "node1", nil) _, out, err = s.NodeServices(nil, "node1", nil)
assert.Nil(err) assert.Nil(t, err)
assert.NotNil(out) assert.NotNil(t, out)
taggedAddress = out.Services["backend"].TaggedAddresses[structs.TaggedAddressVirtualIP] taggedAddress = out.Services["backend"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address) assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(ns3.Port, taggedAddress.Port) assert.Equal(t, ns3.Port, taggedAddress.Port)
// Create a new service, no more freed VIPs so it should go back to using the counter. // Create a new service, no more freed VIPs so it should go back to using the counter.
ns4 := &structs.NodeService{ ns4 := &structs.NodeService{
@ -1803,16 +1800,16 @@ func TestStateStore_EnsureService_ReassignFreedVIPs(t *testing.T) {
// Make sure the virtual IP has been incremented for the frontend service. // Make sure the virtual IP has been incremented for the frontend service.
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "frontend"}) vip, err = s.VirtualIPForService(structs.ServiceName{Name: "frontend"})
require.NoError(t, err) require.NoError(t, err)
assert.Equal("240.0.0.3", vip) assert.Equal(t, "240.0.0.3", vip)
// Retrieve and verify // Retrieve and verify
_, out, err = s.NodeServices(nil, "node1", nil) _, out, err = s.NodeServices(nil, "node1", nil)
assert.Nil(err) assert.Nil(t, err)
assert.NotNil(out) assert.NotNil(t, out)
taggedAddress = out.Services["frontend"].TaggedAddresses[structs.TaggedAddressVirtualIP] taggedAddress = out.Services["frontend"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address) assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(ns4.Port, taggedAddress.Port) assert.Equal(t, ns4.Port, taggedAddress.Port)
} }
func TestStateStore_Services(t *testing.T) { func TestStateStore_Services(t *testing.T) {
@ -2360,82 +2357,80 @@ func TestStateStore_DeleteService(t *testing.T) {
} }
func TestStateStore_ConnectServiceNodes(t *testing.T) { func TestStateStore_ConnectServiceNodes(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Listing with no results returns an empty list. // Listing with no results returns an empty list.
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
idx, nodes, err := s.ConnectServiceNodes(ws, "db", nil) idx, nodes, err := s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(0)) assert.Equal(t, idx, uint64(0))
assert.Len(nodes, 0) assert.Len(t, nodes, 0)
// Create some nodes and services. // Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"})) assert.Nil(t, s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"})) assert.Nil(t, s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000})) assert.Nil(t, s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000})) assert.Nil(t, s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000})) assert.Nil(t, s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000})) assert.Nil(t, s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "native-db", Service: "db", Connect: structs.ServiceConnect{Native: true}})) assert.Nil(t, s.EnsureService(16, "bar", &structs.NodeService{ID: "native-db", Service: "db", Connect: structs.ServiceConnect{Native: true}}))
assert.Nil(s.EnsureService(17, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001})) assert.Nil(t, s.EnsureService(17, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
// Read everything back. // Read everything back.
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil) idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(17)) assert.Equal(t, idx, uint64(17))
assert.Len(nodes, 3) assert.Len(t, nodes, 3)
for _, n := range nodes { for _, n := range nodes {
assert.True(n.ServiceKind == structs.ServiceKindConnectProxy || assert.True(t, n.ServiceKind == structs.ServiceKindConnectProxy ||
n.ServiceConnect.Native, n.ServiceConnect.Native,
"either proxy or connect native") "either proxy or connect native")
} }
// Registering some unrelated node should not fire the watch. // Registering some unrelated node should not fire the watch.
testRegisterNode(t, s, 17, "nope") testRegisterNode(t, s, 17, "nope")
assert.False(watchFired(ws)) assert.False(t, watchFired(ws))
// But removing a node with the "db" service should fire the watch. // But removing a node with the "db" service should fire the watch.
assert.Nil(s.DeleteNode(18, "bar", nil)) assert.Nil(t, s.DeleteNode(18, "bar", nil))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
} }
func TestStateStore_ConnectServiceNodes_Gateways(t *testing.T) { func TestStateStore_ConnectServiceNodes_Gateways(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Listing with no results returns an empty list. // Listing with no results returns an empty list.
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
idx, nodes, err := s.ConnectServiceNodes(ws, "db", nil) idx, nodes, err := s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(0)) assert.Equal(t, idx, uint64(0))
assert.Len(nodes, 0) assert.Len(t, nodes, 0)
// Create some nodes and services. // Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"})) assert.Nil(t, s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"})) assert.Nil(t, s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
// Typical services // Typical services
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000})) assert.Nil(t, s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000})) assert.Nil(t, s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001})) assert.Nil(t, s.EnsureService(14, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.False(watchFired(ws)) assert.False(t, watchFired(ws))
// Register a sidecar for db // Register a sidecar for db
assert.Nil(s.EnsureService(15, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000})) assert.Nil(t, s.EnsureService(15, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
// Reset WatchSet to ensure watch fires when associating db with gateway // Reset WatchSet to ensure watch fires when associating db with gateway
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
_, _, err = s.ConnectServiceNodes(ws, "db", nil) _, _, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
// Associate gateway with db // Associate gateway with db
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443})) assert.Nil(t, s.EnsureService(16, "bar", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443}))
assert.Nil(s.EnsureConfigEntry(17, &structs.TerminatingGatewayConfigEntry{ assert.Nil(t, s.EnsureConfigEntry(17, &structs.TerminatingGatewayConfigEntry{
Kind: "terminating-gateway", Kind: "terminating-gateway",
Name: "gateway", Name: "gateway",
Services: []structs.LinkedService{ Services: []structs.LinkedService{
@ -2444,71 +2439,71 @@ func TestStateStore_ConnectServiceNodes_Gateways(t *testing.T) {
}, },
}, },
})) }))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
// Read everything back. // Read everything back.
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil) idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(17)) assert.Equal(t, idx, uint64(17))
assert.Len(nodes, 2) assert.Len(t, nodes, 2)
// Check sidecar // Check sidecar
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind) assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal("foo", nodes[0].Node) assert.Equal(t, "foo", nodes[0].Node)
assert.Equal("proxy", nodes[0].ServiceName) assert.Equal(t, "proxy", nodes[0].ServiceName)
assert.Equal("proxy", nodes[0].ServiceID) assert.Equal(t, "proxy", nodes[0].ServiceID)
assert.Equal("db", nodes[0].ServiceProxy.DestinationServiceName) assert.Equal(t, "db", nodes[0].ServiceProxy.DestinationServiceName)
assert.Equal(8000, nodes[0].ServicePort) assert.Equal(t, 8000, nodes[0].ServicePort)
// Check gateway // Check gateway
assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].ServiceKind) assert.Equal(t, structs.ServiceKindTerminatingGateway, nodes[1].ServiceKind)
assert.Equal("bar", nodes[1].Node) assert.Equal(t, "bar", nodes[1].Node)
assert.Equal("gateway", nodes[1].ServiceName) assert.Equal(t, "gateway", nodes[1].ServiceName)
assert.Equal("gateway", nodes[1].ServiceID) assert.Equal(t, "gateway", nodes[1].ServiceID)
assert.Equal(443, nodes[1].ServicePort) assert.Equal(t, 443, nodes[1].ServicePort)
// Watch should fire when another gateway instance is registered // Watch should fire when another gateway instance is registered
assert.Nil(s.EnsureService(18, "foo", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway-2", Service: "gateway", Port: 443})) assert.Nil(t, s.EnsureService(18, "foo", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway-2", Service: "gateway", Port: 443}))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
// Reset WatchSet to ensure watch fires when deregistering gateway // Reset WatchSet to ensure watch fires when deregistering gateway
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
_, _, err = s.ConnectServiceNodes(ws, "db", nil) _, _, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
// Watch should fire when a gateway instance is deregistered // Watch should fire when a gateway instance is deregistered
assert.Nil(s.DeleteService(19, "bar", "gateway", nil)) assert.Nil(t, s.DeleteService(19, "bar", "gateway", nil))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil) idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(19)) assert.Equal(t, idx, uint64(19))
assert.Len(nodes, 2) assert.Len(t, nodes, 2)
// Check the new gateway // Check the new gateway
assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].ServiceKind) assert.Equal(t, structs.ServiceKindTerminatingGateway, nodes[1].ServiceKind)
assert.Equal("foo", nodes[1].Node) assert.Equal(t, "foo", nodes[1].Node)
assert.Equal("gateway", nodes[1].ServiceName) assert.Equal(t, "gateway", nodes[1].ServiceName)
assert.Equal("gateway-2", nodes[1].ServiceID) assert.Equal(t, "gateway-2", nodes[1].ServiceID)
assert.Equal(443, nodes[1].ServicePort) assert.Equal(t, 443, nodes[1].ServicePort)
// Index should not slide back after deleting all instances of the gateway // Index should not slide back after deleting all instances of the gateway
assert.Nil(s.DeleteService(20, "foo", "gateway-2", nil)) assert.Nil(t, s.DeleteService(20, "foo", "gateway-2", nil))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil) idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(20)) assert.Equal(t, idx, uint64(20))
assert.Len(nodes, 1) assert.Len(t, nodes, 1)
// Ensure that remaining node is the proxy and not a gateway // Ensure that remaining node is the proxy and not a gateway
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind) assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal("foo", nodes[0].Node) assert.Equal(t, "foo", nodes[0].Node)
assert.Equal("proxy", nodes[0].ServiceName) assert.Equal(t, "proxy", nodes[0].ServiceName)
assert.Equal("proxy", nodes[0].ServiceID) assert.Equal(t, "proxy", nodes[0].ServiceID)
assert.Equal(8000, nodes[0].ServicePort) assert.Equal(t, 8000, nodes[0].ServicePort)
} }
func TestStateStore_Service_Snapshot(t *testing.T) { func TestStateStore_Service_Snapshot(t *testing.T) {
@ -3679,14 +3674,12 @@ func TestStateStore_ConnectQueryBlocking(t *testing.T) {
tt.setupFn(s) tt.setupFn(s)
} }
require := require.New(t)
// Run the query // Run the query
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
_, res, err := s.CheckConnectServiceNodes(ws, tt.svc, nil) _, res, err := s.CheckConnectServiceNodes(ws, tt.svc, nil)
require.NoError(err) require.NoError(t, err)
require.Len(res, tt.wantBeforeResLen) require.Len(t, res, tt.wantBeforeResLen)
require.Len(ws, tt.wantBeforeWatchSetSize) require.Len(t, ws, tt.wantBeforeWatchSetSize)
// Mutate the state store // Mutate the state store
if tt.updateFn != nil { if tt.updateFn != nil {
@ -3695,18 +3688,18 @@ func TestStateStore_ConnectQueryBlocking(t *testing.T) {
fired := watchFired(ws) fired := watchFired(ws)
if tt.shouldFire { if tt.shouldFire {
require.True(fired, "WatchSet should have fired") require.True(t, fired, "WatchSet should have fired")
} else { } else {
require.False(fired, "WatchSet should not have fired") require.False(t, fired, "WatchSet should not have fired")
} }
// Re-query the same result. Should return the desired index and len // Re-query the same result. Should return the desired index and len
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
idx, res, err := s.CheckConnectServiceNodes(ws, tt.svc, nil) idx, res, err := s.CheckConnectServiceNodes(ws, tt.svc, nil)
require.NoError(err) require.NoError(t, err)
require.Len(res, tt.wantAfterResLen) require.Len(t, res, tt.wantAfterResLen)
require.Equal(tt.wantAfterIndex, idx) require.Equal(t, tt.wantAfterIndex, idx)
require.Len(ws, tt.wantAfterWatchSetSize) require.Len(t, ws, tt.wantAfterWatchSetSize)
}) })
} }
} }
@ -3828,25 +3821,24 @@ func TestStateStore_CheckServiceNodes(t *testing.T) {
} }
func TestStateStore_CheckConnectServiceNodes(t *testing.T) { func TestStateStore_CheckConnectServiceNodes(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Listing with no results returns an empty list. // Listing with no results returns an empty list.
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
idx, nodes, err := s.CheckConnectServiceNodes(ws, "db", nil) idx, nodes, err := s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(0)) assert.Equal(t, idx, uint64(0))
assert.Len(nodes, 0) assert.Len(t, nodes, 0)
// Create some nodes and services. // Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"})) assert.Nil(t, s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"})) assert.Nil(t, s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000})) assert.Nil(t, s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000})) assert.Nil(t, s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000})) assert.Nil(t, s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000})) assert.Nil(t, s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001})) assert.Nil(t, s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
// Register node checks // Register node checks
testRegisterCheck(t, s, 17, "foo", "", "check1", api.HealthPassing) testRegisterCheck(t, s, 17, "foo", "", "check1", api.HealthPassing)
@ -3859,13 +3851,13 @@ func TestStateStore_CheckConnectServiceNodes(t *testing.T) {
// Read everything back. // Read everything back.
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil) idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(20)) assert.Equal(t, idx, uint64(20))
assert.Len(nodes, 2) assert.Len(t, nodes, 2)
for _, n := range nodes { for _, n := range nodes {
assert.Equal(structs.ServiceKindConnectProxy, n.Service.Kind) assert.Equal(t, structs.ServiceKindConnectProxy, n.Service.Kind)
assert.Equal("db", n.Service.Proxy.DestinationServiceName) assert.Equal(t, "db", n.Service.Proxy.DestinationServiceName)
} }
} }
@ -3874,34 +3866,33 @@ func TestStateStore_CheckConnectServiceNodes_Gateways(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Listing with no results returns an empty list. // Listing with no results returns an empty list.
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
idx, nodes, err := s.CheckConnectServiceNodes(ws, "db", nil) idx, nodes, err := s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(0)) assert.Equal(t, idx, uint64(0))
assert.Len(nodes, 0) assert.Len(t, nodes, 0)
// Create some nodes and services. // Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"})) assert.Nil(t, s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"})) assert.Nil(t, s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
// Typical services // Typical services
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000})) assert.Nil(t, s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000})) assert.Nil(t, s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001})) assert.Nil(t, s.EnsureService(14, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.False(watchFired(ws)) assert.False(t, watchFired(ws))
// Register node and service checks // Register node and service checks
testRegisterCheck(t, s, 15, "foo", "", "check1", api.HealthPassing) testRegisterCheck(t, s, 15, "foo", "", "check1", api.HealthPassing)
testRegisterCheck(t, s, 16, "bar", "", "check2", api.HealthPassing) testRegisterCheck(t, s, 16, "bar", "", "check2", api.HealthPassing)
testRegisterCheck(t, s, 17, "foo", "db", "check3", api.HealthPassing) testRegisterCheck(t, s, 17, "foo", "db", "check3", api.HealthPassing)
assert.False(watchFired(ws)) assert.False(t, watchFired(ws))
// Watch should fire when a gateway is associated with the service, even if the gateway doesn't exist yet // Watch should fire when a gateway is associated with the service, even if the gateway doesn't exist yet
assert.Nil(s.EnsureConfigEntry(18, &structs.TerminatingGatewayConfigEntry{ assert.Nil(t, s.EnsureConfigEntry(18, &structs.TerminatingGatewayConfigEntry{
Kind: "terminating-gateway", Kind: "terminating-gateway",
Name: "gateway", Name: "gateway",
Services: []structs.LinkedService{ Services: []structs.LinkedService{
@ -3910,90 +3901,90 @@ func TestStateStore_CheckConnectServiceNodes_Gateways(t *testing.T) {
}, },
}, },
})) }))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil) idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(18)) assert.Equal(t, idx, uint64(18))
assert.Len(nodes, 0) assert.Len(t, nodes, 0)
// Watch should fire when a gateway is added // Watch should fire when a gateway is added
assert.Nil(s.EnsureService(19, "bar", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443})) assert.Nil(t, s.EnsureService(19, "bar", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443}))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
// Watch should fire when a check is added to the gateway // Watch should fire when a check is added to the gateway
testRegisterCheck(t, s, 20, "bar", "gateway", "check4", api.HealthPassing) testRegisterCheck(t, s, 20, "bar", "gateway", "check4", api.HealthPassing)
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
// Watch should fire when a different connect service is registered for db // Watch should fire when a different connect service is registered for db
assert.Nil(s.EnsureService(21, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000})) assert.Nil(t, s.EnsureService(21, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
// Read everything back. // Read everything back.
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil) idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(21)) assert.Equal(t, idx, uint64(21))
assert.Len(nodes, 2) assert.Len(t, nodes, 2)
// Check sidecar // Check sidecar
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].Service.Kind) assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].Service.Kind)
assert.Equal("foo", nodes[0].Node.Node) assert.Equal(t, "foo", nodes[0].Node.Node)
assert.Equal("proxy", nodes[0].Service.Service) assert.Equal(t, "proxy", nodes[0].Service.Service)
assert.Equal("proxy", nodes[0].Service.ID) assert.Equal(t, "proxy", nodes[0].Service.ID)
assert.Equal("db", nodes[0].Service.Proxy.DestinationServiceName) assert.Equal(t, "db", nodes[0].Service.Proxy.DestinationServiceName)
assert.Equal(8000, nodes[0].Service.Port) assert.Equal(t, 8000, nodes[0].Service.Port)
// Check gateway // Check gateway
assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].Service.Kind) assert.Equal(t, structs.ServiceKindTerminatingGateway, nodes[1].Service.Kind)
assert.Equal("bar", nodes[1].Node.Node) assert.Equal(t, "bar", nodes[1].Node.Node)
assert.Equal("gateway", nodes[1].Service.Service) assert.Equal(t, "gateway", nodes[1].Service.Service)
assert.Equal("gateway", nodes[1].Service.ID) assert.Equal(t, "gateway", nodes[1].Service.ID)
assert.Equal(443, nodes[1].Service.Port) assert.Equal(t, 443, nodes[1].Service.Port)
// Watch should fire when another gateway instance is registered // Watch should fire when another gateway instance is registered
assert.Nil(s.EnsureService(22, "foo", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway-2", Service: "gateway", Port: 443})) assert.Nil(t, s.EnsureService(22, "foo", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway-2", Service: "gateway", Port: 443}))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil) idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(22)) assert.Equal(t, idx, uint64(22))
assert.Len(nodes, 3) assert.Len(t, nodes, 3)
// Watch should fire when a gateway instance is deregistered // Watch should fire when a gateway instance is deregistered
assert.Nil(s.DeleteService(23, "bar", "gateway", nil)) assert.Nil(t, s.DeleteService(23, "bar", "gateway", nil))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil) idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(23)) assert.Equal(t, idx, uint64(23))
assert.Len(nodes, 2) assert.Len(t, nodes, 2)
// Check new gateway // Check new gateway
assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].Service.Kind) assert.Equal(t, structs.ServiceKindTerminatingGateway, nodes[1].Service.Kind)
assert.Equal("foo", nodes[1].Node.Node) assert.Equal(t, "foo", nodes[1].Node.Node)
assert.Equal("gateway", nodes[1].Service.Service) assert.Equal(t, "gateway", nodes[1].Service.Service)
assert.Equal("gateway-2", nodes[1].Service.ID) assert.Equal(t, "gateway-2", nodes[1].Service.ID)
assert.Equal(443, nodes[1].Service.Port) assert.Equal(t, 443, nodes[1].Service.Port)
// Index should not slide back after deleting all instances of the gateway // Index should not slide back after deleting all instances of the gateway
assert.Nil(s.DeleteService(24, "foo", "gateway-2", nil)) assert.Nil(t, s.DeleteService(24, "foo", "gateway-2", nil))
assert.True(watchFired(ws)) assert.True(t, watchFired(ws))
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil) idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(24)) assert.Equal(t, idx, uint64(24))
assert.Len(nodes, 1) assert.Len(t, nodes, 1)
// Ensure that remaining node is the proxy and not a gateway // Ensure that remaining node is the proxy and not a gateway
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].Service.Kind) assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].Service.Kind)
assert.Equal("foo", nodes[0].Node.Node) assert.Equal(t, "foo", nodes[0].Node.Node)
assert.Equal("proxy", nodes[0].Service.Service) assert.Equal(t, "proxy", nodes[0].Service.Service)
assert.Equal("proxy", nodes[0].Service.ID) assert.Equal(t, "proxy", nodes[0].Service.ID)
assert.Equal(8000, nodes[0].Service.Port) assert.Equal(t, 8000, nodes[0].Service.Port)
} }
func BenchmarkCheckServiceNodes(b *testing.B) { func BenchmarkCheckServiceNodes(b *testing.B) {
@ -5254,14 +5245,13 @@ func TestStateStore_GatewayServices_ServiceDeletion(t *testing.T) {
func TestStateStore_CheckIngressServiceNodes(t *testing.T) { func TestStateStore_CheckIngressServiceNodes(t *testing.T) {
s := testStateStore(t) s := testStateStore(t)
ws := setupIngressState(t, s) ws := setupIngressState(t, s)
require := require.New(t)
t.Run("check service1 ingress gateway", func(t *testing.T) { t.Run("check service1 ingress gateway", func(t *testing.T) {
idx, results, err := s.CheckIngressServiceNodes(ws, "service1", nil) idx, results, err := s.CheckIngressServiceNodes(ws, "service1", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(15), idx) require.Equal(t, uint64(15), idx)
// Multiple instances of the ingress2 service // Multiple instances of the ingress2 service
require.Len(results, 4) require.Len(t, results, 4)
ids := make(map[string]struct{}) ids := make(map[string]struct{})
for _, n := range results { for _, n := range results {
@ -5272,14 +5262,14 @@ func TestStateStore_CheckIngressServiceNodes(t *testing.T) {
"ingress2": {}, "ingress2": {},
"wildcardIngress": {}, "wildcardIngress": {},
} }
require.Equal(expectedIds, ids) require.Equal(t, expectedIds, ids)
}) })
t.Run("check service2 ingress gateway", func(t *testing.T) { t.Run("check service2 ingress gateway", func(t *testing.T) {
idx, results, err := s.CheckIngressServiceNodes(ws, "service2", nil) idx, results, err := s.CheckIngressServiceNodes(ws, "service2", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(15), idx) require.Equal(t, uint64(15), idx)
require.Len(results, 2) require.Len(t, results, 2)
ids := make(map[string]struct{}) ids := make(map[string]struct{})
for _, n := range results { for _, n := range results {
@ -5289,38 +5279,38 @@ func TestStateStore_CheckIngressServiceNodes(t *testing.T) {
"ingress1": {}, "ingress1": {},
"wildcardIngress": {}, "wildcardIngress": {},
} }
require.Equal(expectedIds, ids) require.Equal(t, expectedIds, ids)
}) })
t.Run("check service3 ingress gateway", func(t *testing.T) { t.Run("check service3 ingress gateway", func(t *testing.T) {
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
idx, results, err := s.CheckIngressServiceNodes(ws, "service3", nil) idx, results, err := s.CheckIngressServiceNodes(ws, "service3", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(15), idx) require.Equal(t, uint64(15), idx)
require.Len(results, 1) require.Len(t, results, 1)
require.Equal("wildcardIngress", results[0].Service.ID) require.Equal(t, "wildcardIngress", results[0].Service.ID)
}) })
t.Run("delete a wildcard entry", func(t *testing.T) { t.Run("delete a wildcard entry", func(t *testing.T) {
require.Nil(s.DeleteConfigEntry(19, "ingress-gateway", "wildcardIngress", nil)) require.Nil(t, s.DeleteConfigEntry(19, "ingress-gateway", "wildcardIngress", nil))
require.True(watchFired(ws)) require.True(t, watchFired(ws))
idx, results, err := s.CheckIngressServiceNodes(ws, "service1", nil) idx, results, err := s.CheckIngressServiceNodes(ws, "service1", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(15), idx) require.Equal(t, uint64(15), idx)
require.Len(results, 3) require.Len(t, results, 3)
idx, results, err = s.CheckIngressServiceNodes(ws, "service2", nil) idx, results, err = s.CheckIngressServiceNodes(ws, "service2", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(15), idx) require.Equal(t, uint64(15), idx)
require.Len(results, 1) require.Len(t, results, 1)
idx, results, err = s.CheckIngressServiceNodes(ws, "service3", nil) idx, results, err = s.CheckIngressServiceNodes(ws, "service3", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(15), idx) require.Equal(t, uint64(15), idx)
// TODO(ingress): index goes backward when deleting last config entry // TODO(ingress): index goes backward when deleting last config entry
// require.Equal(uint64(11), idx) // require.Equal(t,uint64(11), idx)
require.Len(results, 0) require.Len(t, results, 0)
}) })
} }
@ -5628,56 +5618,55 @@ func TestStateStore_GatewayServices_WildcardAssociation(t *testing.T) {
s := testStateStore(t) s := testStateStore(t)
setupIngressState(t, s) setupIngressState(t, s)
require := require.New(t)
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
t.Run("base case for wildcard", func(t *testing.T) { t.Run("base case for wildcard", func(t *testing.T) {
idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(16), idx) require.Equal(t, uint64(16), idx)
require.Len(results, 3) require.Len(t, results, 3)
}) })
t.Run("do not associate ingress services with gateway", func(t *testing.T) { t.Run("do not associate ingress services with gateway", func(t *testing.T) {
testRegisterIngressService(t, s, 17, "node1", "testIngress") testRegisterIngressService(t, s, 17, "node1", "testIngress")
require.False(watchFired(ws)) require.False(t, watchFired(ws))
idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(16), idx) require.Equal(t, uint64(16), idx)
require.Len(results, 3) require.Len(t, results, 3)
}) })
t.Run("do not associate terminating-gateway services with gateway", func(t *testing.T) { t.Run("do not associate terminating-gateway services with gateway", func(t *testing.T) {
require.Nil(s.EnsureService(18, "node1", require.Nil(t, s.EnsureService(18, "node1",
&structs.NodeService{ &structs.NodeService{
Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443, Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443,
}, },
)) ))
require.False(watchFired(ws)) require.False(t, watchFired(ws))
idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(16), idx) require.Equal(t, uint64(16), idx)
require.Len(results, 3) require.Len(t, results, 3)
}) })
t.Run("do not associate connect-proxy services with gateway", func(t *testing.T) { t.Run("do not associate connect-proxy services with gateway", func(t *testing.T) {
testRegisterSidecarProxy(t, s, 19, "node1", "web") testRegisterSidecarProxy(t, s, 19, "node1", "web")
require.False(watchFired(ws)) require.False(t, watchFired(ws))
idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(16), idx) require.Equal(t, uint64(16), idx)
require.Len(results, 3) require.Len(t, results, 3)
}) })
t.Run("do not associate consul services with gateway", func(t *testing.T) { t.Run("do not associate consul services with gateway", func(t *testing.T) {
require.Nil(s.EnsureService(20, "node1", require.Nil(t, s.EnsureService(20, "node1",
&structs.NodeService{ID: "consul", Service: "consul", Tags: nil}, &structs.NodeService{ID: "consul", Service: "consul", Tags: nil},
)) ))
require.False(watchFired(ws)) require.False(t, watchFired(ws))
idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(16), idx) require.Equal(t, uint64(16), idx)
require.Len(results, 3) require.Len(t, results, 3)
}) })
} }
@ -5708,15 +5697,13 @@ func TestStateStore_GatewayServices_IngressProtocolFiltering(t *testing.T) {
}) })
t.Run("no services from default tcp protocol", func(t *testing.T) { t.Run("no services from default tcp protocol", func(t *testing.T) {
require := require.New(t)
idx, results, err := s.GatewayServices(nil, "ingress1", nil) idx, results, err := s.GatewayServices(nil, "ingress1", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(4), idx) require.Equal(t, uint64(4), idx)
require.Len(results, 0) require.Len(t, results, 0)
}) })
t.Run("service-defaults", func(t *testing.T) { t.Run("service-defaults", func(t *testing.T) {
require := require.New(t)
expected := structs.GatewayServices{ expected := structs.GatewayServices{
{ {
Gateway: structs.NewServiceName("ingress1", nil), Gateway: structs.NewServiceName("ingress1", nil),
@ -5739,13 +5726,12 @@ func TestStateStore_GatewayServices_IngressProtocolFiltering(t *testing.T) {
} }
assert.NoError(t, s.EnsureConfigEntry(5, svcDefaults)) assert.NoError(t, s.EnsureConfigEntry(5, svcDefaults))
idx, results, err := s.GatewayServices(nil, "ingress1", nil) idx, results, err := s.GatewayServices(nil, "ingress1", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(5), idx) require.Equal(t, uint64(5), idx)
require.ElementsMatch(results, expected) require.ElementsMatch(t, results, expected)
}) })
t.Run("proxy-defaults", func(t *testing.T) { t.Run("proxy-defaults", func(t *testing.T) {
require := require.New(t)
expected := structs.GatewayServices{ expected := structs.GatewayServices{
{ {
Gateway: structs.NewServiceName("ingress1", nil), Gateway: structs.NewServiceName("ingress1", nil),
@ -5783,13 +5769,12 @@ func TestStateStore_GatewayServices_IngressProtocolFiltering(t *testing.T) {
assert.NoError(t, s.EnsureConfigEntry(6, proxyDefaults)) assert.NoError(t, s.EnsureConfigEntry(6, proxyDefaults))
idx, results, err := s.GatewayServices(nil, "ingress1", nil) idx, results, err := s.GatewayServices(nil, "ingress1", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(6), idx) require.Equal(t, uint64(6), idx)
require.ElementsMatch(results, expected) require.ElementsMatch(t, results, expected)
}) })
t.Run("service-defaults overrides proxy-defaults", func(t *testing.T) { t.Run("service-defaults overrides proxy-defaults", func(t *testing.T) {
require := require.New(t)
expected := structs.GatewayServices{ expected := structs.GatewayServices{
{ {
Gateway: structs.NewServiceName("ingress1", nil), Gateway: structs.NewServiceName("ingress1", nil),
@ -5813,13 +5798,12 @@ func TestStateStore_GatewayServices_IngressProtocolFiltering(t *testing.T) {
assert.NoError(t, s.EnsureConfigEntry(7, svcDefaults)) assert.NoError(t, s.EnsureConfigEntry(7, svcDefaults))
idx, results, err := s.GatewayServices(nil, "ingress1", nil) idx, results, err := s.GatewayServices(nil, "ingress1", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(7), idx) require.Equal(t, uint64(7), idx)
require.ElementsMatch(results, expected) require.ElementsMatch(t, results, expected)
}) })
t.Run("change listener protocol and expect different filter", func(t *testing.T) { t.Run("change listener protocol and expect different filter", func(t *testing.T) {
require := require.New(t)
expected := structs.GatewayServices{ expected := structs.GatewayServices{
{ {
Gateway: structs.NewServiceName("ingress1", nil), Gateway: structs.NewServiceName("ingress1", nil),
@ -5853,9 +5837,9 @@ func TestStateStore_GatewayServices_IngressProtocolFiltering(t *testing.T) {
assert.NoError(t, s.EnsureConfigEntry(8, ingress1)) assert.NoError(t, s.EnsureConfigEntry(8, ingress1))
idx, results, err := s.GatewayServices(nil, "ingress1", nil) idx, results, err := s.GatewayServices(nil, "ingress1", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(8), idx) require.Equal(t, uint64(8), idx)
require.ElementsMatch(results, expected) require.ElementsMatch(t, results, expected)
}) })
} }

View File

@ -12,7 +12,6 @@ import (
) )
func TestStore_ConfigEntry(t *testing.T) { func TestStore_ConfigEntry(t *testing.T) {
require := require.New(t)
s := testConfigStateStore(t) s := testConfigStateStore(t)
expected := &structs.ProxyConfigEntry{ expected := &structs.ProxyConfigEntry{
@ -24,12 +23,12 @@ func TestStore_ConfigEntry(t *testing.T) {
} }
// Create // Create
require.NoError(s.EnsureConfigEntry(0, expected)) require.NoError(t, s.EnsureConfigEntry(0, expected))
idx, config, err := s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil) idx, config, err := s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(0), idx) require.Equal(t, uint64(0), idx)
require.Equal(expected, config) require.Equal(t, expected, config)
// Update // Update
updated := &structs.ProxyConfigEntry{ updated := &structs.ProxyConfigEntry{
@ -39,44 +38,43 @@ func TestStore_ConfigEntry(t *testing.T) {
"DestinationServiceName": "bar", "DestinationServiceName": "bar",
}, },
} }
require.NoError(s.EnsureConfigEntry(1, updated)) require.NoError(t, s.EnsureConfigEntry(1, updated))
idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil) idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(1), idx) require.Equal(t, uint64(1), idx)
require.Equal(updated, config) require.Equal(t, updated, config)
// Delete // Delete
require.NoError(s.DeleteConfigEntry(2, structs.ProxyDefaults, "global", nil)) require.NoError(t, s.DeleteConfigEntry(2, structs.ProxyDefaults, "global", nil))
idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil) idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(2), idx) require.Equal(t, uint64(2), idx)
require.Nil(config) require.Nil(t, config)
// Set up a watch. // Set up a watch.
serviceConf := &structs.ServiceConfigEntry{ serviceConf := &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "foo", Name: "foo",
} }
require.NoError(s.EnsureConfigEntry(3, serviceConf)) require.NoError(t, s.EnsureConfigEntry(3, serviceConf))
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
_, _, err = s.ConfigEntry(ws, structs.ServiceDefaults, "foo", nil) _, _, err = s.ConfigEntry(ws, structs.ServiceDefaults, "foo", nil)
require.NoError(err) require.NoError(t, err)
// Make an unrelated modification and make sure the watch doesn't fire. // Make an unrelated modification and make sure the watch doesn't fire.
require.NoError(s.EnsureConfigEntry(4, updated)) require.NoError(t, s.EnsureConfigEntry(4, updated))
require.False(watchFired(ws)) require.False(t, watchFired(ws))
// Update the watched config and make sure it fires. // Update the watched config and make sure it fires.
serviceConf.Protocol = "http" serviceConf.Protocol = "http"
require.NoError(s.EnsureConfigEntry(5, serviceConf)) require.NoError(t, s.EnsureConfigEntry(5, serviceConf))
require.True(watchFired(ws)) require.True(t, watchFired(ws))
} }
func TestStore_ConfigEntryCAS(t *testing.T) { func TestStore_ConfigEntryCAS(t *testing.T) {
require := require.New(t)
s := testConfigStateStore(t) s := testConfigStateStore(t)
expected := &structs.ProxyConfigEntry{ expected := &structs.ProxyConfigEntry{
@ -88,12 +86,12 @@ func TestStore_ConfigEntryCAS(t *testing.T) {
} }
// Create // Create
require.NoError(s.EnsureConfigEntry(1, expected)) require.NoError(t, s.EnsureConfigEntry(1, expected))
idx, config, err := s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil) idx, config, err := s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(1), idx) require.Equal(t, uint64(1), idx)
require.Equal(expected, config) require.Equal(t, expected, config)
// Update with invalid index // Update with invalid index
updated := &structs.ProxyConfigEntry{ updated := &structs.ProxyConfigEntry{
@ -104,29 +102,28 @@ func TestStore_ConfigEntryCAS(t *testing.T) {
}, },
} }
ok, err := s.EnsureConfigEntryCAS(2, 99, updated) ok, err := s.EnsureConfigEntryCAS(2, 99, updated)
require.False(ok) require.False(t, ok)
require.NoError(err) require.NoError(t, err)
// Entry should not be changed // Entry should not be changed
idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil) idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(1), idx) require.Equal(t, uint64(1), idx)
require.Equal(expected, config) require.Equal(t, expected, config)
// Update with a valid index // Update with a valid index
ok, err = s.EnsureConfigEntryCAS(2, 1, updated) ok, err = s.EnsureConfigEntryCAS(2, 1, updated)
require.True(ok) require.True(t, ok)
require.NoError(err) require.NoError(t, err)
// Entry should be updated // Entry should be updated
idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil) idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(2), idx) require.Equal(t, uint64(2), idx)
require.Equal(updated, config) require.Equal(t, updated, config)
} }
func TestStore_ConfigEntry_DeleteCAS(t *testing.T) { func TestStore_ConfigEntry_DeleteCAS(t *testing.T) {
require := require.New(t)
s := testConfigStateStore(t) s := testConfigStateStore(t)
entry := &structs.ProxyConfigEntry{ entry := &structs.ProxyConfigEntry{
@ -139,31 +136,31 @@ func TestStore_ConfigEntry_DeleteCAS(t *testing.T) {
// Attempt to delete the entry before it exists. // Attempt to delete the entry before it exists.
ok, err := s.DeleteConfigEntryCAS(1, 0, entry) ok, err := s.DeleteConfigEntryCAS(1, 0, entry)
require.NoError(err) require.NoError(t, err)
require.False(ok) require.False(t, ok)
// Create the entry. // Create the entry.
require.NoError(s.EnsureConfigEntry(1, entry)) require.NoError(t, s.EnsureConfigEntry(1, entry))
// Attempt to delete with an invalid index. // Attempt to delete with an invalid index.
ok, err = s.DeleteConfigEntryCAS(2, 99, entry) ok, err = s.DeleteConfigEntryCAS(2, 99, entry)
require.NoError(err) require.NoError(t, err)
require.False(ok) require.False(t, ok)
// Entry should not be deleted. // Entry should not be deleted.
_, config, err := s.ConfigEntry(nil, entry.Kind, entry.Name, nil) _, config, err := s.ConfigEntry(nil, entry.Kind, entry.Name, nil)
require.NoError(err) require.NoError(t, err)
require.NotNil(config) require.NotNil(t, config)
// Attempt to delete with a valid index. // Attempt to delete with a valid index.
ok, err = s.DeleteConfigEntryCAS(2, 1, entry) ok, err = s.DeleteConfigEntryCAS(2, 1, entry)
require.NoError(err) require.NoError(t, err)
require.True(ok) require.True(t, ok)
// Entry should be deleted. // Entry should be deleted.
_, config, err = s.ConfigEntry(nil, entry.Kind, entry.Name, nil) _, config, err = s.ConfigEntry(nil, entry.Kind, entry.Name, nil)
require.NoError(err) require.NoError(t, err)
require.Nil(config) require.Nil(t, config)
} }
func TestStore_ConfigEntry_UpdateOver(t *testing.T) { func TestStore_ConfigEntry_UpdateOver(t *testing.T) {
@ -263,7 +260,6 @@ func TestStore_ConfigEntry_UpdateOver(t *testing.T) {
} }
func TestStore_ConfigEntries(t *testing.T) { func TestStore_ConfigEntries(t *testing.T) {
require := require.New(t)
s := testConfigStateStore(t) s := testConfigStateStore(t)
// Create some config entries. // Create some config entries.
@ -280,39 +276,39 @@ func TestStore_ConfigEntries(t *testing.T) {
Name: "test3", Name: "test3",
} }
require.NoError(s.EnsureConfigEntry(0, entry1)) require.NoError(t, s.EnsureConfigEntry(0, entry1))
require.NoError(s.EnsureConfigEntry(1, entry2)) require.NoError(t, s.EnsureConfigEntry(1, entry2))
require.NoError(s.EnsureConfigEntry(2, entry3)) require.NoError(t, s.EnsureConfigEntry(2, entry3))
// Get all entries // Get all entries
idx, entries, err := s.ConfigEntries(nil, nil) idx, entries, err := s.ConfigEntries(nil, nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(2), idx) require.Equal(t, uint64(2), idx)
require.Equal([]structs.ConfigEntry{entry1, entry2, entry3}, entries) require.Equal(t, []structs.ConfigEntry{entry1, entry2, entry3}, entries)
// Get all proxy entries // Get all proxy entries
idx, entries, err = s.ConfigEntriesByKind(nil, structs.ProxyDefaults, nil) idx, entries, err = s.ConfigEntriesByKind(nil, structs.ProxyDefaults, nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(2), idx) require.Equal(t, uint64(2), idx)
require.Equal([]structs.ConfigEntry{entry1}, entries) require.Equal(t, []structs.ConfigEntry{entry1}, entries)
// Get all service entries // Get all service entries
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
idx, entries, err = s.ConfigEntriesByKind(ws, structs.ServiceDefaults, nil) idx, entries, err = s.ConfigEntriesByKind(ws, structs.ServiceDefaults, nil)
require.NoError(err) require.NoError(t, err)
require.Equal(uint64(2), idx) require.Equal(t, uint64(2), idx)
require.Equal([]structs.ConfigEntry{entry2, entry3}, entries) require.Equal(t, []structs.ConfigEntry{entry2, entry3}, entries)
// Watch should not have fired // Watch should not have fired
require.False(watchFired(ws)) require.False(t, watchFired(ws))
// Now make an update and make sure the watch fires. // Now make an update and make sure the watch fires.
require.NoError(s.EnsureConfigEntry(3, &structs.ServiceConfigEntry{ require.NoError(t, s.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "test2", Name: "test2",
Protocol: "tcp", Protocol: "tcp",
})) }))
require.True(watchFired(ws)) require.True(t, watchFired(ws))
} }
func TestStore_ConfigEntry_GraphValidation(t *testing.T) { func TestStore_ConfigEntry_GraphValidation(t *testing.T) {

View File

@ -184,25 +184,24 @@ func TestStore_CAConfig_Snapshot_Restore_BlankConfig(t *testing.T) {
} }
func TestStore_CARootSetList(t *testing.T) { func TestStore_CARootSetList(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Call list to populate the watch set // Call list to populate the watch set
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws) _, _, err := s.CARoots(ws)
assert.Nil(err) assert.Nil(t, err)
// Build a valid value // Build a valid value
ca1 := connect.TestCA(t, nil) ca1 := connect.TestCA(t, nil)
expected := *ca1 expected := *ca1
// Set // Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1}) ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
assert.Nil(err) assert.Nil(t, err)
assert.True(ok) assert.True(t, ok)
// Make sure the index got updated. // Make sure the index got updated.
assert.Equal(s.maxIndex(tableConnectCARoots), uint64(1)) assert.Equal(t, s.maxIndex(tableConnectCARoots), uint64(1))
assert.True(watchFired(ws), "watch fired") assert.True(t, watchFired(ws), "watch fired")
// Read it back out and verify it. // Read it back out and verify it.
@ -212,20 +211,19 @@ func TestStore_CARootSetList(t *testing.T) {
} }
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
_, roots, err := s.CARoots(ws) _, roots, err := s.CARoots(ws)
assert.Nil(err) assert.Nil(t, err)
assert.Len(roots, 1) assert.Len(t, roots, 1)
actual := roots[0] actual := roots[0]
assertDeepEqual(t, expected, *actual) assertDeepEqual(t, expected, *actual)
} }
func TestStore_CARootSet_emptyID(t *testing.T) { func TestStore_CARootSet_emptyID(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Call list to populate the watch set // Call list to populate the watch set
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws) _, _, err := s.CARoots(ws)
assert.Nil(err) assert.Nil(t, err)
// Build a valid value // Build a valid value
ca1 := connect.TestCA(t, nil) ca1 := connect.TestCA(t, nil)
@ -233,29 +231,28 @@ func TestStore_CARootSet_emptyID(t *testing.T) {
// Set // Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1}) ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
assert.NotNil(err) assert.NotNil(t, err)
assert.Contains(err.Error(), ErrMissingCARootID.Error()) assert.Contains(t, err.Error(), ErrMissingCARootID.Error())
assert.False(ok) assert.False(t, ok)
// Make sure the index got updated. // Make sure the index got updated.
assert.Equal(s.maxIndex(tableConnectCARoots), uint64(0)) assert.Equal(t, s.maxIndex(tableConnectCARoots), uint64(0))
assert.False(watchFired(ws), "watch fired") assert.False(t, watchFired(ws), "watch fired")
// Read it back out and verify it. // Read it back out and verify it.
ws = memdb.NewWatchSet() ws = memdb.NewWatchSet()
_, roots, err := s.CARoots(ws) _, roots, err := s.CARoots(ws)
assert.Nil(err) assert.Nil(t, err)
assert.Len(roots, 0) assert.Len(t, roots, 0)
} }
func TestStore_CARootSet_noActive(t *testing.T) { func TestStore_CARootSet_noActive(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Call list to populate the watch set // Call list to populate the watch set
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws) _, _, err := s.CARoots(ws)
assert.Nil(err) assert.Nil(t, err)
// Build a valid value // Build a valid value
ca1 := connect.TestCA(t, nil) ca1 := connect.TestCA(t, nil)
@ -265,19 +262,18 @@ func TestStore_CARootSet_noActive(t *testing.T) {
// Set // Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2}) ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
assert.NotNil(err) assert.NotNil(t, err)
assert.Contains(err.Error(), "exactly one active") assert.Contains(t, err.Error(), "exactly one active")
assert.False(ok) assert.False(t, ok)
} }
func TestStore_CARootSet_multipleActive(t *testing.T) { func TestStore_CARootSet_multipleActive(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Call list to populate the watch set // Call list to populate the watch set
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws) _, _, err := s.CARoots(ws)
assert.Nil(err) assert.Nil(t, err)
// Build a valid value // Build a valid value
ca1 := connect.TestCA(t, nil) ca1 := connect.TestCA(t, nil)
@ -285,13 +281,12 @@ func TestStore_CARootSet_multipleActive(t *testing.T) {
// Set // Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2}) ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
assert.NotNil(err) assert.NotNil(t, err)
assert.Contains(err.Error(), "exactly one active") assert.Contains(t, err.Error(), "exactly one active")
assert.False(ok) assert.False(t, ok)
} }
func TestStore_CARootActive_valid(t *testing.T) { func TestStore_CARootActive_valid(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Build a valid value // Build a valid value
@ -303,33 +298,31 @@ func TestStore_CARootActive_valid(t *testing.T) {
// Set // Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2, ca3}) ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2, ca3})
assert.Nil(err) assert.Nil(t, err)
assert.True(ok) assert.True(t, ok)
// Query // Query
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
idx, res, err := s.CARootActive(ws) idx, res, err := s.CARootActive(ws)
assert.Equal(idx, uint64(1)) assert.Equal(t, idx, uint64(1))
assert.Nil(err) assert.Nil(t, err)
assert.NotNil(res) assert.NotNil(t, res)
assert.Equal(ca2.ID, res.ID) assert.Equal(t, ca2.ID, res.ID)
} }
// Test that querying the active CA returns the correct value. // Test that querying the active CA returns the correct value.
func TestStore_CARootActive_none(t *testing.T) { func TestStore_CARootActive_none(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Querying with no results returns nil. // Querying with no results returns nil.
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
idx, res, err := s.CARootActive(ws) idx, res, err := s.CARootActive(ws)
assert.Equal(idx, uint64(0)) assert.Equal(t, idx, uint64(0))
assert.Nil(res) assert.Nil(t, res)
assert.Nil(err) assert.Nil(t, err)
} }
func TestStore_CARoot_Snapshot_Restore(t *testing.T) { func TestStore_CARoot_Snapshot_Restore(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Create some intentions. // Create some intentions.
@ -351,8 +344,8 @@ func TestStore_CARoot_Snapshot_Restore(t *testing.T) {
// Now create // Now create
ok, err := s.CARootSetCAS(1, 0, roots) ok, err := s.CARootSetCAS(1, 0, roots)
assert.Nil(err) assert.Nil(t, err)
assert.True(ok) assert.True(t, ok)
// Snapshot the queries. // Snapshot the queries.
snap := s.Snapshot() snap := s.Snapshot()
@ -360,34 +353,33 @@ func TestStore_CARoot_Snapshot_Restore(t *testing.T) {
// Alter the real state store. // Alter the real state store.
ok, err = s.CARootSetCAS(2, 1, roots[:1]) ok, err = s.CARootSetCAS(2, 1, roots[:1])
assert.Nil(err) assert.Nil(t, err)
assert.True(ok) assert.True(t, ok)
// Verify the snapshot. // Verify the snapshot.
assert.Equal(snap.LastIndex(), uint64(1)) assert.Equal(t, snap.LastIndex(), uint64(1))
dump, err := snap.CARoots() dump, err := snap.CARoots()
assert.Nil(err) assert.Nil(t, err)
assert.Equal(roots, dump) assert.Equal(t, roots, dump)
// Restore the values into a new state store. // Restore the values into a new state store.
func() { func() {
s := testStateStore(t) s := testStateStore(t)
restore := s.Restore() restore := s.Restore()
for _, r := range dump { for _, r := range dump {
assert.Nil(restore.CARoot(r)) assert.Nil(t, restore.CARoot(r))
} }
restore.Commit() restore.Commit()
// Read the restored values back out and verify that they match. // Read the restored values back out and verify that they match.
idx, actual, err := s.CARoots(nil) idx, actual, err := s.CARoots(nil)
assert.Nil(err) assert.Nil(t, err)
assert.Equal(idx, uint64(2)) assert.Equal(t, idx, uint64(2))
assert.Equal(roots, actual) assert.Equal(t, roots, actual)
}() }()
} }
func TestStore_CABuiltinProvider(t *testing.T) { func TestStore_CABuiltinProvider(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
{ {
@ -398,13 +390,13 @@ func TestStore_CABuiltinProvider(t *testing.T) {
} }
ok, err := s.CASetProviderState(0, expected) ok, err := s.CASetProviderState(0, expected)
assert.NoError(err) assert.NoError(t, err)
assert.True(ok) assert.True(t, ok)
idx, state, err := s.CAProviderState(expected.ID) idx, state, err := s.CAProviderState(expected.ID)
assert.NoError(err) assert.NoError(t, err)
assert.Equal(idx, uint64(0)) assert.Equal(t, idx, uint64(0))
assert.Equal(expected, state) assert.Equal(t, expected, state)
} }
{ {
@ -415,13 +407,13 @@ func TestStore_CABuiltinProvider(t *testing.T) {
} }
ok, err := s.CASetProviderState(1, expected) ok, err := s.CASetProviderState(1, expected)
assert.NoError(err) assert.NoError(t, err)
assert.True(ok) assert.True(t, ok)
idx, state, err := s.CAProviderState(expected.ID) idx, state, err := s.CAProviderState(expected.ID)
assert.NoError(err) assert.NoError(t, err)
assert.Equal(idx, uint64(1)) assert.Equal(t, idx, uint64(1))
assert.Equal(expected, state) assert.Equal(t, expected, state)
} }
{ {
@ -429,21 +421,20 @@ func TestStore_CABuiltinProvider(t *testing.T) {
// numbers will initialize from the max index of the provider table. // numbers will initialize from the max index of the provider table.
// That's why this first serial is 2 and not 1. // That's why this first serial is 2 and not 1.
sn, err := s.CAIncrementProviderSerialNumber(10) sn, err := s.CAIncrementProviderSerialNumber(10)
assert.NoError(err) assert.NoError(t, err)
assert.Equal(uint64(2), sn) assert.Equal(t, uint64(2), sn)
sn, err = s.CAIncrementProviderSerialNumber(10) sn, err = s.CAIncrementProviderSerialNumber(10)
assert.NoError(err) assert.NoError(t, err)
assert.Equal(uint64(3), sn) assert.Equal(t, uint64(3), sn)
sn, err = s.CAIncrementProviderSerialNumber(10) sn, err = s.CAIncrementProviderSerialNumber(10)
assert.NoError(err) assert.NoError(t, err)
assert.Equal(uint64(4), sn) assert.Equal(t, uint64(4), sn)
} }
} }
func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) { func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t) s := testStateStore(t)
// Create multiple state entries. // Create multiple state entries.
@ -462,8 +453,8 @@ func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) {
for i, state := range before { for i, state := range before {
ok, err := s.CASetProviderState(uint64(98+i), state) ok, err := s.CASetProviderState(uint64(98+i), state)
assert.NoError(err) assert.NoError(t, err)
assert.True(ok) assert.True(t, ok)
} }
// Take a snapshot. // Take a snapshot.
@ -477,26 +468,26 @@ func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) {
RootCert: "d", RootCert: "d",
} }
ok, err := s.CASetProviderState(100, after) ok, err := s.CASetProviderState(100, after)
assert.NoError(err) assert.NoError(t, err)
assert.True(ok) assert.True(t, ok)
snapped, err := snap.CAProviderState() snapped, err := snap.CAProviderState()
assert.NoError(err) assert.NoError(t, err)
assert.Equal(before, snapped) assert.Equal(t, before, snapped)
// Restore onto a new state store. // Restore onto a new state store.
s2 := testStateStore(t) s2 := testStateStore(t)
restore := s2.Restore() restore := s2.Restore()
for _, entry := range snapped { for _, entry := range snapped {
assert.NoError(restore.CAProviderState(entry)) assert.NoError(t, restore.CAProviderState(entry))
} }
restore.Commit() restore.Commit()
// Verify the restored values match those from before the snapshot. // Verify the restored values match those from before the snapshot.
for _, state := range before { for _, state := range before {
idx, res, err := s2.CAProviderState(state.ID) idx, res, err := s2.CAProviderState(state.ID)
assert.NoError(err) assert.NoError(t, err)
assert.Equal(idx, uint64(99)) assert.Equal(t, idx, uint64(99))
assert.Equal(state, res) assert.Equal(t, state, res)
} }
} }

View File

@ -46,14 +46,13 @@ func testBothIntentionFormats(t *testing.T, f func(t *testing.T, s *Store, legac
func TestStore_IntentionGet_none(t *testing.T) { func TestStore_IntentionGet_none(t *testing.T) {
testBothIntentionFormats(t, func(t *testing.T, s *Store, legacy bool) { testBothIntentionFormats(t, func(t *testing.T, s *Store, legacy bool) {
assert := assert.New(t)
// Querying with no results returns nil. // Querying with no results returns nil.
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
idx, _, res, err := s.IntentionGet(ws, testUUID()) idx, _, res, err := s.IntentionGet(ws, testUUID())
assert.Equal(uint64(1), idx) assert.Equal(t, uint64(1), idx)
assert.Nil(res) assert.Nil(t, res)
assert.Nil(err) assert.Nil(t, err)
}) })
} }

View File

@ -18,7 +18,6 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
s := testACLTokensStateStore(t) s := testACLTokensStateStore(t)
// Setup token and wait for good state // Setup token and wait for good state
@ -37,14 +36,14 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
go publisher.Run(ctx) go publisher.Run(ctx)
s.db.publisher = publisher s.db.publisher = publisher
sub, err := publisher.Subscribe(subscription) sub, err := publisher.Subscribe(subscription)
require.NoError(err) require.NoError(t, err)
defer sub.Unsubscribe() defer sub.Unsubscribe()
eventCh := testRunSub(sub) eventCh := testRunSub(sub)
// Stream should get EndOfSnapshot // Stream should get EndOfSnapshot
e := assertEvent(t, eventCh) e := assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot()) require.True(t, e.IsEndOfSnapshot())
// Update an unrelated token. // Update an unrelated token.
token2 := &structs.ACLToken{ token2 := &structs.ACLToken{
@ -52,7 +51,7 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
SecretID: "72e81982-7a0f-491f-a60e-c9c802ac1402", SecretID: "72e81982-7a0f-491f-a60e-c9c802ac1402",
} }
token2.SetHash(false) token2.SetHash(false)
require.NoError(s.ACLTokenSet(3, token2.Clone())) require.NoError(t, s.ACLTokenSet(3, token2.Clone()))
// Ensure there's no reset event. // Ensure there's no reset event.
assertNoEvent(t, eventCh) assertNoEvent(t, eventCh)
@ -64,11 +63,11 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
Description: "something else", Description: "something else",
} }
token3.SetHash(false) token3.SetHash(false)
require.NoError(s.ACLTokenSet(4, token3.Clone())) require.NoError(t, s.ACLTokenSet(4, token3.Clone()))
// Ensure the reset event was sent. // Ensure the reset event was sent.
err = assertErr(t, eventCh) err = assertErr(t, eventCh)
require.Equal(stream.ErrSubForceClosed, err) require.Equal(t, stream.ErrSubForceClosed, err)
// Register another subscription. // Register another subscription.
subscription2 := &stream.SubscribeRequest{ subscription2 := &stream.SubscribeRequest{
@ -77,27 +76,27 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
Token: token.SecretID, Token: token.SecretID,
} }
sub2, err := publisher.Subscribe(subscription2) sub2, err := publisher.Subscribe(subscription2)
require.NoError(err) require.NoError(t, err)
defer sub2.Unsubscribe() defer sub2.Unsubscribe()
eventCh2 := testRunSub(sub2) eventCh2 := testRunSub(sub2)
// Expect initial EoS // Expect initial EoS
e = assertEvent(t, eventCh2) e = assertEvent(t, eventCh2)
require.True(e.IsEndOfSnapshot()) require.True(t, e.IsEndOfSnapshot())
// Delete the unrelated token. // Delete the unrelated token.
require.NoError(s.ACLTokenDeleteByAccessor(5, token2.AccessorID, nil)) require.NoError(t, s.ACLTokenDeleteByAccessor(5, token2.AccessorID, nil))
// Ensure there's no reset event. // Ensure there's no reset event.
assertNoEvent(t, eventCh2) assertNoEvent(t, eventCh2)
// Delete the token used by the subscriber. // Delete the token used by the subscriber.
require.NoError(s.ACLTokenDeleteByAccessor(6, token.AccessorID, nil)) require.NoError(t, s.ACLTokenDeleteByAccessor(6, token.AccessorID, nil))
// Ensure the reset event was sent. // Ensure the reset event was sent.
err = assertErr(t, eventCh2) err = assertErr(t, eventCh2)
require.Equal(stream.ErrSubForceClosed, err) require.Equal(t, stream.ErrSubForceClosed, err)
} }
func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) { func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
@ -106,7 +105,6 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
s := testACLTokensStateStore(t) s := testACLTokensStateStore(t)
// Create token and wait for good state // Create token and wait for good state
@ -125,14 +123,14 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
go publisher.Run(ctx) go publisher.Run(ctx)
s.db.publisher = publisher s.db.publisher = publisher
sub, err := publisher.Subscribe(subscription) sub, err := publisher.Subscribe(subscription)
require.NoError(err) require.NoError(t, err)
defer sub.Unsubscribe() defer sub.Unsubscribe()
eventCh := testRunSub(sub) eventCh := testRunSub(sub)
// Ignore the end of snapshot event // Ignore the end of snapshot event
e := assertEvent(t, eventCh) e := assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot(), "event should be a EoS got %v", e) require.True(t, e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
// Update an unrelated policy. // Update an unrelated policy.
policy2 := structs.ACLPolicy{ policy2 := structs.ACLPolicy{
@ -143,7 +141,7 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
Datacenters: []string{"dc1"}, Datacenters: []string{"dc1"},
} }
policy2.SetHash(false) policy2.SetHash(false)
require.NoError(s.ACLPolicySet(3, &policy2)) require.NoError(t, s.ACLPolicySet(3, &policy2))
// Ensure there's no reset event. // Ensure there's no reset event.
assertNoEvent(t, eventCh) assertNoEvent(t, eventCh)
@ -157,7 +155,7 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
Datacenters: []string{"dc1"}, Datacenters: []string{"dc1"},
} }
policy3.SetHash(false) policy3.SetHash(false)
require.NoError(s.ACLPolicySet(4, &policy3)) require.NoError(t, s.ACLPolicySet(4, &policy3))
// Ensure the reset event was sent. // Ensure the reset event was sent.
assertReset(t, eventCh, true) assertReset(t, eventCh, true)
@ -169,27 +167,27 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
Token: token.SecretID, Token: token.SecretID,
} }
sub, err = publisher.Subscribe(subscription2) sub, err = publisher.Subscribe(subscription2)
require.NoError(err) require.NoError(t, err)
defer sub.Unsubscribe() defer sub.Unsubscribe()
eventCh = testRunSub(sub) eventCh = testRunSub(sub)
// Ignore the end of snapshot event // Ignore the end of snapshot event
e = assertEvent(t, eventCh) e = assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot(), "event should be a EoS got %v", e) require.True(t, e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
// Delete the unrelated policy. // Delete the unrelated policy.
require.NoError(s.ACLPolicyDeleteByID(5, testPolicyID_C, nil)) require.NoError(t, s.ACLPolicyDeleteByID(5, testPolicyID_C, nil))
// Ensure there's no reload event. // Ensure there's no reload event.
assertNoEvent(t, eventCh) assertNoEvent(t, eventCh)
// Delete the policy used by the subscriber. // Delete the policy used by the subscriber.
require.NoError(s.ACLPolicyDeleteByID(6, testPolicyID_A, nil)) require.NoError(t, s.ACLPolicyDeleteByID(6, testPolicyID_A, nil))
// Ensure the reload event was sent. // Ensure the reload event was sent.
err = assertErr(t, eventCh) err = assertErr(t, eventCh)
require.Equal(stream.ErrSubForceClosed, err) require.Equal(t, stream.ErrSubForceClosed, err)
// Register another subscription. // Register another subscription.
subscription3 := &stream.SubscribeRequest{ subscription3 := &stream.SubscribeRequest{
@ -198,14 +196,14 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
Token: token.SecretID, Token: token.SecretID,
} }
sub, err = publisher.Subscribe(subscription3) sub, err = publisher.Subscribe(subscription3)
require.NoError(err) require.NoError(t, err)
defer sub.Unsubscribe() defer sub.Unsubscribe()
eventCh = testRunSub(sub) eventCh = testRunSub(sub)
// Ignore the end of snapshot event // Ignore the end of snapshot event
e = assertEvent(t, eventCh) e = assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot(), "event should be a EoS got %v", e) require.True(t, e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
// Now update the policy used in role B, but not directly in the token. // Now update the policy used in role B, but not directly in the token.
policy4 := structs.ACLPolicy{ policy4 := structs.ACLPolicy{
@ -216,7 +214,7 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
Datacenters: []string{"dc1"}, Datacenters: []string{"dc1"},
} }
policy4.SetHash(false) policy4.SetHash(false)
require.NoError(s.ACLPolicySet(7, &policy4)) require.NoError(t, s.ACLPolicySet(7, &policy4))
// Ensure the reset event was sent. // Ensure the reset event was sent.
assertReset(t, eventCh, true) assertReset(t, eventCh, true)
@ -228,7 +226,6 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
s := testACLTokensStateStore(t) s := testACLTokensStateStore(t)
// Create token and wait for good state // Create token and wait for good state
@ -247,13 +244,13 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
go publisher.Run(ctx) go publisher.Run(ctx)
s.db.publisher = publisher s.db.publisher = publisher
sub, err := publisher.Subscribe(subscription) sub, err := publisher.Subscribe(subscription)
require.NoError(err) require.NoError(t, err)
eventCh := testRunSub(sub) eventCh := testRunSub(sub)
// Stream should get EndOfSnapshot // Stream should get EndOfSnapshot
e := assertEvent(t, eventCh) e := assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot()) require.True(t, e.IsEndOfSnapshot())
// Update an unrelated role (the token has role testRoleID_B). // Update an unrelated role (the token has role testRoleID_B).
role := structs.ACLRole{ role := structs.ACLRole{
@ -262,7 +259,7 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
Description: "test", Description: "test",
} }
role.SetHash(false) role.SetHash(false)
require.NoError(s.ACLRoleSet(3, &role)) require.NoError(t, s.ACLRoleSet(3, &role))
// Ensure there's no reload event. // Ensure there's no reload event.
assertNoEvent(t, eventCh) assertNoEvent(t, eventCh)
@ -274,7 +271,7 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
Description: "changed", Description: "changed",
} }
role2.SetHash(false) role2.SetHash(false)
require.NoError(s.ACLRoleSet(4, &role2)) require.NoError(t, s.ACLRoleSet(4, &role2))
// Ensure the reload event was sent. // Ensure the reload event was sent.
assertReset(t, eventCh, false) assertReset(t, eventCh, false)
@ -286,22 +283,22 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
Token: token.SecretID, Token: token.SecretID,
} }
sub, err = publisher.Subscribe(subscription2) sub, err = publisher.Subscribe(subscription2)
require.NoError(err) require.NoError(t, err)
eventCh = testRunSub(sub) eventCh = testRunSub(sub)
// Ignore the end of snapshot event // Ignore the end of snapshot event
e = assertEvent(t, eventCh) e = assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot(), "event should be a EoS got %v", e) require.True(t, e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
// Delete the unrelated policy. // Delete the unrelated policy.
require.NoError(s.ACLRoleDeleteByID(5, testRoleID_A, nil)) require.NoError(t, s.ACLRoleDeleteByID(5, testRoleID_A, nil))
// Ensure there's no reload event. // Ensure there's no reload event.
assertNoEvent(t, eventCh) assertNoEvent(t, eventCh)
// Delete the policy used by the subscriber. // Delete the policy used by the subscriber.
require.NoError(s.ACLRoleDeleteByID(6, testRoleID_B, nil)) require.NoError(t, s.ACLRoleDeleteByID(6, testRoleID_B, nil))
// Ensure the reload event was sent. // Ensure the reload event was sent.
assertReset(t, eventCh, false) assertReset(t, eventCh, false)

View File

@ -314,8 +314,6 @@ func TestTxn_Apply_ACLDeny(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -333,16 +331,16 @@ func TestTxn_Apply_ACLDeny(t *testing.T) {
Key: "nope", Key: "nope",
Value: []byte("hello"), Value: []byte("hello"),
} }
require.NoError(state.KVSSet(1, d)) require.NoError(t, state.KVSSet(1, d))
node := &structs.Node{ node := &structs.Node{
ID: types.NodeID(testNodeID), ID: types.NodeID(testNodeID),
Node: "nope", Node: "nope",
} }
require.NoError(state.EnsureNode(2, node)) require.NoError(t, state.EnsureNode(2, node))
svc := structs.NodeService{ID: "nope", Service: "nope", Address: "127.0.0.1"} svc := structs.NodeService{ID: "nope", Service: "nope", Address: "127.0.0.1"}
require.NoError(state.EnsureService(3, "nope", &svc)) require.NoError(t, state.EnsureService(3, "nope", &svc))
check := structs.HealthCheck{Node: "nope", CheckID: types.CheckID("nope")} check := structs.HealthCheck{Node: "nope", CheckID: types.CheckID("nope")}
state.EnsureCheck(4, &check) state.EnsureCheck(4, &check)
@ -606,7 +604,7 @@ func TestTxn_Apply_ACLDeny(t *testing.T) {
} }
} }
require.Equal(expected, out) require.Equal(t, expected, out)
} }
func TestTxn_Apply_LockDelay(t *testing.T) { func TestTxn_Apply_LockDelay(t *testing.T) {
@ -707,8 +705,6 @@ func TestTxn_Read(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t) dir1, s1 := testServer(t)
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
@ -732,7 +728,7 @@ func TestTxn_Read(t *testing.T) {
ID: types.NodeID(testNodeID), ID: types.NodeID(testNodeID),
Node: "foo", Node: "foo",
} }
require.NoError(state.EnsureNode(2, node)) require.NoError(t, state.EnsureNode(2, node))
svc := structs.NodeService{ svc := structs.NodeService{
ID: "svc-foo", ID: "svc-foo",
@ -740,7 +736,7 @@ func TestTxn_Read(t *testing.T) {
Address: "127.0.0.1", Address: "127.0.0.1",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
require.NoError(state.EnsureService(3, "foo", &svc)) require.NoError(t, state.EnsureService(3, "foo", &svc))
check := structs.HealthCheck{ check := structs.HealthCheck{
Node: "foo", Node: "foo",
@ -823,7 +819,7 @@ func TestTxn_Read(t *testing.T) {
KnownLeader: true, KnownLeader: true,
}, },
} }
require.Equal(expected, out) require.Equal(t, expected, out)
} }
func TestTxn_Read_ACLDeny(t *testing.T) { func TestTxn_Read_ACLDeny(t *testing.T) {
@ -833,8 +829,6 @@ func TestTxn_Read_ACLDeny(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) { dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true c.ACLsEnabled = true
@ -863,10 +857,10 @@ func TestTxn_Read_ACLDeny(t *testing.T) {
ID: types.NodeID(testNodeID), ID: types.NodeID(testNodeID),
Node: "nope", Node: "nope",
} }
require.NoError(state.EnsureNode(2, node)) require.NoError(t, state.EnsureNode(2, node))
svc := structs.NodeService{ID: "nope", Service: "nope", Address: "127.0.0.1"} svc := structs.NodeService{ID: "nope", Service: "nope", Address: "127.0.0.1"}
require.NoError(state.EnsureService(3, "nope", &svc)) require.NoError(t, state.EnsureService(3, "nope", &svc))
check := structs.HealthCheck{Node: "nope", CheckID: types.CheckID("nope")} check := structs.HealthCheck{Node: "nope", CheckID: types.CheckID("nope")}
state.EnsureCheck(4, &check) state.EnsureCheck(4, &check)
@ -899,10 +893,10 @@ func TestTxn_Read_ACLDeny(t *testing.T) {
var out structs.TxnReadResponse var out structs.TxnReadResponse
err := msgpackrpc.CallWithCodec(codec, "Txn.Read", &arg, &out) err := msgpackrpc.CallWithCodec(codec, "Txn.Read", &arg, &out)
require.NoError(err) require.NoError(t, err)
require.Empty(out.Results) require.Empty(t, out.Results)
require.Empty(out.Errors) require.Empty(t, out.Errors)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true") require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
}) })
t.Run("complex operations (return permission denied errors)", func(t *testing.T) { t.Run("complex operations (return permission denied errors)", func(t *testing.T) {
@ -931,11 +925,11 @@ func TestTxn_Read_ACLDeny(t *testing.T) {
var out structs.TxnReadResponse var out structs.TxnReadResponse
err := msgpackrpc.CallWithCodec(codec, "Txn.Read", &arg, &out) err := msgpackrpc.CallWithCodec(codec, "Txn.Read", &arg, &out)
require.NoError(err) require.NoError(t, err)
require.Equal(structs.TxnErrors{ require.Equal(t, structs.TxnErrors{
{OpIndex: 0, What: acl.ErrPermissionDenied.Error()}, {OpIndex: 0, What: acl.ErrPermissionDenied.Error()},
{OpIndex: 1, What: acl.ErrPermissionDenied.Error()}, {OpIndex: 1, What: acl.ErrPermissionDenied.Error()},
}, out.Errors) }, out.Errors)
require.Empty(out.Results) require.Empty(t, out.Results)
}) })
} }

View File

@ -7,14 +7,13 @@ import (
) )
func TestCollectHostInfo(t *testing.T) { func TestCollectHostInfo(t *testing.T) {
assert := assert.New(t)
host := CollectHostInfo() host := CollectHostInfo()
assert.Nil(host.Errors) assert.Nil(t, host.Errors)
assert.NotNil(host.CollectionTime) assert.NotNil(t, host.CollectionTime)
assert.NotNil(host.Host) assert.NotNil(t, host.Host)
assert.NotNil(host.Disk) assert.NotNil(t, host.Disk)
assert.NotNil(host.Memory) assert.NotNil(t, host.Memory)
} }

View File

@ -611,9 +611,6 @@ func TestHealthServiceNodes(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
assert := assert.New(t)
require := require.New(t)
req, _ := http.NewRequest("GET", "/v1/health/service/consul?dc=dc1", nil) req, _ := http.NewRequest("GET", "/v1/health/service/consul?dc=dc1", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.HealthServiceNodes(resp, req) obj, err := a.srv.HealthServiceNodes(resp, req)
@ -680,12 +677,12 @@ func TestHealthServiceNodes(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/health/service/test?cached", nil) req, _ := http.NewRequest("GET", "/v1/health/service/test?cached", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.HealthServiceNodes(resp, req) obj, err := a.srv.HealthServiceNodes(resp, req)
require.NoError(err) require.NoError(t, err)
nodes := obj.(structs.CheckServiceNodes) nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 1) assert.Len(t, nodes, 1)
// Should be a cache miss // Should be a cache miss
assert.Equal("MISS", resp.Header().Get("X-Cache")) assert.Equal(t, "MISS", resp.Header().Get("X-Cache"))
} }
{ {
@ -693,12 +690,12 @@ func TestHealthServiceNodes(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/health/service/test?cached", nil) req, _ := http.NewRequest("GET", "/v1/health/service/test?cached", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.HealthServiceNodes(resp, req) obj, err := a.srv.HealthServiceNodes(resp, req)
require.NoError(err) require.NoError(t, err)
nodes := obj.(structs.CheckServiceNodes) nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 1) assert.Len(t, nodes, 1)
// Should be a cache HIT now! // Should be a cache HIT now!
assert.Equal("HIT", resp.Header().Get("X-Cache")) assert.Equal(t, "HIT", resp.Header().Get("X-Cache"))
} }
// Ensure background refresh works // Ensure background refresh works
@ -707,7 +704,7 @@ func TestHealthServiceNodes(t *testing.T) {
args2 := args args2 := args
args2.Node = "baz" args2.Node = "baz"
args2.Address = "127.0.0.2" args2.Address = "127.0.0.2"
require.NoError(a.RPC("Catalog.Register", args, &out)) require.NoError(t, a.RPC("Catalog.Register", args, &out))
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
// List it again // List it again
@ -1414,27 +1411,26 @@ func TestHealthConnectServiceNodes(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
// Register // Register
args := structs.TestRegisterRequestProxy(t) args := structs.TestRegisterRequestProxy(t)
var out struct{} var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out)) assert.Nil(t, a.RPC("Catalog.Register", args, &out))
// Request // Request
req, _ := http.NewRequest("GET", fmt.Sprintf( req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?dc=dc1", args.Service.Proxy.DestinationServiceName), nil) "/v1/health/connect/%s?dc=dc1", args.Service.Proxy.DestinationServiceName), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req) obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err) assert.Nil(t, err)
assertIndex(t, resp) assertIndex(t, resp)
// Should be a non-nil empty list for checks // Should be a non-nil empty list for checks
nodes := obj.(structs.CheckServiceNodes) nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 1) assert.Len(t, nodes, 1)
assert.Len(nodes[0].Checks, 0) assert.Len(t, nodes[0].Checks, 0)
} }
func TestHealthIngressServiceNodes(t *testing.T) { func TestHealthIngressServiceNodes(t *testing.T) {
@ -1616,58 +1612,54 @@ func TestHealthConnectServiceNodes_PassingFilter(t *testing.T) {
assert.Nil(t, a.RPC("Catalog.Register", args, &out)) assert.Nil(t, a.RPC("Catalog.Register", args, &out))
t.Run("bc_no_query_value", func(t *testing.T) { t.Run("bc_no_query_value", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf( req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing", args.Service.Proxy.DestinationServiceName), nil) "/v1/health/connect/%s?passing", args.Service.Proxy.DestinationServiceName), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req) obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err) assert.Nil(t, err)
assertIndex(t, resp) assertIndex(t, resp)
// Should be 0 health check for consul // Should be 0 health check for consul
nodes := obj.(structs.CheckServiceNodes) nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 0) assert.Len(t, nodes, 0)
}) })
t.Run("passing_true", func(t *testing.T) { t.Run("passing_true", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf( req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing=true", args.Service.Proxy.DestinationServiceName), nil) "/v1/health/connect/%s?passing=true", args.Service.Proxy.DestinationServiceName), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req) obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err) assert.Nil(t, err)
assertIndex(t, resp) assertIndex(t, resp)
// Should be 0 health check for consul // Should be 0 health check for consul
nodes := obj.(structs.CheckServiceNodes) nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 0) assert.Len(t, nodes, 0)
}) })
t.Run("passing_false", func(t *testing.T) { t.Run("passing_false", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf( req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing=false", args.Service.Proxy.DestinationServiceName), nil) "/v1/health/connect/%s?passing=false", args.Service.Proxy.DestinationServiceName), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req) obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err) assert.Nil(t, err)
assertIndex(t, resp) assertIndex(t, resp)
// Should be 1 // Should be 1
nodes := obj.(structs.CheckServiceNodes) nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 1) assert.Len(t, nodes, 1)
}) })
t.Run("passing_bad", func(t *testing.T) { t.Run("passing_bad", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf( req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing=nope-nope", args.Service.Proxy.DestinationServiceName), nil) "/v1/health/connect/%s?passing=nope-nope", args.Service.Proxy.DestinationServiceName), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.HealthConnectServiceNodes(resp, req) a.srv.HealthConnectServiceNodes(resp, req)
assert.Equal(400, resp.Code) assert.Equal(t, 400, resp.Code)
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
assert.Nil(err) assert.Nil(t, err)
assert.True(bytes.Contains(body, []byte("Invalid value for ?passing"))) assert.True(t, bytes.Contains(body, []byte("Invalid value for ?passing")))
}) })
} }

View File

@ -907,7 +907,6 @@ func TestParseCacheControl(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
r, _ := http.NewRequest("GET", "/foo/bar", nil) r, _ := http.NewRequest("GET", "/foo/bar", nil)
if tt.headerVal != "" { if tt.headerVal != "" {
@ -919,13 +918,13 @@ func TestParseCacheControl(t *testing.T) {
failed := parseCacheControl(rr, r, &got) failed := parseCacheControl(rr, r, &got)
if tt.wantErr { if tt.wantErr {
require.True(failed) require.True(t, failed)
require.Equal(http.StatusBadRequest, rr.Code) require.Equal(t, http.StatusBadRequest, rr.Code)
} else { } else {
require.False(failed) require.False(t, failed)
} }
require.Equal(tt.want, got) require.Equal(t, tt.want, got)
}) })
} }
} }
@ -990,7 +989,6 @@ func TestHTTPServer_PProfHandlers_ACLs(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
dc1 := "dc1" dc1 := "dc1"
a := NewTestAgent(t, ` a := NewTestAgent(t, `
@ -1062,7 +1060,7 @@ func TestHTTPServer_PProfHandlers_ACLs(t *testing.T) {
req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil) req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.handler(true).ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
assert.Equal(c.code, resp.Code) assert.Equal(t, c.code, resp.Code)
}) })
} }
} }

View File

@ -261,7 +261,6 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, "") a := agent.NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -289,7 +288,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
} }
a.State.AddService(srv1, "") a.State.AddService(srv1, "")
args.Service = srv1 args.Service = srv1
assert.Nil(a.RPC("Catalog.Register", args, &out)) assert.Nil(t, a.RPC("Catalog.Register", args, &out))
// Exists both, different (update) // Exists both, different (update)
srv2 := &structs.NodeService{ srv2 := &structs.NodeService{
@ -310,7 +309,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
*srv2_mod = *srv2 *srv2_mod = *srv2
srv2_mod.Port = 9000 srv2_mod.Port = 9000
args.Service = srv2_mod args.Service = srv2_mod
assert.Nil(a.RPC("Catalog.Register", args, &out)) assert.Nil(t, a.RPC("Catalog.Register", args, &out))
// Exists local (create) // Exists local (create)
srv3 := &structs.NodeService{ srv3 := &structs.NodeService{
@ -341,7 +340,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
args.Service = srv4 args.Service = srv4
assert.Nil(a.RPC("Catalog.Register", args, &out)) assert.Nil(t, a.RPC("Catalog.Register", args, &out))
// Exists local, in sync, remote missing (create) // Exists local, in sync, remote missing (create)
srv5 := &structs.NodeService{ srv5 := &structs.NodeService{
@ -361,28 +360,28 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
InSync: true, InSync: true,
}) })
assert.Nil(a.State.SyncFull()) assert.Nil(t, a.State.SyncFull())
var services structs.IndexedNodeServices var services structs.IndexedNodeServices
req := structs.NodeSpecificRequest{ req := structs.NodeSpecificRequest{
Datacenter: "dc1", Datacenter: "dc1",
Node: a.Config.NodeName, Node: a.Config.NodeName,
} }
assert.Nil(a.RPC("Catalog.NodeServices", &req, &services)) assert.Nil(t, a.RPC("Catalog.NodeServices", &req, &services))
// We should have 5 services (consul included) // We should have 5 services (consul included)
assert.Len(services.NodeServices.Services, 5) assert.Len(t, services.NodeServices.Services, 5)
// Check that virtual IPs have been set // Check that virtual IPs have been set
vips := make(map[string]struct{}) vips := make(map[string]struct{})
for _, serv := range services.NodeServices.Services { for _, serv := range services.NodeServices.Services {
if serv.TaggedAddresses != nil { if serv.TaggedAddresses != nil {
serviceVIP := serv.TaggedAddresses[structs.TaggedAddressVirtualIP].Address serviceVIP := serv.TaggedAddresses[structs.TaggedAddressVirtualIP].Address
assert.NotEmpty(serviceVIP) assert.NotEmpty(t, serviceVIP)
vips[serviceVIP] = struct{}{} vips[serviceVIP] = struct{}{}
} }
} }
assert.Len(vips, 4) assert.Len(t, vips, 4)
// All the services should match // All the services should match
// Retry to mitigate data races between local and remote state // Retry to mitigate data races between local and remote state
@ -407,26 +406,26 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
} }
}) })
assert.NoError(servicesInSync(a.State, 4, structs.DefaultEnterpriseMetaInDefaultPartition())) assert.NoError(t, servicesInSync(a.State, 4, structs.DefaultEnterpriseMetaInDefaultPartition()))
// Remove one of the services // Remove one of the services
a.State.RemoveService(structs.NewServiceID("cache-proxy", nil)) a.State.RemoveService(structs.NewServiceID("cache-proxy", nil))
assert.Nil(a.State.SyncFull()) assert.Nil(t, a.State.SyncFull())
assert.Nil(a.RPC("Catalog.NodeServices", &req, &services)) assert.Nil(t, a.RPC("Catalog.NodeServices", &req, &services))
// We should have 4 services (consul included) // We should have 4 services (consul included)
assert.Len(services.NodeServices.Services, 4) assert.Len(t, services.NodeServices.Services, 4)
// All the services should match // All the services should match
for id, serv := range services.NodeServices.Services { for id, serv := range services.NodeServices.Services {
serv.CreateIndex, serv.ModifyIndex = 0, 0 serv.CreateIndex, serv.ModifyIndex = 0, 0
switch id { switch id {
case "mysql-proxy": case "mysql-proxy":
assert.Equal(srv1, serv) assert.Equal(t, srv1, serv)
case "redis-proxy": case "redis-proxy":
assert.Equal(srv2, serv) assert.Equal(t, srv2, serv)
case "web-proxy": case "web-proxy":
assert.Equal(srv3, serv) assert.Equal(t, srv3, serv)
case structs.ConsulServiceID: case structs.ConsulServiceID:
// ignore // ignore
default: default:
@ -434,7 +433,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
} }
} }
assert.Nil(servicesInSync(a.State, 3, structs.DefaultEnterpriseMetaInDefaultPartition())) assert.Nil(t, servicesInSync(a.State, 3, structs.DefaultEnterpriseMetaInDefaultPartition()))
} }
func TestAgent_ServiceWatchCh(t *testing.T) { func TestAgent_ServiceWatchCh(t *testing.T) {
@ -447,8 +446,6 @@ func TestAgent_ServiceWatchCh(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
require := require.New(t)
// register a local service // register a local service
srv1 := &structs.NodeService{ srv1 := &structs.NodeService{
ID: "svc_id1", ID: "svc_id1",
@ -456,11 +453,11 @@ func TestAgent_ServiceWatchCh(t *testing.T) {
Tags: []string{"tag1"}, Tags: []string{"tag1"},
Port: 6100, Port: 6100,
} }
require.NoError(a.State.AddService(srv1, "")) require.NoError(t, a.State.AddService(srv1, ""))
verifyState := func(ss *local.ServiceState) { verifyState := func(ss *local.ServiceState) {
require.NotNil(ss) require.NotNil(t, ss)
require.NotNil(ss.WatchCh) require.NotNil(t, ss.WatchCh)
// Sanity check WatchCh blocks // Sanity check WatchCh blocks
select { select {
@ -478,7 +475,7 @@ func TestAgent_ServiceWatchCh(t *testing.T) {
go func() { go func() {
srv2 := srv1 srv2 := srv1
srv2.Port = 6200 srv2.Port = 6200
require.NoError(a.State.AddService(srv2, "")) require.NoError(t, a.State.AddService(srv2, ""))
}() }()
// We should observe WatchCh close // We should observe WatchCh close
@ -513,7 +510,7 @@ func TestAgent_ServiceWatchCh(t *testing.T) {
verifyState(ss) verifyState(ss)
go func() { go func() {
require.NoError(a.State.RemoveService(srv1.CompoundServiceID())) require.NoError(t, a.State.RemoveService(srv1.CompoundServiceID()))
}() }()
// We should observe WatchCh close // We should observe WatchCh close
@ -1966,20 +1963,19 @@ func TestAgent_AddCheckFailure(t *testing.T) {
func TestAgent_AliasCheck(t *testing.T) { func TestAgent_AliasCheck(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
cfg := loadRuntimeConfig(t, `bind_addr = "127.0.0.1" data_dir = "dummy" node_name = "dummy"`) cfg := loadRuntimeConfig(t, `bind_addr = "127.0.0.1" data_dir = "dummy" node_name = "dummy"`)
l := local.NewState(agent.LocalConfig(cfg), nil, new(token.Store)) l := local.NewState(agent.LocalConfig(cfg), nil, new(token.Store))
l.TriggerSyncChanges = func() {} l.TriggerSyncChanges = func() {}
// Add checks // Add checks
require.NoError(l.AddService(&structs.NodeService{Service: "s1"}, "")) require.NoError(t, l.AddService(&structs.NodeService{Service: "s1"}, ""))
require.NoError(l.AddService(&structs.NodeService{Service: "s2"}, "")) require.NoError(t, l.AddService(&structs.NodeService{Service: "s2"}, ""))
require.NoError(l.AddCheck(&structs.HealthCheck{CheckID: types.CheckID("c1"), ServiceID: "s1"}, "")) require.NoError(t, l.AddCheck(&structs.HealthCheck{CheckID: types.CheckID("c1"), ServiceID: "s1"}, ""))
require.NoError(l.AddCheck(&structs.HealthCheck{CheckID: types.CheckID("c2"), ServiceID: "s2"}, "")) require.NoError(t, l.AddCheck(&structs.HealthCheck{CheckID: types.CheckID("c2"), ServiceID: "s2"}, ""))
// Add an alias // Add an alias
notifyCh := make(chan struct{}, 1) notifyCh := make(chan struct{}, 1)
require.NoError(l.AddAliasCheck(structs.NewCheckID(types.CheckID("a1"), nil), structs.NewServiceID("s1", nil), notifyCh)) require.NoError(t, l.AddAliasCheck(structs.NewCheckID(types.CheckID("a1"), nil), structs.NewServiceID("s1", nil), notifyCh))
// Update and verify we get notified // Update and verify we get notified
l.UpdateCheck(structs.NewCheckID(types.CheckID("c1"), nil), api.HealthCritical, "") l.UpdateCheck(structs.NewCheckID(types.CheckID("c1"), nil), api.HealthCritical, "")
@ -2017,17 +2013,16 @@ func TestAgent_AliasCheck(t *testing.T) {
func TestAgent_AliasCheck_ServiceNotification(t *testing.T) { func TestAgent_AliasCheck_ServiceNotification(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
cfg := loadRuntimeConfig(t, `bind_addr = "127.0.0.1" data_dir = "dummy" node_name = "dummy"`) cfg := loadRuntimeConfig(t, `bind_addr = "127.0.0.1" data_dir = "dummy" node_name = "dummy"`)
l := local.NewState(agent.LocalConfig(cfg), nil, new(token.Store)) l := local.NewState(agent.LocalConfig(cfg), nil, new(token.Store))
l.TriggerSyncChanges = func() {} l.TriggerSyncChanges = func() {}
// Add an alias check for service s1 // Add an alias check for service s1
notifyCh := make(chan struct{}, 1) notifyCh := make(chan struct{}, 1)
require.NoError(l.AddAliasCheck(structs.NewCheckID(types.CheckID("a1"), nil), structs.NewServiceID("s1", nil), notifyCh)) require.NoError(t, l.AddAliasCheck(structs.NewCheckID(types.CheckID("a1"), nil), structs.NewServiceID("s1", nil), notifyCh))
// Add aliased service, s1, and verify we get notified // Add aliased service, s1, and verify we get notified
require.NoError(l.AddService(&structs.NodeService{Service: "s1"}, "")) require.NoError(t, l.AddService(&structs.NodeService{Service: "s1"}, ""))
select { select {
case <-notifyCh: case <-notifyCh:
default: default:
@ -2035,7 +2030,7 @@ func TestAgent_AliasCheck_ServiceNotification(t *testing.T) {
} }
// Re-adding same service should not lead to a notification // Re-adding same service should not lead to a notification
require.NoError(l.AddService(&structs.NodeService{Service: "s1"}, "")) require.NoError(t, l.AddService(&structs.NodeService{Service: "s1"}, ""))
select { select {
case <-notifyCh: case <-notifyCh:
t.Fatal("notify received") t.Fatal("notify received")
@ -2043,7 +2038,7 @@ func TestAgent_AliasCheck_ServiceNotification(t *testing.T) {
} }
// Add different service and verify we do not get notified // Add different service and verify we do not get notified
require.NoError(l.AddService(&structs.NodeService{Service: "s2"}, "")) require.NoError(t, l.AddService(&structs.NodeService{Service: "s2"}, ""))
select { select {
case <-notifyCh: case <-notifyCh:
t.Fatal("notify received") t.Fatal("notify received")
@ -2051,7 +2046,7 @@ func TestAgent_AliasCheck_ServiceNotification(t *testing.T) {
} }
// Delete service and verify we get notified // Delete service and verify we get notified
require.NoError(l.RemoveService(structs.NewServiceID("s1", nil))) require.NoError(t, l.RemoveService(structs.NewServiceID("s1", nil)))
select { select {
case <-notifyCh: case <-notifyCh:
default: default:
@ -2059,7 +2054,7 @@ func TestAgent_AliasCheck_ServiceNotification(t *testing.T) {
} }
// Delete different service and verify we do not get notified // Delete different service and verify we do not get notified
require.NoError(l.RemoveService(structs.NewServiceID("s2", nil))) require.NoError(t, l.RemoveService(structs.NewServiceID("s2", nil)))
select { select {
case <-notifyCh: case <-notifyCh:
t.Fatal("notify received") t.Fatal("notify received")
@ -2144,28 +2139,26 @@ func TestState_RemoveServiceErrorMessages(t *testing.T) {
// Stub state syncing // Stub state syncing
state.TriggerSyncChanges = func() {} state.TriggerSyncChanges = func() {}
require := require.New(t)
// Add 1 service // Add 1 service
err := state.AddService(&structs.NodeService{ err := state.AddService(&structs.NodeService{
ID: "web-id", ID: "web-id",
Service: "web-name", Service: "web-name",
}, "") }, "")
require.NoError(err) require.NoError(t, err)
// Attempt to remove service that doesn't exist // Attempt to remove service that doesn't exist
sid := structs.NewServiceID("db", nil) sid := structs.NewServiceID("db", nil)
err = state.RemoveService(sid) err = state.RemoveService(sid)
require.Contains(err.Error(), fmt.Sprintf(`Unknown service ID %q`, sid)) require.Contains(t, err.Error(), fmt.Sprintf(`Unknown service ID %q`, sid))
// Attempt to remove service by name (which isn't valid) // Attempt to remove service by name (which isn't valid)
sid2 := structs.NewServiceID("web-name", nil) sid2 := structs.NewServiceID("web-name", nil)
err = state.RemoveService(sid2) err = state.RemoveService(sid2)
require.Contains(err.Error(), fmt.Sprintf(`Unknown service ID %q`, sid2)) require.Contains(t, err.Error(), fmt.Sprintf(`Unknown service ID %q`, sid2))
// Attempt to remove service by id (valid) // Attempt to remove service by id (valid)
err = state.RemoveService(structs.NewServiceID("web-id", nil)) err = state.RemoveService(structs.NewServiceID("web-id", nil))
require.NoError(err) require.NoError(t, err)
} }
func TestState_Notify(t *testing.T) { func TestState_Notify(t *testing.T) {
@ -2180,24 +2173,21 @@ func TestState_Notify(t *testing.T) {
// Stub state syncing // Stub state syncing
state.TriggerSyncChanges = func() {} state.TriggerSyncChanges = func() {}
require := require.New(t)
assert := assert.New(t)
// Register a notifier // Register a notifier
notifyCh := make(chan struct{}, 1) notifyCh := make(chan struct{}, 1)
state.Notify(notifyCh) state.Notify(notifyCh)
defer state.StopNotify(notifyCh) defer state.StopNotify(notifyCh)
assert.Empty(notifyCh) assert.Empty(t, notifyCh)
drainCh(notifyCh) drainCh(notifyCh)
// Add a service // Add a service
err := state.AddService(&structs.NodeService{ err := state.AddService(&structs.NodeService{
Service: "web", Service: "web",
}, "fake-token-web") }, "fake-token-web")
require.NoError(err) require.NoError(t, err)
// Should have a notification // Should have a notification
assert.NotEmpty(notifyCh) assert.NotEmpty(t, notifyCh)
drainCh(notifyCh) drainCh(notifyCh)
// Re-Add same service // Re-Add same service
@ -2205,17 +2195,17 @@ func TestState_Notify(t *testing.T) {
Service: "web", Service: "web",
Port: 4444, Port: 4444,
}, "fake-token-web") }, "fake-token-web")
require.NoError(err) require.NoError(t, err)
// Should have a notification // Should have a notification
assert.NotEmpty(notifyCh) assert.NotEmpty(t, notifyCh)
drainCh(notifyCh) drainCh(notifyCh)
// Remove service // Remove service
require.NoError(state.RemoveService(structs.NewServiceID("web", nil))) require.NoError(t, state.RemoveService(structs.NewServiceID("web", nil)))
// Should have a notification // Should have a notification
assert.NotEmpty(notifyCh) assert.NotEmpty(t, notifyCh)
drainCh(notifyCh) drainCh(notifyCh)
// Stopping should... stop // Stopping should... stop
@ -2225,10 +2215,10 @@ func TestState_Notify(t *testing.T) {
err = state.AddService(&structs.NodeService{ err = state.AddService(&structs.NodeService{
Service: "web", Service: "web",
}, "fake-token-web") }, "fake-token-web")
require.NoError(err) require.NoError(t, err)
// Should NOT have a notification // Should NOT have a notification
assert.Empty(notifyCh) assert.Empty(t, notifyCh)
drainCh(notifyCh) drainCh(notifyCh)
} }

View File

@ -663,15 +663,14 @@ func TestPreparedQuery_ExecuteCached(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.PreparedQuerySpecific(resp, req) obj, err := a.srv.PreparedQuerySpecific(resp, req)
require := require.New(t) require.NoError(t, err)
require.NoError(err) require.Equal(t, 200, resp.Code)
require.Equal(200, resp.Code)
r, ok := obj.(structs.PreparedQueryExecuteResponse) r, ok := obj.(structs.PreparedQueryExecuteResponse)
require.True(ok) require.True(t, ok)
require.Equal(expectFailovers, r.Failovers) require.Equal(t, expectFailovers, r.Failovers)
require.Equal(expectCache, resp.Header().Get("X-Cache")) require.Equal(t, expectCache, resp.Header().Get("X-Cache"))
} }
// Should be a miss at first // Should be a miss at first
@ -770,22 +769,21 @@ func TestPreparedQuery_Explain(t *testing.T) {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
require := require.New(t)
m := MockPreparedQuery{ m := MockPreparedQuery{
executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error { executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
require.True(args.Connect) require.True(t, args.Connect)
return nil return nil
}, },
} }
require.NoError(a.registerEndpoint("PreparedQuery", &m)) require.NoError(t, a.registerEndpoint("PreparedQuery", &m))
body := bytes.NewBuffer(nil) body := bytes.NewBuffer(nil)
req, _ := http.NewRequest("GET", "/v1/query/my-id/execute?connect=true", body) req, _ := http.NewRequest("GET", "/v1/query/my-id/execute?connect=true", body)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.PreparedQuerySpecific(resp, req) _, err := a.srv.PreparedQuerySpecific(resp, req)
require.NoError(err) require.NoError(t, err)
require.Equal(200, resp.Code) require.Equal(t, 200, resp.Code)
}) })
} }

View File

@ -354,7 +354,6 @@ func testManager_BasicLifecycle(
) { ) {
c := TestCacheWithTypes(t, types) c := TestCacheWithTypes(t, types)
require := require.New(t)
logger := testutil.Logger(t) logger := testutil.Logger(t)
state := local.NewState(agentConfig, logger, &token.Store{}) state := local.NewState(agentConfig, logger, &token.Store{})
source := &structs.QuerySource{Datacenter: "dc1"} source := &structs.QuerySource{Datacenter: "dc1"}
@ -370,12 +369,12 @@ func testManager_BasicLifecycle(
Source: source, Source: source,
Logger: logger, Logger: logger,
}) })
require.NoError(err) require.NoError(t, err)
// And run it // And run it
go func() { go func() {
err := m.Run() err := m.Run()
require.NoError(err) require.NoError(t, err)
}() }()
// BEFORE we register, we should be able to get a watch channel // BEFORE we register, we should be able to get a watch channel
@ -385,19 +384,19 @@ func testManager_BasicLifecycle(
// And it should block with nothing sent on it yet // And it should block with nothing sent on it yet
assertWatchChanBlocks(t, wCh) assertWatchChanBlocks(t, wCh)
require.NoError(state.AddService(webProxy, "my-token")) require.NoError(t, state.AddService(webProxy, "my-token"))
// We should see the initial config delivered but not until after the // We should see the initial config delivered but not until after the
// coalesce timeout // coalesce timeout
start := time.Now() start := time.Now()
assertWatchChanRecvs(t, wCh, expectSnap) assertWatchChanRecvs(t, wCh, expectSnap)
require.True(time.Since(start) >= coalesceTimeout) require.True(t, time.Since(start) >= coalesceTimeout)
assertLastReqArgs(t, types, "my-token", source) assertLastReqArgs(t, types, "my-token", source)
// Update NodeConfig // Update NodeConfig
webProxy.Port = 7777 webProxy.Port = 7777
require.NoError(state.AddService(webProxy, "my-token")) require.NoError(t, state.AddService(webProxy, "my-token"))
expectSnap.Port = 7777 expectSnap.Port = 7777
assertWatchChanRecvs(t, wCh, expectSnap) assertWatchChanRecvs(t, wCh, expectSnap)
@ -410,7 +409,7 @@ func testManager_BasicLifecycle(
assertWatchChanRecvs(t, wCh2, expectSnap) assertWatchChanRecvs(t, wCh2, expectSnap)
// Change token // Change token
require.NoError(state.AddService(webProxy, "other-token")) require.NoError(t, state.AddService(webProxy, "other-token"))
assertWatchChanRecvs(t, wCh, expectSnap) assertWatchChanRecvs(t, wCh, expectSnap)
assertWatchChanRecvs(t, wCh2, expectSnap) assertWatchChanRecvs(t, wCh2, expectSnap)
@ -445,7 +444,7 @@ func testManager_BasicLifecycle(
// Re-add the proxy with another new port // Re-add the proxy with another new port
webProxy.Port = 3333 webProxy.Port = 3333
require.NoError(state.AddService(webProxy, "other-token")) require.NoError(t, state.AddService(webProxy, "other-token"))
// Same watch chan should be notified again // Same watch chan should be notified again
expectSnap.Port = 3333 expectSnap.Port = 3333
@ -460,13 +459,13 @@ func testManager_BasicLifecycle(
// We specifically don't remove the proxy or cancel the second watcher to // We specifically don't remove the proxy or cancel the second watcher to
// ensure both are cleaned up by close. // ensure both are cleaned up by close.
require.NoError(m.Close()) require.NoError(t, m.Close())
// Sanity check the state is clean // Sanity check the state is clean
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
require.Len(m.proxies, 0) require.Len(t, m.proxies, 0)
require.Len(m.watchers, 0) require.Len(t, m.watchers, 0)
} }
func assertWatchChanBlocks(t *testing.T, ch <-chan *ConfigSnapshot) { func assertWatchChanBlocks(t *testing.T, ch <-chan *ConfigSnapshot) {
@ -505,10 +504,9 @@ func TestManager_deliverLatest(t *testing.T) {
}, },
Logger: logger, Logger: logger,
} }
require := require.New(t)
m, err := NewManager(cfg) m, err := NewManager(cfg)
require.NoError(err) require.NoError(t, err)
snap1 := &ConfigSnapshot{ snap1 := &ConfigSnapshot{
ProxyID: structs.NewServiceID("test-proxy", nil), ProxyID: structs.NewServiceID("test-proxy", nil),
@ -526,14 +524,14 @@ func TestManager_deliverLatest(t *testing.T) {
m.deliverLatest(snap1, ch1) m.deliverLatest(snap1, ch1)
// Check it was delivered // Check it was delivered
require.Equal(snap1, <-ch1) require.Equal(t, snap1, <-ch1)
// Now send both without reading simulating a slow client // Now send both without reading simulating a slow client
m.deliverLatest(snap1, ch1) m.deliverLatest(snap1, ch1)
m.deliverLatest(snap2, ch1) m.deliverLatest(snap2, ch1)
// Check we got the _second_ one // Check we got the _second_ one
require.Equal(snap2, <-ch1) require.Equal(t, snap2, <-ch1)
// Same again for 5-buffered chan // Same again for 5-buffered chan
ch5 := make(chan *ConfigSnapshot, 5) ch5 := make(chan *ConfigSnapshot, 5)
@ -542,7 +540,7 @@ func TestManager_deliverLatest(t *testing.T) {
m.deliverLatest(snap1, ch5) m.deliverLatest(snap1, ch5)
// Check it was delivered // Check it was delivered
require.Equal(snap1, <-ch5) require.Equal(t, snap1, <-ch5)
// Now send enough to fill the chan simulating a slow client // Now send enough to fill the chan simulating a slow client
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
@ -551,7 +549,7 @@ func TestManager_deliverLatest(t *testing.T) {
m.deliverLatest(snap2, ch5) m.deliverLatest(snap2, ch5)
// Check we got the _second_ one // Check we got the _second_ one
require.Equal(snap2, <-ch5) require.Equal(t, snap2, <-ch5)
} }
func testGenCacheKey(req cache.Request) string { func testGenCacheKey(req cache.Request) string {

View File

@ -115,11 +115,10 @@ func TestStateChanged(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
state, err := newState(tt.ns, tt.token, stateConfig{logger: hclog.New(nil)}) state, err := newState(tt.ns, tt.token, stateConfig{logger: hclog.New(nil)})
require.NoError(err) require.NoError(t, err)
otherNS, otherToken := tt.mutate(*tt.ns, tt.token) otherNS, otherToken := tt.mutate(*tt.ns, tt.token)
require.Equal(tt.want, state.Changed(otherNS, otherToken)) require.Equal(t, tt.want, state.Changed(otherNS, otherToken))
}) })
} }
} }

View File

@ -23,8 +23,6 @@ func TestServiceManager_RegisterService(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
require := require.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
@ -51,12 +49,12 @@ func TestServiceManager_RegisterService(t *testing.T) {
Port: 8000, Port: 8000,
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
require.NoError(a.addServiceFromSource(svc, nil, false, "", ConfigSourceLocal)) require.NoError(t, a.addServiceFromSource(svc, nil, false, "", ConfigSourceLocal))
// Verify both the service and sidecar. // Verify both the service and sidecar.
redisService := a.State.Service(structs.NewServiceID("redis", nil)) redisService := a.State.Service(structs.NewServiceID("redis", nil))
require.NotNil(redisService) require.NotNil(t, redisService)
require.Equal(&structs.NodeService{ require.Equal(t, &structs.NodeService{
ID: "redis", ID: "redis",
Service: "redis", Service: "redis",
Port: 8000, Port: 8000,
@ -74,8 +72,6 @@ func TestServiceManager_RegisterSidecar(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
require := require.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
@ -124,12 +120,12 @@ func TestServiceManager_RegisterSidecar(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
require.NoError(a.addServiceFromSource(svc, nil, false, "", ConfigSourceLocal)) require.NoError(t, a.addServiceFromSource(svc, nil, false, "", ConfigSourceLocal))
// Verify sidecar got global config loaded // Verify sidecar got global config loaded
sidecarService := a.State.Service(structs.NewServiceID("web-sidecar-proxy", nil)) sidecarService := a.State.Service(structs.NewServiceID("web-sidecar-proxy", nil))
require.NotNil(sidecarService) require.NotNil(t, sidecarService)
require.Equal(&structs.NodeService{ require.Equal(t, &structs.NodeService{
Kind: structs.ServiceKindConnectProxy, Kind: structs.ServiceKindConnectProxy,
ID: "web-sidecar-proxy", ID: "web-sidecar-proxy",
Service: "web-sidecar-proxy", Service: "web-sidecar-proxy",
@ -169,8 +165,6 @@ func TestServiceManager_RegisterMeshGateway(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
require := require.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
@ -199,12 +193,12 @@ func TestServiceManager_RegisterMeshGateway(t *testing.T) {
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
require.NoError(a.addServiceFromSource(svc, nil, false, "", ConfigSourceLocal)) require.NoError(t, a.addServiceFromSource(svc, nil, false, "", ConfigSourceLocal))
// Verify gateway got global config loaded // Verify gateway got global config loaded
gateway := a.State.Service(structs.NewServiceID("mesh-gateway", nil)) gateway := a.State.Service(structs.NewServiceID("mesh-gateway", nil))
require.NotNil(gateway) require.NotNil(t, gateway)
require.Equal(&structs.NodeService{ require.Equal(t, &structs.NodeService{
Kind: structs.ServiceKindMeshGateway, Kind: structs.ServiceKindMeshGateway,
ID: "mesh-gateway", ID: "mesh-gateway",
Service: "mesh-gateway", Service: "mesh-gateway",
@ -229,8 +223,6 @@ func TestServiceManager_RegisterTerminatingGateway(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
require := require.New(t)
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown() defer a.Shutdown()
@ -259,12 +251,12 @@ func TestServiceManager_RegisterTerminatingGateway(t *testing.T) {
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
require.NoError(a.addServiceFromSource(svc, nil, false, "", ConfigSourceLocal)) require.NoError(t, a.addServiceFromSource(svc, nil, false, "", ConfigSourceLocal))
// Verify gateway got global config loaded // Verify gateway got global config loaded
gateway := a.State.Service(structs.NewServiceID("terminating-gateway", nil)) gateway := a.State.Service(structs.NewServiceID("terminating-gateway", nil))
require.NotNil(gateway) require.NotNil(t, gateway)
require.Equal(&structs.NodeService{ require.Equal(t, &structs.NodeService{
Kind: structs.ServiceKindTerminatingGateway, Kind: structs.ServiceKindTerminatingGateway,
ID: "terminating-gateway", ID: "terminating-gateway",
Service: "terminating-gateway", Service: "terminating-gateway",
@ -293,8 +285,6 @@ func TestServiceManager_PersistService_API(t *testing.T) {
// TestAgent_PurgeService. // TestAgent_PurgeService.
t.Parallel() t.Parallel()
require := require.New(t)
// Launch a server to manage the config entries. // Launch a server to manage the config entries.
serverAgent := NewTestAgent(t, "") serverAgent := NewTestAgent(t, "")
defer serverAgent.Shutdown() defer serverAgent.Shutdown()
@ -331,7 +321,7 @@ func TestServiceManager_PersistService_API(t *testing.T) {
_, err := a.JoinLAN([]string{ _, err := a.JoinLAN([]string{
fmt.Sprintf("127.0.0.1:%d", serverAgent.Config.SerfPortLAN), fmt.Sprintf("127.0.0.1:%d", serverAgent.Config.SerfPortLAN),
}, nil) }, nil)
require.NoError(err) require.NoError(t, err)
testrpc.WaitForLeader(t, a.RPC, "dc1") testrpc.WaitForLeader(t, a.RPC, "dc1")
@ -401,7 +391,7 @@ func TestServiceManager_PersistService_API(t *testing.T) {
// Service is not persisted unless requested, but we always persist service configs. // Service is not persisted unless requested, but we always persist service configs.
err = a.AddService(AddServiceRequest{Service: svc, Source: ConfigSourceRemote}) err = a.AddService(AddServiceRequest{Service: svc, Source: ConfigSourceRemote})
require.NoError(err) require.NoError(t, err)
requireFileIsAbsent(t, svcFile) requireFileIsAbsent(t, svcFile)
requireFileIsPresent(t, configFile) requireFileIsPresent(t, configFile)
@ -412,7 +402,7 @@ func TestServiceManager_PersistService_API(t *testing.T) {
token: "mytoken", token: "mytoken",
Source: ConfigSourceRemote, Source: ConfigSourceRemote,
}) })
require.NoError(err) require.NoError(t, err)
requireFileIsPresent(t, svcFile) requireFileIsPresent(t, svcFile)
requireFileIsPresent(t, configFile) requireFileIsPresent(t, configFile)
@ -447,8 +437,8 @@ func TestServiceManager_PersistService_API(t *testing.T) {
// Verify in memory state. // Verify in memory state.
{ {
sidecarService := a.State.Service(structs.NewServiceID("web-sidecar-proxy", nil)) sidecarService := a.State.Service(structs.NewServiceID("web-sidecar-proxy", nil))
require.NotNil(sidecarService) require.NotNil(t, sidecarService)
require.Equal(expectState, sidecarService) require.Equal(t, expectState, sidecarService)
} }
// Updates service definition on disk // Updates service definition on disk
@ -460,7 +450,7 @@ func TestServiceManager_PersistService_API(t *testing.T) {
token: "mytoken", token: "mytoken",
Source: ConfigSourceRemote, Source: ConfigSourceRemote,
}) })
require.NoError(err) require.NoError(t, err)
requireFileIsPresent(t, svcFile) requireFileIsPresent(t, svcFile)
requireFileIsPresent(t, configFile) requireFileIsPresent(t, configFile)
@ -496,8 +486,8 @@ func TestServiceManager_PersistService_API(t *testing.T) {
expectState.Proxy.LocalServicePort = 8001 expectState.Proxy.LocalServicePort = 8001
{ {
sidecarService := a.State.Service(structs.NewServiceID("web-sidecar-proxy", nil)) sidecarService := a.State.Service(structs.NewServiceID("web-sidecar-proxy", nil))
require.NotNil(sidecarService) require.NotNil(t, sidecarService)
require.Equal(expectState, sidecarService) require.Equal(t, expectState, sidecarService)
} }
// Kill the agent to restart it. // Kill the agent to restart it.
@ -512,12 +502,12 @@ func TestServiceManager_PersistService_API(t *testing.T) {
{ {
restored := a.State.Service(structs.NewServiceID("web-sidecar-proxy", nil)) restored := a.State.Service(structs.NewServiceID("web-sidecar-proxy", nil))
require.NotNil(restored) require.NotNil(t, restored)
require.Equal(expectState, restored) require.Equal(t, expectState, restored)
} }
// Now remove it. // Now remove it.
require.NoError(a2.RemoveService(structs.NewServiceID("web-sidecar-proxy", nil))) require.NoError(t, a2.RemoveService(structs.NewServiceID("web-sidecar-proxy", nil)))
requireFileIsAbsent(t, svcFile) requireFileIsAbsent(t, svcFile)
requireFileIsAbsent(t, configFile) requireFileIsAbsent(t, configFile)
} }
@ -704,8 +694,6 @@ func TestServiceManager_Disabled(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
require := require.New(t)
a := NewTestAgent(t, "enable_central_service_config = false") a := NewTestAgent(t, "enable_central_service_config = false")
defer a.Shutdown() defer a.Shutdown()
@ -752,12 +740,12 @@ func TestServiceManager_Disabled(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
require.NoError(a.addServiceFromSource(svc, nil, false, "", ConfigSourceLocal)) require.NoError(t, a.addServiceFromSource(svc, nil, false, "", ConfigSourceLocal))
// Verify sidecar got global config loaded // Verify sidecar got global config loaded
sidecarService := a.State.Service(structs.NewServiceID("web-sidecar-proxy", nil)) sidecarService := a.State.Service(structs.NewServiceID("web-sidecar-proxy", nil))
require.NotNil(sidecarService) require.NotNil(t, sidecarService)
require.Equal(&structs.NodeService{ require.Equal(t, &structs.NodeService{
Kind: structs.ServiceKindConnectProxy, Kind: structs.ServiceKindConnectProxy,
ID: "web-sidecar-proxy", ID: "web-sidecar-proxy",
Service: "web-sidecar-proxy", Service: "web-sidecar-proxy",

View File

@ -330,30 +330,29 @@ func TestAgent_sidecarServiceFromNodeService(t *testing.T) {
` `
} }
require := require.New(t)
a := StartTestAgent(t, TestAgent{Name: "jones", HCL: hcl}) a := StartTestAgent(t, TestAgent{Name: "jones", HCL: hcl})
defer a.Shutdown() defer a.Shutdown()
if tt.preRegister != nil { if tt.preRegister != nil {
err := a.addServiceFromSource(tt.preRegister.NodeService(), nil, false, "", ConfigSourceLocal) err := a.addServiceFromSource(tt.preRegister.NodeService(), nil, false, "", ConfigSourceLocal)
require.NoError(err) require.NoError(t, err)
} }
ns := tt.sd.NodeService() ns := tt.sd.NodeService()
err := ns.Validate() err := ns.Validate()
require.NoError(err, "Invalid test case - NodeService must validate") require.NoError(t, err, "Invalid test case - NodeService must validate")
gotNS, gotChecks, gotToken, err := a.sidecarServiceFromNodeService(ns, tt.token) gotNS, gotChecks, gotToken, err := a.sidecarServiceFromNodeService(ns, tt.token)
if tt.wantErr != "" { if tt.wantErr != "" {
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), tt.wantErr) require.Contains(t, err.Error(), tt.wantErr)
return return
} }
require.NoError(err) require.NoError(t, err)
require.Equal(tt.wantNS, gotNS) require.Equal(t, tt.wantNS, gotNS)
require.Equal(tt.wantChecks, gotChecks) require.Equal(t, tt.wantChecks, gotChecks)
require.Equal(tt.wantToken, gotToken) require.Equal(t, tt.wantToken, gotToken)
}) })
} }
} }

View File

@ -197,14 +197,13 @@ func TestConnectProxyConfig_MarshalJSON(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
got, err := tt.in.MarshalJSON() got, err := tt.in.MarshalJSON()
if tt.wantErr { if tt.wantErr {
require.Error(err) require.Error(t, err)
return return
} }
require.NoError(err) require.NoError(t, err)
require.JSONEq(tt.want, string(got)) require.JSONEq(t, tt.want, string(got))
}) })
} }
} }
@ -255,14 +254,13 @@ func TestUpstream_MarshalJSON(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
got, err := json.Marshal(tt.in) got, err := json.Marshal(tt.in)
if tt.wantErr { if tt.wantErr {
require.Error(err) require.Error(t, err)
return return
} }
require.NoError(err) require.NoError(t, err)
require.JSONEq(tt.want, string(got)) require.JSONEq(t, tt.want, string(got))
}) })
} }
} }

View File

@ -227,17 +227,16 @@ func TestIntentionValidate(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
assert := assert.New(t)
ixn := TestIntention(t) ixn := TestIntention(t)
tc.Modify(ixn) tc.Modify(ixn)
err := ixn.Validate() err := ixn.Validate()
assert.Equal(err != nil, tc.Err != "", err) assert.Equal(t, err != nil, tc.Err != "", err)
if err == nil { if err == nil {
return return
} }
assert.Contains(strings.ToLower(err.Error()), strings.ToLower(tc.Err)) assert.Contains(t, strings.ToLower(err.Error()), strings.ToLower(tc.Err))
}) })
} }
} }
@ -301,7 +300,6 @@ func TestIntentionPrecedenceSorter(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
assert := assert.New(t)
var input Intentions var input Intentions
for _, v := range tc.Input { for _, v := range tc.Input {
@ -331,7 +329,7 @@ func TestIntentionPrecedenceSorter(t *testing.T) {
v.DestinationName, v.DestinationName,
}) })
} }
assert.Equal(tc.Expected, actual) assert.Equal(t, tc.Expected, actual)
}) })
} }
} }

View File

@ -71,16 +71,15 @@ func TestServiceDefinitionValidate(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
service := TestServiceDefinition(t) service := TestServiceDefinition(t)
tc.Modify(service) tc.Modify(service)
err := service.Validate() err := service.Validate()
if tc.Err == "" { if tc.Err == "" {
require.NoError(err) require.NoError(t, err)
} else { } else {
require.Error(err) require.Error(t, err)
require.Contains(strings.ToLower(err.Error()), strings.ToLower(tc.Err)) require.Contains(t, strings.ToLower(err.Error()), strings.ToLower(tc.Err))
} }
}) })
} }

View File

@ -941,17 +941,16 @@ func TestStructs_NodeService_ValidateConnectProxy(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
assert := assert.New(t)
ns := TestNodeServiceProxy(t) ns := TestNodeServiceProxy(t)
tc.Modify(ns) tc.Modify(ns)
err := ns.Validate() err := ns.Validate()
assert.Equal(err != nil, tc.Err != "", err) assert.Equal(t, err != nil, tc.Err != "", err)
if err == nil { if err == nil {
return return
} }
assert.Contains(strings.ToLower(err.Error()), strings.ToLower(tc.Err)) assert.Contains(t, strings.ToLower(err.Error()), strings.ToLower(tc.Err))
}) })
} }
} }
@ -1000,17 +999,16 @@ func TestStructs_NodeService_ValidateConnectProxy_In_Partition(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
assert := assert.New(t)
ns := TestNodeServiceProxyInPartition(t, "bar") ns := TestNodeServiceProxyInPartition(t, "bar")
tc.Modify(ns) tc.Modify(ns)
err := ns.Validate() err := ns.Validate()
assert.Equal(err != nil, tc.Err != "", err) assert.Equal(t, err != nil, tc.Err != "", err)
if err == nil { if err == nil {
return return
} }
assert.Contains(strings.ToLower(err.Error()), strings.ToLower(tc.Err)) assert.Contains(t, strings.ToLower(err.Error()), strings.ToLower(tc.Err))
}) })
} }
} }
@ -1046,17 +1044,16 @@ func TestStructs_NodeService_ValidateSidecarService(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
assert := assert.New(t)
ns := TestNodeServiceSidecar(t) ns := TestNodeServiceSidecar(t)
tc.Modify(ns) tc.Modify(ns)
err := ns.Validate() err := ns.Validate()
assert.Equal(err != nil, tc.Err != "", err) assert.Equal(t, err != nil, tc.Err != "", err)
if err == nil { if err == nil {
return return
} }
assert.Contains(strings.ToLower(err.Error()), strings.ToLower(tc.Err)) assert.Contains(t, strings.ToLower(err.Error()), strings.ToLower(tc.Err))
}) })
} }
} }

View File

@ -762,8 +762,6 @@ func TestAPI_AgentService(t *testing.T) {
agent := c.Agent() agent := c.Agent()
require := require.New(t)
reg := &AgentServiceRegistration{ reg := &AgentServiceRegistration{
Name: "foo", Name: "foo",
Tags: []string{"bar", "baz"}, Tags: []string{"bar", "baz"},
@ -777,10 +775,10 @@ func TestAPI_AgentService(t *testing.T) {
}, },
}, },
} }
require.NoError(agent.ServiceRegister(reg)) require.NoError(t, agent.ServiceRegister(reg))
got, qm, err := agent.Service("foo", nil) got, qm, err := agent.Service("foo", nil)
require.NoError(err) require.NoError(t, err)
expect := &AgentService{ expect := &AgentService{
ID: "foo", ID: "foo",
@ -797,8 +795,8 @@ func TestAPI_AgentService(t *testing.T) {
Partition: defaultPartition, Partition: defaultPartition,
Datacenter: "dc1", Datacenter: "dc1",
} }
require.Equal(expect, got) require.Equal(t, expect, got)
require.Equal(expect.ContentHash, qm.LastContentHash) require.Equal(t, expect.ContentHash, qm.LastContentHash)
// Sanity check blocking behavior - this is more thoroughly tested in the // Sanity check blocking behavior - this is more thoroughly tested in the
// agent endpoint tests but this ensures that the API package is at least // agent endpoint tests but this ensures that the API package is at least
@ -810,8 +808,8 @@ func TestAPI_AgentService(t *testing.T) {
start := time.Now() start := time.Now()
_, _, err = agent.Service("foo", &opts) _, _, err = agent.Service("foo", &opts)
elapsed := time.Since(start) elapsed := time.Since(start)
require.NoError(err) require.NoError(t, err)
require.True(elapsed >= opts.WaitTime) require.True(t, elapsed >= opts.WaitTime)
} }
func TestAPI_AgentSetTTLStatus(t *testing.T) { func TestAPI_AgentSetTTLStatus(t *testing.T) {
@ -1616,7 +1614,6 @@ func TestAPI_AgentUpdateToken(t *testing.T) {
func TestAPI_AgentConnectCARoots_empty(t *testing.T) { func TestAPI_AgentConnectCARoots_empty(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
c, s := makeClientWithConfig(t, nil, func(c *testutil.TestServerConfig) { c, s := makeClientWithConfig(t, nil, func(c *testutil.TestServerConfig) {
c.Connect = nil // disable connect to prevent CA being bootstrapped c.Connect = nil // disable connect to prevent CA being bootstrapped
}) })
@ -1624,29 +1621,27 @@ func TestAPI_AgentConnectCARoots_empty(t *testing.T) {
agent := c.Agent() agent := c.Agent()
_, _, err := agent.ConnectCARoots(nil) _, _, err := agent.ConnectCARoots(nil)
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), "Connect must be enabled") require.Contains(t, err.Error(), "Connect must be enabled")
} }
func TestAPI_AgentConnectCARoots_list(t *testing.T) { func TestAPI_AgentConnectCARoots_list(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
c, s := makeClient(t) c, s := makeClient(t)
defer s.Stop() defer s.Stop()
agent := c.Agent() agent := c.Agent()
s.WaitForSerfCheck(t) s.WaitForSerfCheck(t)
list, meta, err := agent.ConnectCARoots(nil) list, meta, err := agent.ConnectCARoots(nil)
require.NoError(err) require.NoError(t, err)
require.True(meta.LastIndex > 0) require.True(t, meta.LastIndex > 0)
require.Len(list.Roots, 1) require.Len(t, list.Roots, 1)
} }
func TestAPI_AgentConnectCALeaf(t *testing.T) { func TestAPI_AgentConnectCALeaf(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
c, s := makeClient(t) c, s := makeClient(t)
defer s.Stop() defer s.Stop()
@ -1660,26 +1655,25 @@ func TestAPI_AgentConnectCALeaf(t *testing.T) {
Tags: []string{"bar", "baz"}, Tags: []string{"bar", "baz"},
Port: 8000, Port: 8000,
} }
require.NoError(agent.ServiceRegister(reg)) require.NoError(t, agent.ServiceRegister(reg))
leaf, meta, err := agent.ConnectCALeaf("foo", nil) leaf, meta, err := agent.ConnectCALeaf("foo", nil)
require.NoError(err) require.NoError(t, err)
require.True(meta.LastIndex > 0) require.True(t, meta.LastIndex > 0)
// Sanity checks here as we have actual certificate validation checks at many // Sanity checks here as we have actual certificate validation checks at many
// other levels. // other levels.
require.NotEmpty(leaf.SerialNumber) require.NotEmpty(t, leaf.SerialNumber)
require.NotEmpty(leaf.CertPEM) require.NotEmpty(t, leaf.CertPEM)
require.NotEmpty(leaf.PrivateKeyPEM) require.NotEmpty(t, leaf.PrivateKeyPEM)
require.Equal("foo", leaf.Service) require.Equal(t, "foo", leaf.Service)
require.True(strings.HasSuffix(leaf.ServiceURI, "/svc/foo")) require.True(t, strings.HasSuffix(leaf.ServiceURI, "/svc/foo"))
require.True(leaf.ModifyIndex > 0) require.True(t, leaf.ModifyIndex > 0)
require.True(leaf.ValidAfter.Before(time.Now())) require.True(t, leaf.ValidAfter.Before(time.Now()))
require.True(leaf.ValidBefore.After(time.Now())) require.True(t, leaf.ValidBefore.After(time.Now()))
} }
func TestAPI_AgentConnectAuthorize(t *testing.T) { func TestAPI_AgentConnectAuthorize(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
c, s := makeClient(t) c, s := makeClient(t)
defer s.Stop() defer s.Stop()
@ -1692,9 +1686,9 @@ func TestAPI_AgentConnectAuthorize(t *testing.T) {
ClientCertURI: "spiffe://11111111-2222-3333-4444-555555555555.consul/ns/default/dc/ny1/svc/web", ClientCertURI: "spiffe://11111111-2222-3333-4444-555555555555.consul/ns/default/dc/ny1/svc/web",
} }
auth, err := agent.ConnectAuthorize(params) auth, err := agent.ConnectAuthorize(params)
require.Nil(err) require.Nil(t, err)
require.True(auth.Authorized) require.True(t, auth.Authorized)
require.Equal(auth.Reason, "Default behavior configured by ACLs") require.Equal(t, auth.Reason, "Default behavior configured by ACLs")
} }
func TestAPI_AgentHealthServiceOpts(t *testing.T) { func TestAPI_AgentHealthServiceOpts(t *testing.T) {

View File

@ -738,8 +738,6 @@ func TestAPI_SetQueryOptions(t *testing.T) {
c, s := makeClient(t) c, s := makeClient(t)
defer s.Stop() defer s.Stop()
assert := assert.New(t)
r := c.newRequest("GET", "/v1/kv/foo") r := c.newRequest("GET", "/v1/kv/foo")
q := &QueryOptions{ q := &QueryOptions{
Namespace: "operator", Namespace: "operator",
@ -785,7 +783,7 @@ func TestAPI_SetQueryOptions(t *testing.T) {
if r.params.Get("local-only") != "true" { if r.params.Get("local-only") != "true" {
t.Fatalf("bad: %v", r.params) t.Fatalf("bad: %v", r.params)
} }
assert.Equal("", r.header.Get("Cache-Control")) assert.Equal(t, "", r.header.Get("Cache-Control"))
r = c.newRequest("GET", "/v1/kv/foo") r = c.newRequest("GET", "/v1/kv/foo")
q = &QueryOptions{ q = &QueryOptions{
@ -796,8 +794,8 @@ func TestAPI_SetQueryOptions(t *testing.T) {
r.setQueryOptions(q) r.setQueryOptions(q)
_, ok := r.params["cached"] _, ok := r.params["cached"]
assert.True(ok) assert.True(t, ok)
assert.Equal("max-age=30, stale-if-error=346", r.header.Get("Cache-Control")) assert.Equal(t, "max-age=30, stale-if-error=346", r.header.Get("Cache-Control"))
} }
func TestAPI_SetWriteOptions(t *testing.T) { func TestAPI_SetWriteOptions(t *testing.T) {

View File

@ -318,13 +318,11 @@ func TestAPI_CatalogServiceCached(t *testing.T) {
} }
}) })
require := require.New(t)
// Got success, next hit must be cache hit // Got success, next hit must be cache hit
_, meta, err := catalog.Service("consul", "", q) _, meta, err := catalog.Service("consul", "", q)
require.NoError(err) require.NoError(t, err)
require.True(meta.CacheHit) require.True(t, meta.CacheHit)
require.Equal(time.Duration(0), meta.CacheAge) require.Equal(t, time.Duration(0), meta.CacheAge)
} }
func TestAPI_CatalogService_SingleTag(t *testing.T) { func TestAPI_CatalogService_SingleTag(t *testing.T) {

View File

@ -250,7 +250,6 @@ func TestAPI_ConfigEntries(t *testing.T) {
}) })
t.Run("CAS deletion", func(t *testing.T) { t.Run("CAS deletion", func(t *testing.T) {
require := require.New(t)
entry := &ProxyConfigEntry{ entry := &ProxyConfigEntry{
Kind: ProxyDefaults, Kind: ProxyDefaults,
@ -262,23 +261,23 @@ func TestAPI_ConfigEntries(t *testing.T) {
// Create a config entry. // Create a config entry.
created, _, err := config_entries.Set(entry, nil) created, _, err := config_entries.Set(entry, nil)
require.NoError(err) require.NoError(t, err)
require.True(created, "entry should have been created") require.True(t, created, "entry should have been created")
// Read it back to get the ModifyIndex. // Read it back to get the ModifyIndex.
result, _, err := config_entries.Get(entry.Kind, entry.Name, nil) result, _, err := config_entries.Get(entry.Kind, entry.Name, nil)
require.NoError(err) require.NoError(t, err)
require.NotNil(entry) require.NotNil(t, entry)
// Attempt a deletion with an invalid index. // Attempt a deletion with an invalid index.
deleted, _, err := config_entries.DeleteCAS(entry.Kind, entry.Name, result.GetModifyIndex()-1, nil) deleted, _, err := config_entries.DeleteCAS(entry.Kind, entry.Name, result.GetModifyIndex()-1, nil)
require.NoError(err) require.NoError(t, err)
require.False(deleted, "entry should not have been deleted") require.False(t, deleted, "entry should not have been deleted")
// Attempt a deletion with a valid index. // Attempt a deletion with a valid index.
deleted, _, err = config_entries.DeleteCAS(entry.Kind, entry.Name, result.GetModifyIndex(), nil) deleted, _, err = config_entries.DeleteCAS(entry.Kind, entry.Name, result.GetModifyIndex(), nil)
require.NoError(err) require.NoError(t, err)
require.True(deleted, "entry should have been deleted") require.True(t, deleted, "entry should have been deleted")
}) })
} }

View File

@ -13,7 +13,6 @@ import (
func TestAPI_ConnectCARoots_empty(t *testing.T) { func TestAPI_ConnectCARoots_empty(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
c, s := makeClientWithConfig(t, nil, func(c *testutil.TestServerConfig) { c, s := makeClientWithConfig(t, nil, func(c *testutil.TestServerConfig) {
// Don't bootstrap CA // Don't bootstrap CA
c.Connect = nil c.Connect = nil
@ -25,8 +24,8 @@ func TestAPI_ConnectCARoots_empty(t *testing.T) {
connect := c.Connect() connect := c.Connect()
_, _, err := connect.CARoots(nil) _, _, err := connect.CARoots(nil)
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), "Connect must be enabled") require.Contains(t, err.Error(), "Connect must be enabled")
} }
func TestAPI_ConnectCARoots_list(t *testing.T) { func TestAPI_ConnectCARoots_list(t *testing.T) {

View File

@ -83,7 +83,6 @@ func TestAPI_StatusPeersWithQueryOptions(t *testing.T) {
func TestAPI_StatusLeader_WrongDC(t *testing.T) { func TestAPI_StatusLeader_WrongDC(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
c, s := makeClient(t) c, s := makeClient(t)
defer s.Stop() defer s.Stop()
@ -96,13 +95,12 @@ func TestAPI_StatusLeader_WrongDC(t *testing.T) {
} }
_, err := status.LeaderWithQueryOptions(&opts) _, err := status.LeaderWithQueryOptions(&opts)
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), "No path to datacenter") require.Contains(t, err.Error(), "No path to datacenter")
} }
func TestAPI_StatusPeers_WrongDC(t *testing.T) { func TestAPI_StatusPeers_WrongDC(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
c, s := makeClient(t) c, s := makeClient(t)
defer s.Stop() defer s.Stop()
@ -114,6 +112,6 @@ func TestAPI_StatusPeers_WrongDC(t *testing.T) {
Datacenter: "wrong_dc1", Datacenter: "wrong_dc1",
} }
_, err := status.PeersWithQueryOptions(&opts) _, err := status.PeersWithQueryOptions(&opts)
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), "No path to datacenter") require.Contains(t, err.Error(), "No path to datacenter")
} }

View File

@ -25,7 +25,6 @@ func TestAgentTokensCommand(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -50,7 +49,7 @@ func TestAgentTokensCommand(t *testing.T) {
&api.ACLToken{Description: "test"}, &api.ACLToken{Description: "test"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
assert.NoError(err) assert.NoError(t, err)
// default token // default token
{ {
@ -61,8 +60,8 @@ func TestAgentTokensCommand(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
} }
// agent token // agent token
@ -74,8 +73,8 @@ func TestAgentTokensCommand(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
} }
// recovery token // recovery token
@ -87,8 +86,8 @@ func TestAgentTokensCommand(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
} }
// replication token // replication token
@ -100,7 +99,7 @@ func TestAgentTokensCommand(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
} }
} }

View File

@ -27,7 +27,6 @@ func TestBootstrapCommand_Pretty(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -46,11 +45,11 @@ func TestBootstrapCommand_Pretty(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
assert.Contains(output, "Bootstrap Token") assert.Contains(t, output, "Bootstrap Token")
assert.Contains(output, structs.ACLPolicyGlobalManagementID) assert.Contains(t, output, structs.ACLPolicyGlobalManagementID)
} }
func TestBootstrapCommand_JSON(t *testing.T) { func TestBootstrapCommand_JSON(t *testing.T) {
@ -59,7 +58,6 @@ func TestBootstrapCommand_JSON(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -79,11 +77,11 @@ func TestBootstrapCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
assert.Contains(output, "Bootstrap Token") assert.Contains(t, output, "Bootstrap Token")
assert.Contains(output, structs.ACLPolicyGlobalManagementID) assert.Contains(t, output, structs.ACLPolicyGlobalManagementID)
var jsonOutput json.RawMessage var jsonOutput json.RawMessage
err := json.Unmarshal([]byte(output), &jsonOutput) err := json.Unmarshal([]byte(output), &jsonOutput)

View File

@ -28,7 +28,6 @@ func TestPolicyCreateCommand(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
testDir := testutil.TempDir(t, "acl") testDir := testutil.TempDir(t, "acl")
@ -49,7 +48,7 @@ func TestPolicyCreateCommand(t *testing.T) {
rules := []byte("service \"\" { policy = \"write\" }") rules := []byte("service \"\" { policy = \"write\" }")
err := ioutil.WriteFile(testDir+"/rules.hcl", rules, 0644) err := ioutil.WriteFile(testDir+"/rules.hcl", rules, 0644)
require.NoError(err) require.NoError(t, err)
args := []string{ args := []string{
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
@ -59,8 +58,8 @@ func TestPolicyCreateCommand(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
require.Equal(code, 0) require.Equal(t, code, 0)
require.Empty(ui.ErrorWriter.String()) require.Empty(t, ui.ErrorWriter.String())
} }
func TestPolicyCreateCommand_JSON(t *testing.T) { func TestPolicyCreateCommand_JSON(t *testing.T) {
@ -69,7 +68,6 @@ func TestPolicyCreateCommand_JSON(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
testDir := testutil.TempDir(t, "acl") testDir := testutil.TempDir(t, "acl")
@ -90,7 +88,7 @@ func TestPolicyCreateCommand_JSON(t *testing.T) {
rules := []byte("service \"\" { policy = \"write\" }") rules := []byte("service \"\" { policy = \"write\" }")
err := ioutil.WriteFile(testDir+"/rules.hcl", rules, 0644) err := ioutil.WriteFile(testDir+"/rules.hcl", rules, 0644)
require.NoError(err) require.NoError(t, err)
args := []string{ args := []string{
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
@ -101,8 +99,8 @@ func TestPolicyCreateCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
require.Equal(code, 0) require.Equal(t, code, 0)
require.Empty(ui.ErrorWriter.String()) require.Empty(t, ui.ErrorWriter.String())
var jsonOutput json.RawMessage var jsonOutput json.RawMessage
err = json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput) err = json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput)

View File

@ -26,7 +26,6 @@ func TestPolicyDeleteCommand(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -50,7 +49,7 @@ func TestPolicyDeleteCommand(t *testing.T) {
&api.ACLPolicy{Name: "test-policy"}, &api.ACLPolicy{Name: "test-policy"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
assert.NoError(err) assert.NoError(t, err)
args := []string{ args := []string{
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
@ -59,16 +58,16 @@ func TestPolicyDeleteCommand(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
assert.Contains(output, fmt.Sprintf("deleted successfully")) assert.Contains(t, output, fmt.Sprintf("deleted successfully"))
assert.Contains(output, policy.ID) assert.Contains(t, output, policy.ID)
_, _, err = client.ACL().PolicyRead( _, _, err = client.ACL().PolicyRead(
policy.ID, policy.ID,
&api.QueryOptions{Token: "root"}, &api.QueryOptions{Token: "root"},
) )
assert.EqualError(err, "Unexpected response code: 403 (ACL not found)") assert.EqualError(t, err, "Unexpected response code: 403 (ACL not found)")
} }

View File

@ -27,7 +27,6 @@ func TestPolicyListCommand(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -57,7 +56,7 @@ func TestPolicyListCommand(t *testing.T) {
) )
policyIDs = append(policyIDs, policy.ID) policyIDs = append(policyIDs, policy.ID)
assert.NoError(err) assert.NoError(t, err)
} }
args := []string{ args := []string{
@ -66,13 +65,13 @@ func TestPolicyListCommand(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
for i, v := range policyIDs { for i, v := range policyIDs {
assert.Contains(output, fmt.Sprintf("test-policy-%d", i)) assert.Contains(t, output, fmt.Sprintf("test-policy-%d", i))
assert.Contains(output, v) assert.Contains(t, output, v)
} }
} }
@ -82,7 +81,6 @@ func TestPolicyListCommand_JSON(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -112,7 +110,7 @@ func TestPolicyListCommand_JSON(t *testing.T) {
) )
policyIDs = append(policyIDs, policy.ID) policyIDs = append(policyIDs, policy.ID)
assert.NoError(err) assert.NoError(t, err)
} }
args := []string{ args := []string{
@ -122,16 +120,16 @@ func TestPolicyListCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
for i, v := range policyIDs { for i, v := range policyIDs {
assert.Contains(output, fmt.Sprintf("test-policy-%d", i)) assert.Contains(t, output, fmt.Sprintf("test-policy-%d", i))
assert.Contains(output, v) assert.Contains(t, output, v)
} }
var jsonOutput json.RawMessage var jsonOutput json.RawMessage
err := json.Unmarshal([]byte(output), &jsonOutput) err := json.Unmarshal([]byte(output), &jsonOutput)
assert.NoError(err) assert.NoError(t, err)
} }

View File

@ -27,7 +27,6 @@ func TestPolicyReadCommand(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -51,7 +50,7 @@ func TestPolicyReadCommand(t *testing.T) {
&api.ACLPolicy{Name: "test-policy"}, &api.ACLPolicy{Name: "test-policy"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
assert.NoError(err) assert.NoError(t, err)
// Test querying by id field // Test querying by id field
args := []string{ args := []string{
@ -61,12 +60,12 @@ func TestPolicyReadCommand(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
assert.Contains(output, fmt.Sprintf("test-policy")) assert.Contains(t, output, fmt.Sprintf("test-policy"))
assert.Contains(output, policy.ID) assert.Contains(t, output, policy.ID)
// Test querying by name field // Test querying by name field
argsName := []string{ argsName := []string{
@ -77,12 +76,12 @@ func TestPolicyReadCommand(t *testing.T) {
cmd = New(ui) cmd = New(ui)
code = cmd.Run(argsName) code = cmd.Run(argsName)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
output = ui.OutputWriter.String() output = ui.OutputWriter.String()
assert.Contains(output, fmt.Sprintf("test-policy")) assert.Contains(t, output, fmt.Sprintf("test-policy"))
assert.Contains(output, policy.ID) assert.Contains(t, output, policy.ID)
} }
func TestPolicyReadCommand_JSON(t *testing.T) { func TestPolicyReadCommand_JSON(t *testing.T) {
@ -91,7 +90,6 @@ func TestPolicyReadCommand_JSON(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -115,7 +113,7 @@ func TestPolicyReadCommand_JSON(t *testing.T) {
&api.ACLPolicy{Name: "test-policy"}, &api.ACLPolicy{Name: "test-policy"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
assert.NoError(err) assert.NoError(t, err)
args := []string{ args := []string{
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
@ -125,14 +123,14 @@ func TestPolicyReadCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
assert.Contains(output, fmt.Sprintf("test-policy")) assert.Contains(t, output, fmt.Sprintf("test-policy"))
assert.Contains(output, policy.ID) assert.Contains(t, output, policy.ID)
var jsonOutput json.RawMessage var jsonOutput json.RawMessage
err = json.Unmarshal([]byte(output), &jsonOutput) err = json.Unmarshal([]byte(output), &jsonOutput)
assert.NoError(err) assert.NoError(t, err)
} }

View File

@ -28,7 +28,6 @@ func TestPolicyUpdateCommand(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
testDir := testutil.TempDir(t, "acl") testDir := testutil.TempDir(t, "acl")
@ -49,7 +48,7 @@ func TestPolicyUpdateCommand(t *testing.T) {
rules := []byte("service \"\" { policy = \"write\" }") rules := []byte("service \"\" { policy = \"write\" }")
err := ioutil.WriteFile(testDir+"/rules.hcl", rules, 0644) err := ioutil.WriteFile(testDir+"/rules.hcl", rules, 0644)
assert.NoError(err) assert.NoError(t, err)
// Create a policy // Create a policy
client := a.Client() client := a.Client()
@ -58,7 +57,7 @@ func TestPolicyUpdateCommand(t *testing.T) {
&api.ACLPolicy{Name: "test-policy"}, &api.ACLPolicy{Name: "test-policy"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
assert.NoError(err) assert.NoError(t, err)
args := []string{ args := []string{
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
@ -69,8 +68,8 @@ func TestPolicyUpdateCommand(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
} }
func TestPolicyUpdateCommand_JSON(t *testing.T) { func TestPolicyUpdateCommand_JSON(t *testing.T) {
@ -79,7 +78,6 @@ func TestPolicyUpdateCommand_JSON(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
testDir := testutil.TempDir(t, "acl") testDir := testutil.TempDir(t, "acl")
@ -100,7 +98,7 @@ func TestPolicyUpdateCommand_JSON(t *testing.T) {
rules := []byte("service \"\" { policy = \"write\" }") rules := []byte("service \"\" { policy = \"write\" }")
err := ioutil.WriteFile(testDir+"/rules.hcl", rules, 0644) err := ioutil.WriteFile(testDir+"/rules.hcl", rules, 0644)
assert.NoError(err) assert.NoError(t, err)
// Create a policy // Create a policy
client := a.Client() client := a.Client()
@ -109,7 +107,7 @@ func TestPolicyUpdateCommand_JSON(t *testing.T) {
&api.ACLPolicy{Name: "test-policy"}, &api.ACLPolicy{Name: "test-policy"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
assert.NoError(err) assert.NoError(t, err)
args := []string{ args := []string{
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
@ -121,10 +119,10 @@ func TestPolicyUpdateCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
var jsonOutput json.RawMessage var jsonOutput json.RawMessage
err = json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput) err = json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput)
assert.NoError(err) assert.NoError(t, err)
} }

View File

@ -28,7 +28,6 @@ func TestRoleListCommand(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -61,7 +60,7 @@ func TestRoleListCommand(t *testing.T) {
) )
roleIDs = append(roleIDs, role.ID) roleIDs = append(roleIDs, role.ID)
require.NoError(err) require.NoError(t, err)
} }
args := []string{ args := []string{
@ -70,13 +69,13 @@ func TestRoleListCommand(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
require.Equal(code, 0) require.Equal(t, code, 0)
require.Empty(ui.ErrorWriter.String()) require.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
for i, v := range roleIDs { for i, v := range roleIDs {
require.Contains(output, fmt.Sprintf("test-role-%d", i)) require.Contains(t, output, fmt.Sprintf("test-role-%d", i))
require.Contains(output, v) require.Contains(t, output, v)
} }
} }
@ -86,7 +85,6 @@ func TestRoleListCommand_JSON(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -119,7 +117,7 @@ func TestRoleListCommand_JSON(t *testing.T) {
) )
roleIDs = append(roleIDs, role.ID) roleIDs = append(roleIDs, role.ID)
require.NoError(err) require.NoError(t, err)
} }
args := []string{ args := []string{
@ -129,13 +127,13 @@ func TestRoleListCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
require.Equal(code, 0) require.Equal(t, code, 0)
require.Empty(ui.ErrorWriter.String()) require.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
for i, v := range roleIDs { for i, v := range roleIDs {
require.Contains(output, fmt.Sprintf("test-role-%d", i)) require.Contains(t, output, fmt.Sprintf("test-role-%d", i))
require.Contains(output, v) require.Contains(t, output, v)
} }
var jsonOutput json.RawMessage var jsonOutput json.RawMessage

View File

@ -65,7 +65,6 @@ func TestTokenCloneCommand_Pretty(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -86,14 +85,14 @@ func TestTokenCloneCommand_Pretty(t *testing.T) {
&api.ACLPolicy{Name: "test-policy"}, &api.ACLPolicy{Name: "test-policy"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
require.NoError(err) require.NoError(t, err)
// create a token // create a token
token, _, err := client.ACL().TokenCreate( token, _, err := client.ACL().TokenCreate(
&api.ACLToken{Description: "test", Policies: []*api.ACLTokenPolicyLink{{Name: "test-policy"}}}, &api.ACLToken{Description: "test", Policies: []*api.ACLTokenPolicyLink{{Name: "test-policy"}}},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
require.NoError(err) require.NoError(t, err)
// clone with description // clone with description
t.Run("Description", func(t *testing.T) { t.Run("Description", func(t *testing.T) {
@ -108,27 +107,27 @@ func TestTokenCloneCommand_Pretty(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
require.Empty(ui.ErrorWriter.String()) require.Empty(t, ui.ErrorWriter.String())
require.Equal(code, 0) require.Equal(t, code, 0)
cloned := parseCloneOutput(t, ui.OutputWriter.String()) cloned := parseCloneOutput(t, ui.OutputWriter.String())
require.Equal("test cloned", cloned.Description) require.Equal(t, "test cloned", cloned.Description)
require.Len(cloned.Policies, 1) require.Len(t, cloned.Policies, 1)
apiToken, _, err := client.ACL().TokenRead( apiToken, _, err := client.ACL().TokenRead(
cloned.AccessorID, cloned.AccessorID,
&api.QueryOptions{Token: "root"}, &api.QueryOptions{Token: "root"},
) )
require.NoError(err) require.NoError(t, err)
require.NotNil(apiToken) require.NotNil(t, apiToken)
require.Equal(cloned.AccessorID, apiToken.AccessorID) require.Equal(t, cloned.AccessorID, apiToken.AccessorID)
require.Equal(cloned.SecretID, apiToken.SecretID) require.Equal(t, cloned.SecretID, apiToken.SecretID)
require.Equal(cloned.Description, apiToken.Description) require.Equal(t, cloned.Description, apiToken.Description)
require.Equal(cloned.Local, apiToken.Local) require.Equal(t, cloned.Local, apiToken.Local)
require.Equal(cloned.Policies, apiToken.Policies) require.Equal(t, cloned.Policies, apiToken.Policies)
}) })
// clone without description // clone without description
@ -143,27 +142,27 @@ func TestTokenCloneCommand_Pretty(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
require.Equal(code, 0) require.Equal(t, code, 0)
require.Empty(ui.ErrorWriter.String()) require.Empty(t, ui.ErrorWriter.String())
cloned := parseCloneOutput(t, ui.OutputWriter.String()) cloned := parseCloneOutput(t, ui.OutputWriter.String())
require.Equal("test", cloned.Description) require.Equal(t, "test", cloned.Description)
require.Len(cloned.Policies, 1) require.Len(t, cloned.Policies, 1)
apiToken, _, err := client.ACL().TokenRead( apiToken, _, err := client.ACL().TokenRead(
cloned.AccessorID, cloned.AccessorID,
&api.QueryOptions{Token: "root"}, &api.QueryOptions{Token: "root"},
) )
require.NoError(err) require.NoError(t, err)
require.NotNil(apiToken) require.NotNil(t, apiToken)
require.Equal(cloned.AccessorID, apiToken.AccessorID) require.Equal(t, cloned.AccessorID, apiToken.AccessorID)
require.Equal(cloned.SecretID, apiToken.SecretID) require.Equal(t, cloned.SecretID, apiToken.SecretID)
require.Equal(cloned.Description, apiToken.Description) require.Equal(t, cloned.Description, apiToken.Description)
require.Equal(cloned.Local, apiToken.Local) require.Equal(t, cloned.Local, apiToken.Local)
require.Equal(cloned.Policies, apiToken.Policies) require.Equal(t, cloned.Policies, apiToken.Policies)
}) })
} }
@ -173,7 +172,6 @@ func TestTokenCloneCommand_JSON(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -194,14 +192,14 @@ func TestTokenCloneCommand_JSON(t *testing.T) {
&api.ACLPolicy{Name: "test-policy"}, &api.ACLPolicy{Name: "test-policy"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
require.NoError(err) require.NoError(t, err)
// create a token // create a token
token, _, err := client.ACL().TokenCreate( token, _, err := client.ACL().TokenCreate(
&api.ACLToken{Description: "test", Policies: []*api.ACLTokenPolicyLink{{Name: "test-policy"}}}, &api.ACLToken{Description: "test", Policies: []*api.ACLTokenPolicyLink{{Name: "test-policy"}}},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
require.NoError(err) require.NoError(t, err)
// clone with description // clone with description
t.Run("Description", func(t *testing.T) { t.Run("Description", func(t *testing.T) {
@ -217,8 +215,8 @@ func TestTokenCloneCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
require.Empty(ui.ErrorWriter.String()) require.Empty(t, ui.ErrorWriter.String())
require.Equal(code, 0) require.Equal(t, code, 0)
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
var jsonOutput json.RawMessage var jsonOutput json.RawMessage
@ -239,8 +237,8 @@ func TestTokenCloneCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
require.Empty(ui.ErrorWriter.String()) require.Empty(t, ui.ErrorWriter.String())
require.Equal(code, 0) require.Equal(t, code, 0)
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
var jsonOutput json.RawMessage var jsonOutput json.RawMessage

View File

@ -124,7 +124,6 @@ func TestTokenCreateCommand_JSON(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -148,7 +147,7 @@ func TestTokenCreateCommand_JSON(t *testing.T) {
&api.ACLPolicy{Name: "test-policy"}, &api.ACLPolicy{Name: "test-policy"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
require.NoError(err) require.NoError(t, err)
// create with policy by name // create with policy by name
{ {
@ -161,11 +160,11 @@ func TestTokenCreateCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
require.Equal(code, 0) require.Equal(t, code, 0)
require.Empty(ui.ErrorWriter.String()) require.Empty(t, ui.ErrorWriter.String())
var jsonOutput json.RawMessage var jsonOutput json.RawMessage
err = json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput) err = json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput)
require.NoError(err, "token unmarshalling error") require.NoError(t, err, "token unmarshalling error")
} }
} }

View File

@ -26,7 +26,6 @@ func TestTokenDeleteCommand(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -50,7 +49,7 @@ func TestTokenDeleteCommand(t *testing.T) {
&api.ACLToken{Description: "test"}, &api.ACLToken{Description: "test"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
assert.NoError(err) assert.NoError(t, err)
args := []string{ args := []string{
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
@ -59,16 +58,16 @@ func TestTokenDeleteCommand(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
assert.Contains(output, fmt.Sprintf("deleted successfully")) assert.Contains(t, output, fmt.Sprintf("deleted successfully"))
assert.Contains(output, token.AccessorID) assert.Contains(t, output, token.AccessorID)
_, _, err = client.ACL().TokenRead( _, _, err = client.ACL().TokenRead(
token.AccessorID, token.AccessorID,
&api.QueryOptions{Token: "root"}, &api.QueryOptions{Token: "root"},
) )
assert.EqualError(err, "Unexpected response code: 403 (ACL not found)") assert.EqualError(t, err, "Unexpected response code: 403 (ACL not found)")
} }

View File

@ -28,7 +28,6 @@ func TestTokenListCommand_Pretty(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -58,7 +57,7 @@ func TestTokenListCommand_Pretty(t *testing.T) {
) )
tokenIds = append(tokenIds, token.AccessorID) tokenIds = append(tokenIds, token.AccessorID)
assert.NoError(err) assert.NoError(t, err)
} }
args := []string{ args := []string{
@ -67,13 +66,13 @@ func TestTokenListCommand_Pretty(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
for i, v := range tokenIds { for i, v := range tokenIds {
assert.Contains(output, fmt.Sprintf("test token %d", i)) assert.Contains(t, output, fmt.Sprintf("test token %d", i))
assert.Contains(output, v) assert.Contains(t, output, v)
} }
} }
@ -83,7 +82,6 @@ func TestTokenListCommand_JSON(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -113,7 +111,7 @@ func TestTokenListCommand_JSON(t *testing.T) {
) )
tokenIds = append(tokenIds, token.AccessorID) tokenIds = append(tokenIds, token.AccessorID)
assert.NoError(err) assert.NoError(t, err)
} }
args := []string{ args := []string{
@ -123,8 +121,8 @@ func TestTokenListCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
var jsonOutput []api.ACLTokenListEntry var jsonOutput []api.ACLTokenListEntry
err := json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput) err := json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput)

View File

@ -28,7 +28,6 @@ func TestTokenReadCommand_Pretty(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -52,7 +51,7 @@ func TestTokenReadCommand_Pretty(t *testing.T) {
&api.ACLToken{Description: "test"}, &api.ACLToken{Description: "test"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
assert.NoError(err) assert.NoError(t, err)
args := []string{ args := []string{
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
@ -61,13 +60,13 @@ func TestTokenReadCommand_Pretty(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
output := ui.OutputWriter.String() output := ui.OutputWriter.String()
assert.Contains(output, fmt.Sprintf("test")) assert.Contains(t, output, fmt.Sprintf("test"))
assert.Contains(output, token.AccessorID) assert.Contains(t, output, token.AccessorID)
assert.Contains(output, token.SecretID) assert.Contains(t, output, token.SecretID)
} }
func TestTokenReadCommand_JSON(t *testing.T) { func TestTokenReadCommand_JSON(t *testing.T) {
@ -76,7 +75,6 @@ func TestTokenReadCommand_JSON(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -100,7 +98,7 @@ func TestTokenReadCommand_JSON(t *testing.T) {
&api.ACLToken{Description: "test"}, &api.ACLToken{Description: "test"},
&api.WriteOptions{Token: "root"}, &api.WriteOptions{Token: "root"},
) )
assert.NoError(err) assert.NoError(t, err)
args := []string{ args := []string{
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
@ -110,8 +108,8 @@ func TestTokenReadCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
var jsonOutput json.RawMessage var jsonOutput json.RawMessage
err = json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput) err = json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput)

View File

@ -157,7 +157,6 @@ func TestTokenUpdateCommand_JSON(t *testing.T) {
} }
t.Parallel() t.Parallel()
assert := assert.New(t)
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
primary_datacenter = "dc1" primary_datacenter = "dc1"
@ -201,8 +200,8 @@ func TestTokenUpdateCommand_JSON(t *testing.T) {
} }
code := cmd.Run(args) code := cmd.Run(args)
assert.Equal(code, 0) assert.Equal(t, code, 0)
assert.Empty(ui.ErrorWriter.String()) assert.Empty(t, ui.ErrorWriter.String())
var jsonOutput json.RawMessage var jsonOutput json.RawMessage
err := json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput) err := json.Unmarshal([]byte(ui.OutputWriter.String()), &jsonOutput)

View File

@ -28,7 +28,6 @@ func TestConnectCASetConfigCommand(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
@ -49,10 +48,10 @@ func TestConnectCASetConfigCommand(t *testing.T) {
Datacenter: "dc1", Datacenter: "dc1",
} }
var reply structs.CAConfiguration var reply structs.CAConfiguration
require.NoError(a.RPC("ConnectCA.ConfigurationGet", &req, &reply)) require.NoError(t, a.RPC("ConnectCA.ConfigurationGet", &req, &reply))
require.Equal("consul", reply.Provider) require.Equal(t, "consul", reply.Provider)
parsed, err := ca.ParseConsulCAConfig(reply.Config) parsed, err := ca.ParseConsulCAConfig(reply.Config)
require.NoError(err) require.NoError(t, err)
require.Equal(288*time.Hour, parsed.IntermediateCertTTL) require.Equal(t, 288*time.Hour, parsed.IntermediateCertTTL)
} }

View File

@ -850,14 +850,13 @@ func TestGenerateConfig(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
testDir := testutil.TempDir(t, "envoytest") testDir := testutil.TempDir(t, "envoytest")
if len(tc.Files) > 0 { if len(tc.Files) > 0 {
for fn, fv := range tc.Files { for fn, fv := range tc.Files {
fullname := filepath.Join(testDir, fn) fullname := filepath.Join(testDir, fn)
require.NoError(ioutil.WriteFile(fullname, []byte(fv), 0600)) require.NoError(t, ioutil.WriteFile(fullname, []byte(fv), 0600))
} }
} }
@ -876,7 +875,7 @@ func TestGenerateConfig(t *testing.T) {
defer testSetAndResetEnv(t, myEnv)() defer testSetAndResetEnv(t, myEnv)()
client, err := api.NewClient(&api.Config{Address: srv.URL, TLSConfig: api.TLSConfig{InsecureSkipVerify: true}}) client, err := api.NewClient(&api.Config{Address: srv.URL, TLSConfig: api.TLSConfig{InsecureSkipVerify: true}})
require.NoError(err) require.NoError(t, err)
ui := cli.NewMockUi() ui := cli.NewMockUi()
c := New(ui) c := New(ui)
@ -887,21 +886,21 @@ func TestGenerateConfig(t *testing.T) {
myFlags := copyAndReplaceAll(tc.Flags, "@@TEMPDIR@@", testDirPrefix) myFlags := copyAndReplaceAll(tc.Flags, "@@TEMPDIR@@", testDirPrefix)
args := append([]string{"-bootstrap"}, myFlags...) args := append([]string{"-bootstrap"}, myFlags...)
require.NoError(c.flags.Parse(args)) require.NoError(t, c.flags.Parse(args))
code := c.run(c.flags.Args()) code := c.run(c.flags.Args())
if tc.WantErr == "" { if tc.WantErr == "" {
require.Equal(0, code, ui.ErrorWriter.String()) require.Equal(t, 0, code, ui.ErrorWriter.String())
} else { } else {
require.Equal(1, code, ui.ErrorWriter.String()) require.Equal(t, 1, code, ui.ErrorWriter.String())
require.Contains(ui.ErrorWriter.String(), tc.WantErr) require.Contains(t, ui.ErrorWriter.String(), tc.WantErr)
return return
} }
// Verify we handled the env and flags right first to get correct template // Verify we handled the env and flags right first to get correct template
// args. // args.
got, err := c.templateArgs() got, err := c.templateArgs()
require.NoError(err) // Error cases should have returned above require.NoError(t, err) // Error cases should have returned above
require.Equal(&tc.WantArgs, got) require.Equal(t, &tc.WantArgs, got)
actual := ui.OutputWriter.Bytes() actual := ui.OutputWriter.Bytes()
@ -912,8 +911,8 @@ func TestGenerateConfig(t *testing.T) {
} }
expected, err := ioutil.ReadFile(golden) expected, err := ioutil.ReadFile(golden)
require.NoError(err) require.NoError(t, err)
require.Equal(string(expected), string(actual)) require.Equal(t, string(expected), string(actual))
}) })
} }
} }

View File

@ -105,7 +105,6 @@ func TestExecEnvoy(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
args := append([]string{"exec-fake-envoy"}, tc.Args...) args := append([]string{"exec-fake-envoy"}, tc.Args...)
cmd, destroy := helperProcess(args...) cmd, destroy := helperProcess(args...)
@ -113,10 +112,10 @@ func TestExecEnvoy(t *testing.T) {
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
outBytes, err := cmd.Output() outBytes, err := cmd.Output()
require.NoError(err) require.NoError(t, err)
var got FakeEnvoyExecData var got FakeEnvoyExecData
require.NoError(json.Unmarshal(outBytes, &got)) require.NoError(t, json.Unmarshal(outBytes, &got))
expectConfigData := fakeEnvoyTestData expectConfigData := fakeEnvoyTestData
@ -126,11 +125,11 @@ func TestExecEnvoy(t *testing.T) {
"{{ got.ConfigPath }}", got.ConfigPath, 1) "{{ got.ConfigPath }}", got.ConfigPath, 1)
} }
require.Equal(tc.WantArgs, got.Args) require.Equal(t, tc.WantArgs, got.Args)
require.Equal(expectConfigData, got.ConfigData) require.Equal(t, expectConfigData, got.ConfigData)
// Sanity check the config path in a non-brittle way since we used it to // Sanity check the config path in a non-brittle way since we used it to
// generate expectation for the args. // generate expectation for the args.
require.Regexp(`-bootstrap.json$`, got.ConfigPath) require.Regexp(t, `-bootstrap.json$`, got.ConfigPath)
}) })
} }
} }

View File

@ -16,7 +16,6 @@ func TestConnectExpose(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
client := a.Client() client := a.Client()
defer a.Shutdown() defer a.Shutdown()
@ -41,7 +40,7 @@ func TestConnectExpose(t *testing.T) {
// Make sure the config entry and intention have been created. // Make sure the config entry and intention have been created.
entry, _, err := client.ConfigEntries().Get(api.IngressGateway, "ingress", nil) entry, _, err := client.ConfigEntries().Get(api.IngressGateway, "ingress", nil)
require.NoError(err) require.NoError(t, err)
ns := entry.(*api.IngressGatewayConfigEntry).Namespace ns := entry.(*api.IngressGatewayConfigEntry).Namespace
ap := entry.(*api.IngressGatewayConfigEntry).Partition ap := entry.(*api.IngressGatewayConfigEntry).Partition
expected := &api.IngressGatewayConfigEntry{ expected := &api.IngressGatewayConfigEntry{
@ -64,13 +63,13 @@ func TestConnectExpose(t *testing.T) {
} }
expected.CreateIndex = entry.GetCreateIndex() expected.CreateIndex = entry.GetCreateIndex()
expected.ModifyIndex = entry.GetModifyIndex() expected.ModifyIndex = entry.GetModifyIndex()
require.Equal(expected, entry) require.Equal(t, expected, entry)
ixns, _, err := client.Connect().Intentions(nil) ixns, _, err := client.Connect().Intentions(nil)
require.NoError(err) require.NoError(t, err)
require.Len(ixns, 1) require.Len(t, ixns, 1)
require.Equal("ingress", ixns[0].SourceName) require.Equal(t, "ingress", ixns[0].SourceName)
require.Equal("foo", ixns[0].DestinationName) require.Equal(t, "foo", ixns[0].DestinationName)
// Run the command again with a different port, make sure the config entry // Run the command again with a different port, make sure the config entry
// is updated while intentions are unmodified. // is updated while intentions are unmodified.
@ -104,15 +103,15 @@ func TestConnectExpose(t *testing.T) {
// Make sure the config entry/intention weren't affected. // Make sure the config entry/intention weren't affected.
entry, _, err = client.ConfigEntries().Get(api.IngressGateway, "ingress", nil) entry, _, err = client.ConfigEntries().Get(api.IngressGateway, "ingress", nil)
require.NoError(err) require.NoError(t, err)
expected.ModifyIndex = entry.GetModifyIndex() expected.ModifyIndex = entry.GetModifyIndex()
require.Equal(expected, entry) require.Equal(t, expected, entry)
ixns, _, err = client.Connect().Intentions(nil) ixns, _, err = client.Connect().Intentions(nil)
require.NoError(err) require.NoError(t, err)
require.Len(ixns, 1) require.Len(t, ixns, 1)
require.Equal("ingress", ixns[0].SourceName) require.Equal(t, "ingress", ixns[0].SourceName)
require.Equal("foo", ixns[0].DestinationName) require.Equal(t, "foo", ixns[0].DestinationName)
} }
// Run the command again with a conflicting protocol, should exit with an error and // Run the command again with a conflicting protocol, should exit with an error and
@ -132,18 +131,18 @@ func TestConnectExpose(t *testing.T) {
if code != 1 { if code != 1 {
t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String()) t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String())
} }
require.Contains(ui.ErrorWriter.String(), `conflicting protocol "tcp"`) require.Contains(t, ui.ErrorWriter.String(), `conflicting protocol "tcp"`)
// Make sure the config entry/intention weren't affected. // Make sure the config entry/intention weren't affected.
entry, _, err = client.ConfigEntries().Get(api.IngressGateway, "ingress", nil) entry, _, err = client.ConfigEntries().Get(api.IngressGateway, "ingress", nil)
require.NoError(err) require.NoError(t, err)
require.Equal(expected, entry) require.Equal(t, expected, entry)
ixns, _, err = client.Connect().Intentions(nil) ixns, _, err = client.Connect().Intentions(nil)
require.NoError(err) require.NoError(t, err)
require.Len(ixns, 1) require.Len(t, ixns, 1)
require.Equal("ingress", ixns[0].SourceName) require.Equal(t, "ingress", ixns[0].SourceName)
require.Equal("foo", ixns[0].DestinationName) require.Equal(t, "foo", ixns[0].DestinationName)
} }
} }
@ -153,7 +152,6 @@ func TestConnectExpose_invalidFlags(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
@ -169,7 +167,7 @@ func TestConnectExpose_invalidFlags(t *testing.T) {
if code != 1 { if code != 1 {
t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String()) t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String())
} }
require.Contains(ui.ErrorWriter.String(), "A service name must be given") require.Contains(t, ui.ErrorWriter.String(), "A service name must be given")
}) })
t.Run("missing gateway", func(t *testing.T) { t.Run("missing gateway", func(t *testing.T) {
ui := cli.NewMockUi() ui := cli.NewMockUi()
@ -183,7 +181,7 @@ func TestConnectExpose_invalidFlags(t *testing.T) {
if code != 1 { if code != 1 {
t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String()) t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String())
} }
require.Contains(ui.ErrorWriter.String(), "An ingress gateway service must be given") require.Contains(t, ui.ErrorWriter.String(), "An ingress gateway service must be given")
}) })
t.Run("missing port", func(t *testing.T) { t.Run("missing port", func(t *testing.T) {
ui := cli.NewMockUi() ui := cli.NewMockUi()
@ -198,7 +196,7 @@ func TestConnectExpose_invalidFlags(t *testing.T) {
if code != 1 { if code != 1 {
t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String()) t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String())
} }
require.Contains(ui.ErrorWriter.String(), "A port must be provided") require.Contains(t, ui.ErrorWriter.String(), "A port must be provided")
}) })
} }
@ -208,7 +206,6 @@ func TestConnectExpose_existingConfig(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
client := a.Client() client := a.Client()
defer a.Shutdown() defer a.Shutdown()
@ -220,7 +217,7 @@ func TestConnectExpose_existingConfig(t *testing.T) {
Name: service, Name: service,
Protocol: "http", Protocol: "http",
}, nil) }, nil)
require.NoError(err) require.NoError(t, err)
} }
// Create an existing ingress config entry with some services. // Create an existing ingress config entry with some services.
@ -249,7 +246,7 @@ func TestConnectExpose_existingConfig(t *testing.T) {
}, },
} }
_, _, err := client.ConfigEntries().Set(ingressConf, nil) _, _, err := client.ConfigEntries().Set(ingressConf, nil)
require.NoError(err) require.NoError(t, err)
// Add a service on a new port. // Add a service on a new port.
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -271,7 +268,7 @@ func TestConnectExpose_existingConfig(t *testing.T) {
// Make sure the ingress config was updated and existing services preserved. // Make sure the ingress config was updated and existing services preserved.
entry, _, err := client.ConfigEntries().Get(api.IngressGateway, "ingress", nil) entry, _, err := client.ConfigEntries().Get(api.IngressGateway, "ingress", nil)
require.NoError(err) require.NoError(t, err)
entryConf := entry.(*api.IngressGatewayConfigEntry) entryConf := entry.(*api.IngressGatewayConfigEntry)
ingressConf.Listeners = append(ingressConf.Listeners, api.IngressListener{ ingressConf.Listeners = append(ingressConf.Listeners, api.IngressListener{
@ -290,7 +287,7 @@ func TestConnectExpose_existingConfig(t *testing.T) {
} }
ingressConf.CreateIndex = entry.GetCreateIndex() ingressConf.CreateIndex = entry.GetCreateIndex()
ingressConf.ModifyIndex = entry.GetModifyIndex() ingressConf.ModifyIndex = entry.GetModifyIndex()
require.Equal(ingressConf, entry) require.Equal(t, ingressConf, entry)
} }
// Add an service on a port shared with an existing listener. // Add an service on a port shared with an existing listener.
@ -315,7 +312,7 @@ func TestConnectExpose_existingConfig(t *testing.T) {
// Make sure the ingress config was updated and existing services preserved. // Make sure the ingress config was updated and existing services preserved.
entry, _, err := client.ConfigEntries().Get(api.IngressGateway, "ingress", nil) entry, _, err := client.ConfigEntries().Get(api.IngressGateway, "ingress", nil)
require.NoError(err) require.NoError(t, err)
entryConf := entry.(*api.IngressGatewayConfigEntry) entryConf := entry.(*api.IngressGatewayConfigEntry)
ingressConf.Listeners[1].Services = append(ingressConf.Listeners[1].Services, api.IngressService{ ingressConf.Listeners[1].Services = append(ingressConf.Listeners[1].Services, api.IngressService{
@ -326,7 +323,7 @@ func TestConnectExpose_existingConfig(t *testing.T) {
}) })
ingressConf.CreateIndex = entry.GetCreateIndex() ingressConf.CreateIndex = entry.GetCreateIndex()
ingressConf.ModifyIndex = entry.GetModifyIndex() ingressConf.ModifyIndex = entry.GetModifyIndex()
require.Equal(ingressConf, entry) require.Equal(t, ingressConf, entry)
} }
// Update the bar service and add a custom host. // Update the bar service and add a custom host.
@ -350,11 +347,11 @@ func TestConnectExpose_existingConfig(t *testing.T) {
// Make sure the ingress config was updated and existing services preserved. // Make sure the ingress config was updated and existing services preserved.
entry, _, err := client.ConfigEntries().Get(api.IngressGateway, "ingress", nil) entry, _, err := client.ConfigEntries().Get(api.IngressGateway, "ingress", nil)
require.NoError(err) require.NoError(t, err)
ingressConf.Listeners[1].Services[0].Hosts = []string{"bar.com"} ingressConf.Listeners[1].Services[0].Hosts = []string{"bar.com"}
ingressConf.CreateIndex = entry.GetCreateIndex() ingressConf.CreateIndex = entry.GetCreateIndex()
ingressConf.ModifyIndex = entry.GetModifyIndex() ingressConf.ModifyIndex = entry.GetModifyIndex()
require.Equal(ingressConf, entry) require.Equal(t, ingressConf, entry)
} }
} }

View File

@ -103,7 +103,6 @@ func TestFlagUpstreams(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
var actual map[string]proxy.UpstreamConfig var actual map[string]proxy.UpstreamConfig
f := (*FlagUpstreams)(&actual) f := (*FlagUpstreams)(&actual)
@ -115,12 +114,12 @@ func TestFlagUpstreams(t *testing.T) {
// test failures confusing but it shouldn't be too bad. // test failures confusing but it shouldn't be too bad.
} }
if tc.Error != "" { if tc.Error != "" {
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), tc.Error) require.Contains(t, err.Error(), tc.Error)
return return
} }
require.Equal(tc.Expected, actual) require.Equal(t, tc.Expected, actual)
}) })
} }
} }

View File

@ -114,7 +114,6 @@ func TestCommandConfigWatcher(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
// Register a few services with 0, 1 and 2 sidecars // Register a few services with 0, 1 and 2 sidecars
a := agent.NewTestAgent(t, ` a := agent.NewTestAgent(t, `
@ -160,16 +159,16 @@ func TestCommandConfigWatcher(t *testing.T) {
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
}, tc.Flags...)) }, tc.Flags...))
if tc.WantErr == "" { if tc.WantErr == "" {
require.Equal(0, code, ui.ErrorWriter.String()) require.Equal(t, 0, code, ui.ErrorWriter.String())
} else { } else {
require.Equal(1, code, ui.ErrorWriter.String()) require.Equal(t, 1, code, ui.ErrorWriter.String())
require.Contains(ui.ErrorWriter.String(), tc.WantErr) require.Contains(t, ui.ErrorWriter.String(), tc.WantErr)
return return
} }
// Get the configuration watcher // Get the configuration watcher
cw, err := c.configWatcher(client) cw, err := c.configWatcher(client)
require.NoError(err) require.NoError(t, err)
if tc.Test != nil { if tc.Test != nil {
tc.Test(t, testConfig(t, cw)) tc.Test(t, testConfig(t, cw))
} }

View File

@ -18,7 +18,6 @@ func TestRegisterMonitor_good(t *testing.T) {
} }
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
@ -28,16 +27,16 @@ func TestRegisterMonitor_good(t *testing.T) {
defer m.Close() defer m.Close()
// Verify the settings // Verify the settings
require.Equal(api.ServiceKindConnectProxy, service.Kind) require.Equal(t, api.ServiceKindConnectProxy, service.Kind)
require.Equal("foo", service.Proxy.DestinationServiceName) require.Equal(t, "foo", service.Proxy.DestinationServiceName)
require.Equal("127.0.0.1", service.Address) require.Equal(t, "127.0.0.1", service.Address)
require.Equal(1234, service.Port) require.Equal(t, 1234, service.Port)
// Stop should deregister the service // Stop should deregister the service
require.NoError(m.Close()) require.NoError(t, m.Close())
services, err := client.Agent().Services() services, err := client.Agent().Services()
require.NoError(err) require.NoError(t, err)
require.NotContains(services, m.serviceID()) require.NotContains(t, services, m.serviceID())
} }
func TestRegisterMonitor_heartbeat(t *testing.T) { func TestRegisterMonitor_heartbeat(t *testing.T) {

View File

@ -8,8 +8,7 @@ import (
func TestHTTPFlagsSetToken(t *testing.T) { func TestHTTPFlagsSetToken(t *testing.T) {
var f HTTPFlags var f HTTPFlags
require := require.New(t) require.Empty(t, f.Token())
require.Empty(f.Token()) require.NoError(t, f.SetToken("foo"))
require.NoError(f.SetToken("foo")) require.Equal(t, "foo", f.Token())
require.Equal("foo", f.Token())
} }

View File

@ -46,7 +46,6 @@ func TestIntentionCheck_Validation(t *testing.T) {
for name, tc := range cases { for name, tc := range cases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
require := require.New(t)
c.init() c.init()
@ -58,9 +57,9 @@ func TestIntentionCheck_Validation(t *testing.T) {
ui.OutputWriter.Reset() ui.OutputWriter.Reset()
} }
require.Equal(2, c.Run(tc.args)) require.Equal(t, 2, c.Run(tc.args))
output := ui.ErrorWriter.String() output := ui.ErrorWriter.String()
require.Contains(output, tc.output) require.Contains(t, output, tc.output)
}) })
} }
} }
@ -72,7 +71,6 @@ func TestIntentionCheck(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -87,7 +85,7 @@ func TestIntentionCheck(t *testing.T) {
DestinationName: "db", DestinationName: "db",
Action: api.IntentionActionDeny, Action: api.IntentionActionDeny,
}, nil) }, nil)
require.NoError(err) require.NoError(t, err)
} }
// Get it // Get it
@ -99,8 +97,8 @@ func TestIntentionCheck(t *testing.T) {
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
"foo", "db", "foo", "db",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
require.Contains(ui.OutputWriter.String(), "Allow") require.Contains(t, ui.OutputWriter.String(), "Allow")
} }
{ {
@ -111,7 +109,7 @@ func TestIntentionCheck(t *testing.T) {
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
"web", "db", "web", "db",
} }
require.Equal(1, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 1, c.Run(args), ui.ErrorWriter.String())
require.Contains(ui.OutputWriter.String(), "Denied") require.Contains(t, ui.OutputWriter.String(), "Denied")
} }
} }

View File

@ -37,7 +37,6 @@ func TestIntentionCreate_Validation(t *testing.T) {
for name, tc := range cases { for name, tc := range cases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
require := require.New(t)
c.init() c.init()
@ -49,9 +48,9 @@ func TestIntentionCreate_Validation(t *testing.T) {
ui.OutputWriter.Reset() ui.OutputWriter.Reset()
} }
require.Equal(1, c.Run(tc.args)) require.Equal(t, 1, c.Run(tc.args))
output := ui.ErrorWriter.String() output := ui.ErrorWriter.String()
require.Contains(output, tc.output) require.Contains(t, output, tc.output)
}) })
} }
} }
@ -63,7 +62,6 @@ func TestIntentionCreate(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -77,14 +75,14 @@ func TestIntentionCreate(t *testing.T) {
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
"foo", "bar", "foo", "bar",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
ixns, _, err := client.Connect().Intentions(nil) ixns, _, err := client.Connect().Intentions(nil)
require.NoError(err) require.NoError(t, err)
require.Len(ixns, 1) require.Len(t, ixns, 1)
require.Equal("foo", ixns[0].SourceName) require.Equal(t, "foo", ixns[0].SourceName)
require.Equal("bar", ixns[0].DestinationName) require.Equal(t, "bar", ixns[0].DestinationName)
require.Equal(api.IntentionActionAllow, ixns[0].Action) require.Equal(t, api.IntentionActionAllow, ixns[0].Action)
} }
func TestIntentionCreate_deny(t *testing.T) { func TestIntentionCreate_deny(t *testing.T) {
@ -94,7 +92,6 @@ func TestIntentionCreate_deny(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -109,14 +106,14 @@ func TestIntentionCreate_deny(t *testing.T) {
"-deny", "-deny",
"foo", "bar", "foo", "bar",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
ixns, _, err := client.Connect().Intentions(nil) ixns, _, err := client.Connect().Intentions(nil)
require.NoError(err) require.NoError(t, err)
require.Len(ixns, 1) require.Len(t, ixns, 1)
require.Equal("foo", ixns[0].SourceName) require.Equal(t, "foo", ixns[0].SourceName)
require.Equal("bar", ixns[0].DestinationName) require.Equal(t, "bar", ixns[0].DestinationName)
require.Equal(api.IntentionActionDeny, ixns[0].Action) require.Equal(t, api.IntentionActionDeny, ixns[0].Action)
} }
func TestIntentionCreate_meta(t *testing.T) { func TestIntentionCreate_meta(t *testing.T) {
@ -126,7 +123,6 @@ func TestIntentionCreate_meta(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -141,14 +137,14 @@ func TestIntentionCreate_meta(t *testing.T) {
"-meta", "hello=world", "-meta", "hello=world",
"foo", "bar", "foo", "bar",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
ixns, _, err := client.Connect().Intentions(nil) ixns, _, err := client.Connect().Intentions(nil)
require.NoError(err) require.NoError(t, err)
require.Len(ixns, 1) require.Len(t, ixns, 1)
require.Equal("foo", ixns[0].SourceName) require.Equal(t, "foo", ixns[0].SourceName)
require.Equal("bar", ixns[0].DestinationName) require.Equal(t, "bar", ixns[0].DestinationName)
require.Equal(map[string]string{"hello": "world"}, ixns[0].Meta) require.Equal(t, map[string]string{"hello": "world"}, ixns[0].Meta)
} }
func TestIntentionCreate_File(t *testing.T) { func TestIntentionCreate_File(t *testing.T) {
@ -158,7 +154,6 @@ func TestIntentionCreate_File(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -180,14 +175,14 @@ func TestIntentionCreate_File(t *testing.T) {
f.Name(), f.Name(),
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
ixns, _, err := client.Connect().Intentions(nil) ixns, _, err := client.Connect().Intentions(nil)
require.NoError(err) require.NoError(t, err)
require.Len(ixns, 1) require.Len(t, ixns, 1)
require.Equal("foo", ixns[0].SourceName) require.Equal(t, "foo", ixns[0].SourceName)
require.Equal("bar", ixns[0].DestinationName) require.Equal(t, "bar", ixns[0].DestinationName)
require.Equal(api.IntentionActionAllow, ixns[0].Action) require.Equal(t, api.IntentionActionAllow, ixns[0].Action)
} }
func TestIntentionCreate_File_L7_fails(t *testing.T) { func TestIntentionCreate_File_L7_fails(t *testing.T) {
@ -197,7 +192,6 @@ func TestIntentionCreate_File_L7_fails(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
@ -231,8 +225,8 @@ func TestIntentionCreate_File_L7_fails(t *testing.T) {
f.Name(), f.Name(),
} }
require.Equal(1, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 1, c.Run(args), ui.ErrorWriter.String())
require.Contains(ui.ErrorWriter.String(), "cannot create L7 intention from file") require.Contains(t, ui.ErrorWriter.String(), "cannot create L7 intention from file")
} }
func TestIntentionCreate_FileNoExist(t *testing.T) { func TestIntentionCreate_FileNoExist(t *testing.T) {
@ -242,7 +236,6 @@ func TestIntentionCreate_FileNoExist(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
@ -257,8 +250,8 @@ func TestIntentionCreate_FileNoExist(t *testing.T) {
"shouldnotexist.txt", "shouldnotexist.txt",
} }
require.Equal(1, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 1, c.Run(args), ui.ErrorWriter.String())
require.Contains(ui.ErrorWriter.String(), "no such file") require.Contains(t, ui.ErrorWriter.String(), "no such file")
} }
func TestIntentionCreate_replace(t *testing.T) { func TestIntentionCreate_replace(t *testing.T) {
@ -268,7 +261,6 @@ func TestIntentionCreate_replace(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -284,14 +276,14 @@ func TestIntentionCreate_replace(t *testing.T) {
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
"foo", "bar", "foo", "bar",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
ixns, _, err := client.Connect().Intentions(nil) ixns, _, err := client.Connect().Intentions(nil)
require.NoError(err) require.NoError(t, err)
require.Len(ixns, 1) require.Len(t, ixns, 1)
require.Equal("foo", ixns[0].SourceName) require.Equal(t, "foo", ixns[0].SourceName)
require.Equal("bar", ixns[0].DestinationName) require.Equal(t, "bar", ixns[0].DestinationName)
require.Equal(api.IntentionActionAllow, ixns[0].Action) require.Equal(t, api.IntentionActionAllow, ixns[0].Action)
} }
// Don't replace, should be an error // Don't replace, should be an error
@ -304,8 +296,8 @@ func TestIntentionCreate_replace(t *testing.T) {
"-deny", "-deny",
"foo", "bar", "foo", "bar",
} }
require.Equal(1, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 1, c.Run(args), ui.ErrorWriter.String())
require.Contains(ui.ErrorWriter.String(), "more than once") require.Contains(t, ui.ErrorWriter.String(), "more than once")
} }
// Replace it // Replace it
@ -319,13 +311,13 @@ func TestIntentionCreate_replace(t *testing.T) {
"-deny", "-deny",
"foo", "bar", "foo", "bar",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
ixns, _, err := client.Connect().Intentions(nil) ixns, _, err := client.Connect().Intentions(nil)
require.NoError(err) require.NoError(t, err)
require.Len(ixns, 1) require.Len(t, ixns, 1)
require.Equal("foo", ixns[0].SourceName) require.Equal(t, "foo", ixns[0].SourceName)
require.Equal("bar", ixns[0].DestinationName) require.Equal(t, "bar", ixns[0].DestinationName)
require.Equal(api.IntentionActionDeny, ixns[0].Action) require.Equal(t, api.IntentionActionDeny, ixns[0].Action)
} }
} }

View File

@ -43,7 +43,6 @@ func TestIntentionGet_Validation(t *testing.T) {
for name, tc := range cases { for name, tc := range cases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
require := require.New(t)
c.init() c.init()
@ -55,9 +54,9 @@ func TestIntentionGet_Validation(t *testing.T) {
ui.OutputWriter.Reset() ui.OutputWriter.Reset()
} }
require.Equal(1, c.Run(tc.args)) require.Equal(t, 1, c.Run(tc.args))
output := ui.ErrorWriter.String() output := ui.ErrorWriter.String()
require.Contains(output, tc.output) require.Contains(t, output, tc.output)
}) })
} }
} }
@ -69,7 +68,6 @@ func TestIntentionGet_id(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -86,7 +84,7 @@ func TestIntentionGet_id(t *testing.T) {
DestinationName: "db", DestinationName: "db",
Action: api.IntentionActionAllow, Action: api.IntentionActionAllow,
}, nil) }, nil)
require.NoError(err) require.NoError(t, err)
} }
// Get it // Get it
@ -97,8 +95,8 @@ func TestIntentionGet_id(t *testing.T) {
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
id, id,
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
require.Contains(ui.OutputWriter.String(), id) require.Contains(t, ui.OutputWriter.String(), id)
} }
func TestIntentionGet_srcDst(t *testing.T) { func TestIntentionGet_srcDst(t *testing.T) {
@ -108,7 +106,6 @@ func TestIntentionGet_srcDst(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -125,7 +122,7 @@ func TestIntentionGet_srcDst(t *testing.T) {
DestinationName: "db", DestinationName: "db",
Action: api.IntentionActionAllow, Action: api.IntentionActionAllow,
}, nil) }, nil)
require.NoError(err) require.NoError(t, err)
} }
// Get it // Get it
@ -136,8 +133,8 @@ func TestIntentionGet_srcDst(t *testing.T) {
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
"web", "db", "web", "db",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
require.Contains(ui.OutputWriter.String(), id) require.Contains(t, ui.OutputWriter.String(), id)
} }
func TestIntentionGet_verticalBar(t *testing.T) { func TestIntentionGet_verticalBar(t *testing.T) {
@ -147,7 +144,6 @@ func TestIntentionGet_verticalBar(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -166,7 +162,7 @@ func TestIntentionGet_verticalBar(t *testing.T) {
DestinationName: "db", DestinationName: "db",
Action: api.IntentionActionAllow, Action: api.IntentionActionAllow,
}, nil) }, nil)
require.NoError(err) require.NoError(t, err)
} }
// Get it // Get it
@ -177,9 +173,9 @@ func TestIntentionGet_verticalBar(t *testing.T) {
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
id, id,
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
// Check for sourceName presense because it should not be parsed by // Check for sourceName presense because it should not be parsed by
// columnize // columnize
require.Contains(ui.OutputWriter.String(), sourceName) require.Contains(t, ui.OutputWriter.String(), sourceName)
} }

View File

@ -46,7 +46,6 @@ func TestIntentionMatch_Validation(t *testing.T) {
for name, tc := range cases { for name, tc := range cases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
require := require.New(t)
c.init() c.init()
@ -58,9 +57,9 @@ func TestIntentionMatch_Validation(t *testing.T) {
ui.OutputWriter.Reset() ui.OutputWriter.Reset()
} }
require.Equal(1, c.Run(tc.args)) require.Equal(t, 1, c.Run(tc.args))
output := ui.ErrorWriter.String() output := ui.ErrorWriter.String()
require.Contains(output, tc.output) require.Contains(t, output, tc.output)
}) })
} }
} }
@ -72,7 +71,6 @@ func TestIntentionMatch_matchDst(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -94,8 +92,8 @@ func TestIntentionMatch_matchDst(t *testing.T) {
DestinationName: v[1], DestinationName: v[1],
Action: api.IntentionActionDeny, Action: api.IntentionActionDeny,
}, nil) }, nil)
require.NoError(err) require.NoError(t, err)
require.NotEmpty(id) require.NotEmpty(t, id)
} }
} }
@ -108,10 +106,10 @@ func TestIntentionMatch_matchDst(t *testing.T) {
"-http-addr=" + a.HTTPAddr(), "-http-addr=" + a.HTTPAddr(),
"db", "db",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
require.Contains(ui.OutputWriter.String(), "web") require.Contains(t, ui.OutputWriter.String(), "web")
require.Contains(ui.OutputWriter.String(), "db") require.Contains(t, ui.OutputWriter.String(), "db")
require.Contains(ui.OutputWriter.String(), "*") require.Contains(t, ui.OutputWriter.String(), "*")
} }
} }
@ -122,7 +120,6 @@ func TestIntentionMatch_matchSource(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -144,8 +141,8 @@ func TestIntentionMatch_matchSource(t *testing.T) {
DestinationName: v[1], DestinationName: v[1],
Action: api.IntentionActionDeny, Action: api.IntentionActionDeny,
}, nil) }, nil)
require.NoError(err) require.NoError(t, err)
require.NotEmpty(id) require.NotEmpty(t, id)
} }
} }
@ -159,8 +156,8 @@ func TestIntentionMatch_matchSource(t *testing.T) {
"-source", "-source",
"foo", "foo",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
require.Contains(ui.OutputWriter.String(), "db") require.Contains(t, ui.OutputWriter.String(), "db")
require.NotContains(ui.OutputWriter.String(), "web") require.NotContains(t, ui.OutputWriter.String(), "web")
} }
} }

View File

@ -176,10 +176,9 @@ func TestStructsToAgentService(t *testing.T) {
tc := tt tc := tt
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
actual, err := serviceToAgentService(tc.Input) actual, err := serviceToAgentService(tc.Input)
require.NoError(err) require.NoError(t, err)
require.Equal(tc.Output, actual) require.Equal(t, tc.Output, actual)
}) })
} }
} }

View File

@ -41,7 +41,6 @@ func TestCommand_Validation(t *testing.T) {
for name, tc := range cases { for name, tc := range cases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
require := require.New(t)
c.init() c.init()
@ -53,9 +52,9 @@ func TestCommand_Validation(t *testing.T) {
ui.OutputWriter.Reset() ui.OutputWriter.Reset()
} }
require.Equal(1, c.Run(tc.args)) require.Equal(t, 1, c.Run(tc.args))
output := ui.ErrorWriter.String() output := ui.ErrorWriter.String()
require.Contains(output, tc.output) require.Contains(t, output, tc.output)
}) })
} }
} }
@ -67,15 +66,14 @@ func TestCommand_File_id(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
// Register a service // Register a service
require.NoError(client.Agent().ServiceRegister(&api.AgentServiceRegistration{ require.NoError(t, client.Agent().ServiceRegister(&api.AgentServiceRegistration{
Name: "web"})) Name: "web"}))
require.NoError(client.Agent().ServiceRegister(&api.AgentServiceRegistration{ require.NoError(t, client.Agent().ServiceRegister(&api.AgentServiceRegistration{
Name: "db"})) Name: "db"}))
ui := cli.NewMockUi() ui := cli.NewMockUi()
@ -93,12 +91,12 @@ func TestCommand_File_id(t *testing.T) {
f.Name(), f.Name(),
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
svcs, err := client.Agent().Services() svcs, err := client.Agent().Services()
require.NoError(err) require.NoError(t, err)
require.Len(svcs, 1) require.Len(t, svcs, 1)
require.NotNil(svcs["db"]) require.NotNil(t, svcs["db"])
} }
func TestCommand_File_nameOnly(t *testing.T) { func TestCommand_File_nameOnly(t *testing.T) {
@ -108,15 +106,14 @@ func TestCommand_File_nameOnly(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
// Register a service // Register a service
require.NoError(client.Agent().ServiceRegister(&api.AgentServiceRegistration{ require.NoError(t, client.Agent().ServiceRegister(&api.AgentServiceRegistration{
Name: "web"})) Name: "web"}))
require.NoError(client.Agent().ServiceRegister(&api.AgentServiceRegistration{ require.NoError(t, client.Agent().ServiceRegister(&api.AgentServiceRegistration{
Name: "db"})) Name: "db"}))
ui := cli.NewMockUi() ui := cli.NewMockUi()
@ -134,12 +131,12 @@ func TestCommand_File_nameOnly(t *testing.T) {
f.Name(), f.Name(),
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
svcs, err := client.Agent().Services() svcs, err := client.Agent().Services()
require.NoError(err) require.NoError(t, err)
require.Len(svcs, 1) require.Len(t, svcs, 1)
require.NotNil(svcs["db"]) require.NotNil(t, svcs["db"])
} }
func TestCommand_Flag(t *testing.T) { func TestCommand_Flag(t *testing.T) {
@ -149,15 +146,14 @@ func TestCommand_Flag(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
// Register a service // Register a service
require.NoError(client.Agent().ServiceRegister(&api.AgentServiceRegistration{ require.NoError(t, client.Agent().ServiceRegister(&api.AgentServiceRegistration{
Name: "web"})) Name: "web"}))
require.NoError(client.Agent().ServiceRegister(&api.AgentServiceRegistration{ require.NoError(t, client.Agent().ServiceRegister(&api.AgentServiceRegistration{
Name: "db"})) Name: "db"}))
ui := cli.NewMockUi() ui := cli.NewMockUi()
@ -168,12 +164,12 @@ func TestCommand_Flag(t *testing.T) {
"-id", "web", "-id", "web",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
svcs, err := client.Agent().Services() svcs, err := client.Agent().Services()
require.NoError(err) require.NoError(t, err)
require.Len(svcs, 1) require.Len(t, svcs, 1)
require.NotNil(svcs["db"]) require.NotNil(t, svcs["db"])
} }
func testFile(t *testing.T, suffix string) *os.File { func testFile(t *testing.T, suffix string) *os.File {

View File

@ -40,7 +40,6 @@ func TestCommand_Validation(t *testing.T) {
for name, tc := range cases { for name, tc := range cases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
require := require.New(t)
c.init() c.init()
@ -52,9 +51,9 @@ func TestCommand_Validation(t *testing.T) {
ui.OutputWriter.Reset() ui.OutputWriter.Reset()
} }
require.Equal(1, c.Run(tc.args)) require.Equal(t, 1, c.Run(tc.args))
output := ui.ErrorWriter.String() output := ui.ErrorWriter.String()
require.Contains(output, tc.output) require.Contains(t, output, tc.output)
}) })
} }
} }
@ -66,7 +65,6 @@ func TestCommand_File(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -86,14 +84,14 @@ func TestCommand_File(t *testing.T) {
f.Name(), f.Name(),
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
svcs, err := client.Agent().Services() svcs, err := client.Agent().Services()
require.NoError(err) require.NoError(t, err)
require.Len(svcs, 1) require.Len(t, svcs, 1)
svc := svcs["web"] svc := svcs["web"]
require.NotNil(svc) require.NotNil(t, svc)
} }
func TestCommand_Flags(t *testing.T) { func TestCommand_Flags(t *testing.T) {
@ -103,7 +101,6 @@ func TestCommand_Flags(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -116,14 +113,14 @@ func TestCommand_Flags(t *testing.T) {
"-name", "web", "-name", "web",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
svcs, err := client.Agent().Services() svcs, err := client.Agent().Services()
require.NoError(err) require.NoError(t, err)
require.Len(svcs, 1) require.Len(t, svcs, 1)
svc := svcs["web"] svc := svcs["web"]
require.NotNil(svc) require.NotNil(t, svc)
} }
func TestCommand_Flags_TaggedAddresses(t *testing.T) { func TestCommand_Flags_TaggedAddresses(t *testing.T) {
@ -133,7 +130,6 @@ func TestCommand_Flags_TaggedAddresses(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -148,21 +144,21 @@ func TestCommand_Flags_TaggedAddresses(t *testing.T) {
"-tagged-address", "v6=[2001:db8::12]:1234", "-tagged-address", "v6=[2001:db8::12]:1234",
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
svcs, err := client.Agent().Services() svcs, err := client.Agent().Services()
require.NoError(err) require.NoError(t, err)
require.Len(svcs, 1) require.Len(t, svcs, 1)
svc := svcs["web"] svc := svcs["web"]
require.NotNil(svc) require.NotNil(t, svc)
require.Len(svc.TaggedAddresses, 2) require.Len(t, svc.TaggedAddresses, 2)
require.Contains(svc.TaggedAddresses, "lan") require.Contains(t, svc.TaggedAddresses, "lan")
require.Contains(svc.TaggedAddresses, "v6") require.Contains(t, svc.TaggedAddresses, "v6")
require.Equal(svc.TaggedAddresses["lan"].Address, "127.0.0.1") require.Equal(t, svc.TaggedAddresses["lan"].Address, "127.0.0.1")
require.Equal(svc.TaggedAddresses["lan"].Port, 1234) require.Equal(t, svc.TaggedAddresses["lan"].Port, 1234)
require.Equal(svc.TaggedAddresses["v6"].Address, "2001:db8::12") require.Equal(t, svc.TaggedAddresses["v6"].Address, "2001:db8::12")
require.Equal(svc.TaggedAddresses["v6"].Port, 1234) require.Equal(t, svc.TaggedAddresses["v6"].Port, 1234)
} }
func TestCommand_FileWithUnnamedCheck(t *testing.T) { func TestCommand_FileWithUnnamedCheck(t *testing.T) {
@ -172,7 +168,6 @@ func TestCommand_FileWithUnnamedCheck(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t)
a := agent.NewTestAgent(t, ``) a := agent.NewTestAgent(t, ``)
defer a.Shutdown() defer a.Shutdown()
client := a.Client() client := a.Client()
@ -192,18 +187,18 @@ func TestCommand_FileWithUnnamedCheck(t *testing.T) {
f.Name(), f.Name(),
} }
require.Equal(0, c.Run(args), ui.ErrorWriter.String()) require.Equal(t, 0, c.Run(args), ui.ErrorWriter.String())
svcs, err := client.Agent().Services() svcs, err := client.Agent().Services()
require.NoError(err) require.NoError(t, err)
require.Len(svcs, 1) require.Len(t, svcs, 1)
svc := svcs["web"] svc := svcs["web"]
require.NotNil(svc) require.NotNil(t, svc)
checks, err := client.Agent().Checks() checks, err := client.Agent().Checks()
require.NoError(err) require.NoError(t, err)
require.Len(checks, 1) require.Len(t, checks, 1)
} }
func testFile(t *testing.T, suffix string) *os.File { func testFile(t *testing.T, suffix string) *os.File {

View File

@ -32,10 +32,9 @@ func TestStaticResolver_Resolve(t *testing.T) {
CertURI: tt.fields.CertURI, CertURI: tt.fields.CertURI,
} }
addr, certURI, err := sr.Resolve(context.Background()) addr, certURI, err := sr.Resolve(context.Background())
require := require.New(t) require.Nil(t, err)
require.Nil(err) require.Equal(t, sr.Addr, addr)
require.Equal(sr.Addr, addr) require.Equal(t, sr.CertURI, certURI)
require.Equal(sr.CertURI, certURI)
}) })
} }
} }
@ -201,7 +200,6 @@ func TestConsulResolver_Resolve(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
cr := &ConsulResolver{ cr := &ConsulResolver{
Client: client, Client: client,
Namespace: tt.fields.Namespace, Namespace: tt.fields.Namespace,
@ -218,14 +216,14 @@ func TestConsulResolver_Resolve(t *testing.T) {
defer cancel() defer cancel()
gotAddr, gotCertURI, err := cr.Resolve(ctx) gotAddr, gotCertURI, err := cr.Resolve(ctx)
if tt.wantErr { if tt.wantErr {
require.NotNil(err) require.NotNil(t, err)
return return
} }
require.Nil(err) require.Nil(t, err)
require.Equal(tt.wantCertURI, gotCertURI) require.Equal(t, tt.wantCertURI, gotCertURI)
if len(tt.addrs) > 0 { if len(tt.addrs) > 0 {
require.Contains(tt.addrs, gotAddr) require.Contains(t, tt.addrs, gotAddr)
} }
}) })
} }
@ -323,16 +321,15 @@ func TestConsulResolverFromAddrFunc(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
fn := ConsulResolverFromAddrFunc(client) fn := ConsulResolverFromAddrFunc(client)
got, gotErr := fn(tt.addr) got, gotErr := fn(tt.addr)
if tt.wantErr != "" { if tt.wantErr != "" {
require.Error(gotErr) require.Error(t, gotErr)
require.Contains(gotErr.Error(), tt.wantErr) require.Contains(t, gotErr.Error(), tt.wantErr)
} else { } else {
require.NoError(gotErr) require.NoError(t, gotErr)
require.Equal(tt.want, got) require.Equal(t, tt.want, got)
} }
}) })
} }

View File

@ -77,7 +77,6 @@ func TestService_Dial(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
s := TestService(t, "web", ca) s := TestService(t, "web", ca)
@ -91,7 +90,7 @@ func TestService_Dial(t *testing.T) {
if tt.accept { if tt.accept {
go func() { go func() {
err := testSvr.Serve() err := testSvr.Serve()
require.NoError(err) require.NoError(t, err)
}() }()
<-testSvr.Listening <-testSvr.Listening
defer testSvr.Close() defer testSvr.Close()
@ -114,11 +113,11 @@ func TestService_Dial(t *testing.T) {
testTimer.Stop() testTimer.Stop()
if tt.wantErr == "" { if tt.wantErr == "" {
require.NoError(err) require.NoError(t, err)
require.IsType(&tls.Conn{}, conn) require.IsType(t, &tls.Conn{}, conn)
} else { } else {
require.Error(err) require.Error(t, err)
require.Contains(err.Error(), tt.wantErr) require.Contains(t, err.Error(), tt.wantErr)
} }
if err == nil { if err == nil {
@ -133,8 +132,6 @@ func TestService_ServerTLSConfig(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
require := require.New(t)
a := agent.StartTestAgent(t, agent.TestAgent{Name: "007", Overrides: ` a := agent.StartTestAgent(t, agent.TestAgent{Name: "007", Overrides: `
connect { connect {
test_ca_leaf_root_change_spread = "1ns" test_ca_leaf_root_change_spread = "1ns"
@ -153,12 +150,12 @@ func TestService_ServerTLSConfig(t *testing.T) {
Port: 8080, Port: 8080,
} }
err := agent.ServiceRegister(reg) err := agent.ServiceRegister(reg)
require.NoError(err) require.NoError(t, err)
// Now we should be able to create a service that will eventually get it's TLS // Now we should be able to create a service that will eventually get it's TLS
// all by itself! // all by itself!
service, err := NewService("web", client) service, err := NewService("web", client)
require.NoError(err) require.NoError(t, err)
// Wait for it to be ready // Wait for it to be ready
select { select {
@ -172,17 +169,17 @@ func TestService_ServerTLSConfig(t *testing.T) {
// Sanity check it has a leaf with the right ServiceID and that validates with // Sanity check it has a leaf with the right ServiceID and that validates with
// the given roots. // the given roots.
require.NotNil(tlsCfg.GetCertificate) require.NotNil(t, tlsCfg.GetCertificate)
leaf, err := tlsCfg.GetCertificate(&tls.ClientHelloInfo{}) leaf, err := tlsCfg.GetCertificate(&tls.ClientHelloInfo{})
require.NoError(err) require.NoError(t, err)
cert, err := x509.ParseCertificate(leaf.Certificate[0]) cert, err := x509.ParseCertificate(leaf.Certificate[0])
require.NoError(err) require.NoError(t, err)
require.Len(cert.URIs, 1) require.Len(t, cert.URIs, 1)
require.True(strings.HasSuffix(cert.URIs[0].String(), "/svc/web")) require.True(t, strings.HasSuffix(cert.URIs[0].String(), "/svc/web"))
// Verify it as a client would // Verify it as a client would
err = clientSideVerifier(tlsCfg, leaf.Certificate) err = clientSideVerifier(tlsCfg, leaf.Certificate)
require.NoError(err) require.NoError(t, err)
// Now test that rotating the root updates // Now test that rotating the root updates
{ {
@ -242,7 +239,7 @@ func TestService_HTTPClient(t *testing.T) {
// Hook the service resolver to avoid needing full agent setup. // Hook the service resolver to avoid needing full agent setup.
s.httpResolverFromAddr = func(addr string) (Resolver, error) { s.httpResolverFromAddr = func(addr string) (Resolver, error) {
// Require in this goroutine seems to block causing a timeout on the Get. // Require in this goroutine seems to block causing a timeout on the Get.
//require.Equal("https://backend.service.consul:443", addr) //require.Equal(t,"https://backend.service.consul:443", addr)
return &StaticResolver{ return &StaticResolver{
Addr: testSvr.Addr, Addr: testSvr.Addr,
CertURI: connect.TestSpiffeIDService(t, "backend"), CertURI: connect.TestSpiffeIDService(t, "backend"),

View File

@ -123,13 +123,12 @@ func TestClientSideVerifier(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
err := clientSideVerifier(tt.tlsCfg, tt.rawCerts) err := clientSideVerifier(tt.tlsCfg, tt.rawCerts)
if tt.wantErr == "" { if tt.wantErr == "" {
require.Nil(err) require.Nil(t, err)
} else { } else {
require.NotNil(err) require.NotNil(t, err)
require.Contains(err.Error(), tt.wantErr) require.Contains(t, err.Error(), tt.wantErr)
} }
}) })
} }
@ -265,33 +264,32 @@ func TestServerSideVerifier(t *testing.T) {
// cmp.Diff fail on tls.Config due to unexported fields in each. expectLeaf // cmp.Diff fail on tls.Config due to unexported fields in each. expectLeaf
// allows expecting a leaf cert different from the one in expect // allows expecting a leaf cert different from the one in expect
func requireEqualTLSConfig(t *testing.T, expect, got *tls.Config) { func requireEqualTLSConfig(t *testing.T, expect, got *tls.Config) {
require := require.New(t) require.Equal(t, expect.RootCAs, got.RootCAs)
require.Equal(expect.RootCAs, got.RootCAs)
assertDeepEqual(t, expect.ClientCAs, got.ClientCAs, cmpCertPool) assertDeepEqual(t, expect.ClientCAs, got.ClientCAs, cmpCertPool)
require.Equal(expect.InsecureSkipVerify, got.InsecureSkipVerify) require.Equal(t, expect.InsecureSkipVerify, got.InsecureSkipVerify)
require.Equal(expect.MinVersion, got.MinVersion) require.Equal(t, expect.MinVersion, got.MinVersion)
require.Equal(expect.CipherSuites, got.CipherSuites) require.Equal(t, expect.CipherSuites, got.CipherSuites)
require.NotNil(got.GetCertificate) require.NotNil(t, got.GetCertificate)
require.NotNil(got.GetClientCertificate) require.NotNil(t, got.GetClientCertificate)
require.NotNil(got.GetConfigForClient) require.NotNil(t, got.GetConfigForClient)
require.Contains(got.NextProtos, "h2") require.Contains(t, got.NextProtos, "h2")
var expectLeaf *tls.Certificate var expectLeaf *tls.Certificate
var err error var err error
if expect.GetCertificate != nil { if expect.GetCertificate != nil {
expectLeaf, err = expect.GetCertificate(nil) expectLeaf, err = expect.GetCertificate(nil)
require.Nil(err) require.Nil(t, err)
} else if len(expect.Certificates) > 0 { } else if len(expect.Certificates) > 0 {
expectLeaf = &expect.Certificates[0] expectLeaf = &expect.Certificates[0]
} }
gotLeaf, err := got.GetCertificate(nil) gotLeaf, err := got.GetCertificate(nil)
require.Nil(err) require.Nil(t, err)
require.Equal(expectLeaf, gotLeaf) require.Equal(t, expectLeaf, gotLeaf)
gotLeaf, err = got.GetClientCertificate(nil) gotLeaf, err = got.GetClientCertificate(nil)
require.Nil(err) require.Nil(t, err)
require.Equal(expectLeaf, gotLeaf) require.Equal(t, expectLeaf, gotLeaf)
} }
// cmpCertPool is a custom comparison for x509.CertPool, because CertPool.lazyCerts // cmpCertPool is a custom comparison for x509.CertPool, because CertPool.lazyCerts
@ -324,7 +322,6 @@ func requireCorrectVerifier(t *testing.T, expect, got *tls.Config,
} }
func TestDynamicTLSConfig(t *testing.T) { func TestDynamicTLSConfig(t *testing.T) {
require := require.New(t)
ca1 := connect.TestCA(t, nil) ca1 := connect.TestCA(t, nil)
ca2 := connect.TestCA(t, nil) ca2 := connect.TestCA(t, nil)
@ -334,8 +331,8 @@ func TestDynamicTLSConfig(t *testing.T) {
c := newDynamicTLSConfig(baseCfg, nil) c := newDynamicTLSConfig(baseCfg, nil)
// Should set them from the base config // Should set them from the base config
require.Equal(c.Leaf(), &baseCfg.Certificates[0]) require.Equal(t, c.Leaf(), &baseCfg.Certificates[0])
require.Equal(c.Roots(), baseCfg.RootCAs) require.Equal(t, c.Roots(), baseCfg.RootCAs)
// Create verifiers we can assert are set and run correctly. // Create verifiers we can assert are set and run correctly.
v1Ch := make(chan *tls.Config, 1) v1Ch := make(chan *tls.Config, 1)
@ -361,7 +358,7 @@ func TestDynamicTLSConfig(t *testing.T) {
// Now change the roots as if we just loaded new roots from Consul // Now change the roots as if we just loaded new roots from Consul
err := c.SetRoots(newCfg.RootCAs) err := c.SetRoots(newCfg.RootCAs)
require.Nil(err) require.Nil(t, err)
// The dynamic config should have the new roots, but old leaf // The dynamic config should have the new roots, but old leaf
gotAfter := c.Get(verify2) gotAfter := c.Get(verify2)
@ -378,7 +375,7 @@ func TestDynamicTLSConfig(t *testing.T) {
// Now change the leaf // Now change the leaf
err = c.SetLeaf(&newCfg.Certificates[0]) err = c.SetLeaf(&newCfg.Certificates[0])
require.Nil(err) require.Nil(t, err)
// The dynamic config should have the new roots, AND new leaf // The dynamic config should have the new roots, AND new leaf
gotAfterLeaf := c.Get(verify3) gotAfterLeaf := c.Get(verify3)
@ -392,7 +389,6 @@ func TestDynamicTLSConfig(t *testing.T) {
} }
func TestDynamicTLSConfig_Ready(t *testing.T) { func TestDynamicTLSConfig_Ready(t *testing.T) {
require := require.New(t)
ca1 := connect.TestCA(t, nil) ca1 := connect.TestCA(t, nil)
baseCfg := TestTLSConfig(t, "web", ca1) baseCfg := TestTLSConfig(t, "web", ca1)
@ -400,28 +396,28 @@ func TestDynamicTLSConfig_Ready(t *testing.T) {
c := newDynamicTLSConfig(defaultTLSConfig(), nil) c := newDynamicTLSConfig(defaultTLSConfig(), nil)
readyCh := c.ReadyWait() readyCh := c.ReadyWait()
assertBlocked(t, readyCh) assertBlocked(t, readyCh)
require.False(c.Ready(), "no roots or leaf, should not be ready") require.False(t, c.Ready(), "no roots or leaf, should not be ready")
err := c.SetLeaf(&baseCfg.Certificates[0]) err := c.SetLeaf(&baseCfg.Certificates[0])
require.NoError(err) require.NoError(t, err)
assertBlocked(t, readyCh) assertBlocked(t, readyCh)
require.False(c.Ready(), "no roots, should not be ready") require.False(t, c.Ready(), "no roots, should not be ready")
err = c.SetRoots(baseCfg.RootCAs) err = c.SetRoots(baseCfg.RootCAs)
require.NoError(err) require.NoError(t, err)
assertNotBlocked(t, readyCh) assertNotBlocked(t, readyCh)
require.True(c.Ready(), "should be ready") require.True(t, c.Ready(), "should be ready")
ca2 := connect.TestCA(t, nil) ca2 := connect.TestCA(t, nil)
ca2cfg := TestTLSConfig(t, "web", ca2) ca2cfg := TestTLSConfig(t, "web", ca2)
require.NoError(c.SetRoots(ca2cfg.RootCAs)) require.NoError(t, c.SetRoots(ca2cfg.RootCAs))
assertNotBlocked(t, readyCh) assertNotBlocked(t, readyCh)
require.False(c.Ready(), "invalid leaf, should not be ready") require.False(t, c.Ready(), "invalid leaf, should not be ready")
require.NoError(c.SetRoots(baseCfg.RootCAs)) require.NoError(t, c.SetRoots(baseCfg.RootCAs))
assertNotBlocked(t, readyCh) assertNotBlocked(t, readyCh)
require.True(c.Ready(), "should be ready") require.True(t, c.Ready(), "should be ready")
} }
func assertBlocked(t *testing.T, ch <-chan struct{}) { func assertBlocked(t *testing.T, ch <-chan struct{}) {

View File

@ -13,9 +13,8 @@ import (
// tests that it just writes the file properly. I would love to test this // tests that it just writes the file properly. I would love to test this
// better but I'm not sure how. -mitchellh // better but I'm not sure how. -mitchellh
func TestWriteAtomic(t *testing.T) { func TestWriteAtomic(t *testing.T) {
require := require.New(t)
td, err := ioutil.TempDir("", "lib-file") td, err := ioutil.TempDir("", "lib-file")
require.NoError(err) require.NoError(t, err)
defer os.RemoveAll(td) defer os.RemoveAll(td)
// Create a subdir that doesn't exist to test that it is created // Create a subdir that doesn't exist to test that it is created
@ -23,10 +22,10 @@ func TestWriteAtomic(t *testing.T) {
// Write // Write
expected := []byte("hello") expected := []byte("hello")
require.NoError(WriteAtomic(path, expected)) require.NoError(t, WriteAtomic(path, expected))
// Read and verify // Read and verify
actual, err := ioutil.ReadFile(path) actual, err := ioutil.ReadFile(path)
require.NoError(err) require.NoError(t, err)
require.Equal(expected, actual) require.Equal(t, expected, actual)
} }

View File

@ -12,12 +12,11 @@ import (
) )
func TestLogger_SetupBasic(t *testing.T) { func TestLogger_SetupBasic(t *testing.T) {
require := require.New(t)
cfg := Config{LogLevel: "INFO"} cfg := Config{LogLevel: "INFO"}
logger, err := Setup(cfg, nil) logger, err := Setup(cfg, nil)
require.NoError(err) require.NoError(t, err)
require.NotNil(logger) require.NotNil(t, logger)
} }
func TestLogger_SetupInvalidLogLevel(t *testing.T) { func TestLogger_SetupInvalidLogLevel(t *testing.T) {
@ -52,44 +51,41 @@ func TestLogger_SetupLoggerErrorLevel(t *testing.T) {
var cfg Config var cfg Config
c.before(&cfg) c.before(&cfg)
require := require.New(t)
var buf bytes.Buffer var buf bytes.Buffer
logger, err := Setup(cfg, &buf) logger, err := Setup(cfg, &buf)
require.NoError(err) require.NoError(t, err)
require.NotNil(logger) require.NotNil(t, logger)
logger.Error("test error msg") logger.Error("test error msg")
logger.Info("test info msg") logger.Info("test info msg")
output := buf.String() output := buf.String()
require.Contains(output, "[ERROR] test error msg") require.Contains(t, output, "[ERROR] test error msg")
require.NotContains(output, "[INFO] test info msg") require.NotContains(t, output, "[INFO] test info msg")
}) })
} }
} }
func TestLogger_SetupLoggerDebugLevel(t *testing.T) { func TestLogger_SetupLoggerDebugLevel(t *testing.T) {
require := require.New(t)
cfg := Config{LogLevel: "DEBUG"} cfg := Config{LogLevel: "DEBUG"}
var buf bytes.Buffer var buf bytes.Buffer
logger, err := Setup(cfg, &buf) logger, err := Setup(cfg, &buf)
require.NoError(err) require.NoError(t, err)
require.NotNil(logger) require.NotNil(t, logger)
logger.Info("test info msg") logger.Info("test info msg")
logger.Debug("test debug msg") logger.Debug("test debug msg")
output := buf.String() output := buf.String()
require.Contains(output, "[INFO] test info msg") require.Contains(t, output, "[INFO] test info msg")
require.Contains(output, "[DEBUG] test debug msg") require.Contains(t, output, "[DEBUG] test debug msg")
} }
func TestLogger_SetupLoggerWithName(t *testing.T) { func TestLogger_SetupLoggerWithName(t *testing.T) {
require := require.New(t)
cfg := Config{ cfg := Config{
LogLevel: "DEBUG", LogLevel: "DEBUG",
Name: "test-system", Name: "test-system",
@ -97,16 +93,15 @@ func TestLogger_SetupLoggerWithName(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
logger, err := Setup(cfg, &buf) logger, err := Setup(cfg, &buf)
require.NoError(err) require.NoError(t, err)
require.NotNil(logger) require.NotNil(t, logger)
logger.Warn("test warn msg") logger.Warn("test warn msg")
require.Contains(buf.String(), "[WARN] test-system: test warn msg") require.Contains(t, buf.String(), "[WARN] test-system: test warn msg")
} }
func TestLogger_SetupLoggerWithJSON(t *testing.T) { func TestLogger_SetupLoggerWithJSON(t *testing.T) {
require := require.New(t)
cfg := Config{ cfg := Config{
LogLevel: "DEBUG", LogLevel: "DEBUG",
LogJSON: true, LogJSON: true,
@ -115,22 +110,21 @@ func TestLogger_SetupLoggerWithJSON(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
logger, err := Setup(cfg, &buf) logger, err := Setup(cfg, &buf)
require.NoError(err) require.NoError(t, err)
require.NotNil(logger) require.NotNil(t, logger)
logger.Warn("test warn msg") logger.Warn("test warn msg")
var jsonOutput map[string]string var jsonOutput map[string]string
err = json.Unmarshal(buf.Bytes(), &jsonOutput) err = json.Unmarshal(buf.Bytes(), &jsonOutput)
require.NoError(err) require.NoError(t, err)
require.Contains(jsonOutput, "@level") require.Contains(t, jsonOutput, "@level")
require.Equal(jsonOutput["@level"], "warn") require.Equal(t, jsonOutput["@level"], "warn")
require.Contains(jsonOutput, "@message") require.Contains(t, jsonOutput, "@message")
require.Equal(jsonOutput["@message"], "test warn msg") require.Equal(t, jsonOutput["@message"], "test warn msg")
} }
func TestLogger_SetupLoggerWithValidLogPath(t *testing.T) { func TestLogger_SetupLoggerWithValidLogPath(t *testing.T) {
require := require.New(t)
tmpDir := testutil.TempDir(t, t.Name()) tmpDir := testutil.TempDir(t, t.Name())
@ -141,12 +135,11 @@ func TestLogger_SetupLoggerWithValidLogPath(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
logger, err := Setup(cfg, &buf) logger, err := Setup(cfg, &buf)
require.NoError(err) require.NoError(t, err)
require.NotNil(logger) require.NotNil(t, logger)
} }
func TestLogger_SetupLoggerWithInValidLogPath(t *testing.T) { func TestLogger_SetupLoggerWithInValidLogPath(t *testing.T) {
require := require.New(t)
cfg := Config{ cfg := Config{
LogLevel: "INFO", LogLevel: "INFO",
@ -155,13 +148,12 @@ func TestLogger_SetupLoggerWithInValidLogPath(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
logger, err := Setup(cfg, &buf) logger, err := Setup(cfg, &buf)
require.Error(err) require.Error(t, err)
require.True(errors.Is(err, os.ErrNotExist)) require.True(t, errors.Is(err, os.ErrNotExist))
require.Nil(logger) require.Nil(t, logger)
} }
func TestLogger_SetupLoggerWithInValidLogPathPermission(t *testing.T) { func TestLogger_SetupLoggerWithInValidLogPathPermission(t *testing.T) {
require := require.New(t)
tmpDir := "/tmp/" + t.Name() tmpDir := "/tmp/" + t.Name()
@ -175,7 +167,7 @@ func TestLogger_SetupLoggerWithInValidLogPathPermission(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
logger, err := Setup(cfg, &buf) logger, err := Setup(cfg, &buf)
require.Error(err) require.Error(t, err)
require.True(errors.Is(err, os.ErrPermission)) require.True(t, errors.Is(err, os.ErrPermission))
require.Nil(logger) require.Nil(t, logger)
} }

View File

@ -10,7 +10,6 @@ import (
) )
func TestMonitor_Start(t *testing.T) { func TestMonitor_Start(t *testing.T) {
require := require.New(t)
logger := log.NewInterceptLogger(&log.LoggerOptions{ logger := log.NewInterceptLogger(&log.LoggerOptions{
Level: log.Error, Level: log.Error,
@ -31,7 +30,7 @@ func TestMonitor_Start(t *testing.T) {
for { for {
select { select {
case log := <-logCh: case log := <-logCh:
require.Contains(string(log), "[DEBUG] test log") require.Contains(t, string(log), "[DEBUG] test log")
return return
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
t.Fatal("Expected to receive from log channel") t.Fatal("Expected to receive from log channel")
@ -44,8 +43,6 @@ func TestMonitor_Stop(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
require := require.New(t)
logger := log.NewInterceptLogger(&log.LoggerOptions{ logger := log.NewInterceptLogger(&log.LoggerOptions{
Level: log.Error, Level: log.Error,
}) })
@ -60,7 +57,7 @@ func TestMonitor_Stop(t *testing.T) {
logCh := m.Start() logCh := m.Start()
logger.Debug("test log") logger.Debug("test log")
require.Eventually(func() bool { require.Eventually(t, func() bool {
return len(logCh) == 1 return len(logCh) == 1
}, 3*time.Second, 100*time.Millisecond, "expected logCh to have 1 log in it") }, 3*time.Second, 100*time.Millisecond, "expected logCh to have 1 log in it")
@ -73,7 +70,7 @@ func TestMonitor_Stop(t *testing.T) {
select { select {
case log := <-logCh: case log := <-logCh:
if string(log) != "" { if string(log) != "" {
require.Contains(string(log), "[DEBUG] test log") require.Contains(t, string(log), "[DEBUG] test log")
} else { } else {
return return
} }
@ -88,8 +85,6 @@ func TestMonitor_DroppedMessages(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
require := require.New(t)
logger := log.NewInterceptLogger(&log.LoggerOptions{ logger := log.NewInterceptLogger(&log.LoggerOptions{
Level: log.Warn, Level: log.Warn,
}) })
@ -118,7 +113,7 @@ func TestMonitor_DroppedMessages(t *testing.T) {
} }
// Make sure we do not stop before the goroutines have time to process. // Make sure we do not stop before the goroutines have time to process.
require.Eventually(func() bool { require.Eventually(t, func() bool {
return len(logCh) == mcfg.BufferSize return len(logCh) == mcfg.BufferSize
}, 3*time.Second, 100*time.Millisecond, "expected logCh to have a full log buffer") }, 3*time.Second, 100*time.Millisecond, "expected logCh to have a full log buffer")
@ -126,7 +121,7 @@ func TestMonitor_DroppedMessages(t *testing.T) {
// The number of dropped messages is non-deterministic, so we only assert // The number of dropped messages is non-deterministic, so we only assert
// that we dropped at least 1. // that we dropped at least 1.
require.GreaterOrEqual(dropped, 1) require.GreaterOrEqual(t, dropped, 1)
} }
func TestMonitor_ZeroBufSizeDefault(t *testing.T) { func TestMonitor_ZeroBufSizeDefault(t *testing.T) {
@ -134,8 +129,6 @@ func TestMonitor_ZeroBufSizeDefault(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
require := require.New(t)
logger := log.NewInterceptLogger(&log.LoggerOptions{ logger := log.NewInterceptLogger(&log.LoggerOptions{
Level: log.Error, Level: log.Error,
}) })
@ -154,14 +147,14 @@ func TestMonitor_ZeroBufSizeDefault(t *testing.T) {
// If we do not default the buffer size, the monitor will be unable to buffer // If we do not default the buffer size, the monitor will be unable to buffer
// a log line. // a log line.
require.Eventually(func() bool { require.Eventually(t, func() bool {
return len(logCh) == 1 return len(logCh) == 1
}, 3*time.Second, 100*time.Millisecond, "expected logCh to have 1 log buffered") }, 3*time.Second, 100*time.Millisecond, "expected logCh to have 1 log buffered")
for { for {
select { select {
case log := <-logCh: case log := <-logCh:
require.Contains(string(log), "[DEBUG] test log") require.Contains(t, string(log), "[DEBUG] test log")
return return
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
t.Fatal("Expected to receive from log channel") t.Fatal("Expected to receive from log channel")
@ -170,7 +163,6 @@ func TestMonitor_ZeroBufSizeDefault(t *testing.T) {
} }
func TestMonitor_WriteStopped(t *testing.T) { func TestMonitor_WriteStopped(t *testing.T) {
require := require.New(t)
logger := log.NewInterceptLogger(&log.LoggerOptions{ logger := log.NewInterceptLogger(&log.LoggerOptions{
Level: log.Error, Level: log.Error,
@ -183,6 +175,6 @@ func TestMonitor_WriteStopped(t *testing.T) {
mwriter.Stop() mwriter.Stop()
n, err := mwriter.Write([]byte("write after close")) n, err := mwriter.Write([]byte("write after close"))
require.Equal(n, 0) require.Equal(t, n, 0)
require.EqualError(err, "monitor stopped") require.EqualError(t, err, "monitor stopped")
} }