peering: initial sync (#12842)

- Add endpoints related to peering: read, list, generate token, initiate peering
- Update node/service/check table indexing to account for peers
- Foundational changes for pushing service updates to a peer
- Plumb peer name through Health.ServiceNodes path

see: ENT-1765, ENT-1280, ENT-1283, ENT-1283, ENT-1756, ENT-1739, ENT-1750, ENT-1679,
     ENT-1709, ENT-1704, ENT-1690, ENT-1689, ENT-1702, ENT-1701, ENT-1683, ENT-1663,
     ENT-1650, ENT-1678, ENT-1628, ENT-1658, ENT-1640, ENT-1637, ENT-1597, ENT-1634,
     ENT-1613, ENT-1616, ENT-1617, ENT-1591, ENT-1588, ENT-1596, ENT-1572, ENT-1555

Co-authored-by: R.B. Boyer <rb@hashicorp.com>
Co-authored-by: freddygv <freddy@hashicorp.com>
Co-authored-by: Chris S. Kim <ckim@hashicorp.com>
Co-authored-by: Evan Culver <eculver@hashicorp.com>
Co-authored-by: Nitya Dhanushkodi <nitya@hashicorp.com>
This commit is contained in:
R.B. Boyer 2022-04-21 17:34:40 -05:00 committed by GitHub
parent 45ffdc360e
commit 809344a6f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
140 changed files with 14159 additions and 2128 deletions

3
.changelog/_1679.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:breaking-change
config-entry: Exporting a specific service name across all namespace is invalid.
```

View File

@ -3,7 +3,9 @@
package acl package acl
const DefaultPartitionName = "" const (
DefaultPartitionName = ""
)
// Reviewer Note: This is a little bit strange; one might want it to be "" like partition name // Reviewer Note: This is a little bit strange; one might want it to be "" like partition name
// However in consul/structs/intention.go we define IntentionDefaultNamespace as 'default' and so // However in consul/structs/intention.go we define IntentionDefaultNamespace as 'default' and so

View File

@ -106,3 +106,7 @@ func NewEnterpriseMetaWithPartition(_, _ string) EnterpriseMeta {
// FillAuthzContext stub // FillAuthzContext stub
func (_ *EnterpriseMeta) FillAuthzContext(_ *AuthorizerContext) {} func (_ *EnterpriseMeta) FillAuthzContext(_ *AuthorizerContext) {}
func NormalizeNamespace(_ string) string {
return ""
}

View File

@ -20,6 +20,7 @@ import (
"github.com/armon/go-metrics" "github.com/armon/go-metrics"
"github.com/armon/go-metrics/prometheus" "github.com/armon/go-metrics/prometheus"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/go-connlimit" "github.com/hashicorp/go-connlimit"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
@ -357,6 +358,8 @@ type Agent struct {
// into Agent, which will allow us to remove this field. // into Agent, which will allow us to remove this field.
rpcClientHealth *health.Client rpcClientHealth *health.Client
rpcClientPeering pbpeering.PeeringServiceClient
// routineManager is responsible for managing longer running go routines // routineManager is responsible for managing longer running go routines
// run by the Agent // run by the Agent
routineManager *routine.Manager routineManager *routine.Manager
@ -434,6 +437,8 @@ func New(bd BaseDeps) (*Agent, error) {
QueryOptionDefaults: config.ApplyDefaultQueryOptions(a.config), QueryOptionDefaults: config.ApplyDefaultQueryOptions(a.config),
} }
a.rpcClientPeering = pbpeering.NewPeeringServiceClient(conn)
a.serviceManager = NewServiceManager(&a) a.serviceManager = NewServiceManager(&a)
// We used to do this in the Start method. However it doesn't need to go // We used to do this in the Start method. However it doesn't need to go

View File

@ -27,7 +27,7 @@ type DirectRPC interface {
// agent/cache.Cache struct that we care about // agent/cache.Cache struct that we care about
type Cache interface { type Cache interface {
Notify(ctx context.Context, t string, r cache.Request, correlationID string, ch chan<- cache.UpdateEvent) error Notify(ctx context.Context, t string, r cache.Request, correlationID string, ch chan<- cache.UpdateEvent) error
Prepopulate(t string, result cache.FetchResult, dc string, token string, key string) error Prepopulate(t string, result cache.FetchResult, dc string, peerName string, token string, key string) error
} }
// ServerProvider is an interface that can be used to find one server in the local DC known to // ServerProvider is an interface that can be used to find one server in the local DC known to

View File

@ -137,7 +137,7 @@ func (m *mockCache) Notify(ctx context.Context, t string, r cache.Request, corre
return err return err
} }
func (m *mockCache) Prepopulate(t string, result cache.FetchResult, dc string, token string, key string) error { func (m *mockCache) Prepopulate(t string, result cache.FetchResult, dc string, peerName string, token string, key string) error {
var restore string var restore string
cert, ok := result.Value.(*structs.IssuedCert) cert, ok := result.Value.(*structs.IssuedCert)
if ok { if ok {
@ -147,7 +147,7 @@ func (m *mockCache) Prepopulate(t string, result cache.FetchResult, dc string, t
cert.PrivateKeyPEM = "redacted" cert.PrivateKeyPEM = "redacted"
} }
ret := m.Called(t, result, dc, token, key) ret := m.Called(t, result, dc, peerName, token, key)
if ok && restore != "" { if ok && restore != "" {
cert.PrivateKeyPEM = restore cert.PrivateKeyPEM = restore
@ -304,6 +304,7 @@ func (m *mockedConfig) expectInitialTLS(t *testing.T, agentName, datacenter, tok
rootRes, rootRes,
datacenter, datacenter,
"", "",
"",
rootsReq.CacheInfo().Key, rootsReq.CacheInfo().Key,
).Return(nil).Once() ).Return(nil).Once()
@ -330,6 +331,7 @@ func (m *mockedConfig) expectInitialTLS(t *testing.T, agentName, datacenter, tok
cachetype.ConnectCALeafName, cachetype.ConnectCALeafName,
leafRes, leafRes,
datacenter, datacenter,
"",
token, token,
leafReq.Key(), leafReq.Key(),
).Return(nil).Once() ).Return(nil).Once()

View File

@ -96,7 +96,7 @@ func (ac *AutoConfig) populateCertificateCache(certs *structs.SignedResponse) er
rootRes := cache.FetchResult{Value: &certs.ConnectCARoots, Index: certs.ConnectCARoots.QueryMeta.Index} rootRes := cache.FetchResult{Value: &certs.ConnectCARoots, Index: certs.ConnectCARoots.QueryMeta.Index}
rootsReq := ac.caRootsRequest() rootsReq := ac.caRootsRequest()
// getting the roots doesn't require a token so in order to potentially share the cache with another // getting the roots doesn't require a token so in order to potentially share the cache with another
if err := ac.acConfig.Cache.Prepopulate(cachetype.ConnectCARootName, rootRes, ac.config.Datacenter, "", rootsReq.CacheInfo().Key); err != nil { if err := ac.acConfig.Cache.Prepopulate(cachetype.ConnectCARootName, rootRes, ac.config.Datacenter, structs.DefaultPeerKeyword, "", rootsReq.CacheInfo().Key); err != nil {
return err return err
} }
@ -108,7 +108,7 @@ func (ac *AutoConfig) populateCertificateCache(certs *structs.SignedResponse) er
Index: certs.IssuedCert.RaftIndex.ModifyIndex, Index: certs.IssuedCert.RaftIndex.ModifyIndex,
State: cachetype.ConnectCALeafSuccess(connect.EncodeSigningKeyID(cert.AuthorityKeyId)), State: cachetype.ConnectCALeafSuccess(connect.EncodeSigningKeyID(cert.AuthorityKeyId)),
} }
if err := ac.acConfig.Cache.Prepopulate(cachetype.ConnectCALeafName, certRes, leafReq.Datacenter, leafReq.Token, leafReq.Key()); err != nil { if err := ac.acConfig.Cache.Prepopulate(cachetype.ConnectCALeafName, certRes, leafReq.Datacenter, structs.DefaultPeerKeyword, leafReq.Token, leafReq.Key()); err != nil {
return err return err
} }

View File

@ -5,10 +5,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
) )
func TestCatalogListServices(t *testing.T) { func TestCatalogListServices(t *testing.T) {
@ -104,7 +105,7 @@ func TestCatalogListServices_IntegrationWithCache_NotModifiedResponse(t *testing
}, },
} }
err := c.Prepopulate(CatalogListServicesName, last, "dc1", "token", req.CacheInfo().Key) err := c.Prepopulate(CatalogListServicesName, last, "dc1", "", "token", req.CacheInfo().Key)
require.NoError(t, err) require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())

View File

@ -0,0 +1,92 @@
// Code generated by mockery v2.11.0. DO NOT EDIT.
package cachetype
import (
local "github.com/hashicorp/consul/agent/local"
memdb "github.com/hashicorp/go-memdb"
mock "github.com/stretchr/testify/mock"
structs "github.com/hashicorp/consul/agent/structs"
testing "testing"
time "time"
)
// MockAgent is an autogenerated mock type for the Agent type
type MockAgent struct {
mock.Mock
}
// LocalBlockingQuery provides a mock function with given fields: alwaysBlock, hash, wait, fn
func (_m *MockAgent) LocalBlockingQuery(alwaysBlock bool, hash string, wait time.Duration, fn func(memdb.WatchSet) (string, interface{}, error)) (string, interface{}, error) {
ret := _m.Called(alwaysBlock, hash, wait, fn)
var r0 string
if rf, ok := ret.Get(0).(func(bool, string, time.Duration, func(memdb.WatchSet) (string, interface{}, error)) string); ok {
r0 = rf(alwaysBlock, hash, wait, fn)
} else {
r0 = ret.Get(0).(string)
}
var r1 interface{}
if rf, ok := ret.Get(1).(func(bool, string, time.Duration, func(memdb.WatchSet) (string, interface{}, error)) interface{}); ok {
r1 = rf(alwaysBlock, hash, wait, fn)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(interface{})
}
}
var r2 error
if rf, ok := ret.Get(2).(func(bool, string, time.Duration, func(memdb.WatchSet) (string, interface{}, error)) error); ok {
r2 = rf(alwaysBlock, hash, wait, fn)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// LocalState provides a mock function with given fields:
func (_m *MockAgent) LocalState() *local.State {
ret := _m.Called()
var r0 *local.State
if rf, ok := ret.Get(0).(func() *local.State); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*local.State)
}
}
return r0
}
// ServiceHTTPBasedChecks provides a mock function with given fields: id
func (_m *MockAgent) ServiceHTTPBasedChecks(id structs.ServiceID) []structs.CheckType {
ret := _m.Called(id)
var r0 []structs.CheckType
if rf, ok := ret.Get(0).(func(structs.ServiceID) []structs.CheckType); ok {
r0 = rf(id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]structs.CheckType)
}
}
return r0
}
// NewMockAgent creates a new instance of MockAgent. It also registers a cleanup function to assert the mocks expectations.
func NewMockAgent(t testing.TB) *MockAgent {
mock := &MockAgent{}
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

14
agent/cache/cache.go vendored
View File

@ -91,7 +91,7 @@ const (
// struct in agent/structs. This API makes cache usage a mostly drop-in // struct in agent/structs. This API makes cache usage a mostly drop-in
// replacement for non-cached RPC calls. // replacement for non-cached RPC calls.
// //
// The cache is partitioned by ACL and datacenter. This allows the cache // The cache is partitioned by ACL and datacenter/peer. This allows the cache
// to be safe for multi-DC queries and for queries where the data is modified // to be safe for multi-DC queries and for queries where the data is modified
// due to ACLs all without the cache having to have any clever logic, at // due to ACLs all without the cache having to have any clever logic, at
// the slight expense of a less perfect cache. // the slight expense of a less perfect cache.
@ -406,7 +406,7 @@ func (c *Cache) getWithIndex(ctx context.Context, r getOptions) (interface{}, Re
return result.Value, ResultMeta{}, err return result.Value, ResultMeta{}, err
} }
key := makeEntryKey(r.TypeEntry.Name, r.Info.Datacenter, r.Info.Token, r.Info.Key) key := makeEntryKey(r.TypeEntry.Name, r.Info.Datacenter, r.Info.PeerName, r.Info.Token, r.Info.Key)
// First time through // First time through
first := true first := true
@ -526,7 +526,11 @@ RETRY_GET:
} }
} }
func makeEntryKey(t, dc, token, key string) string { func makeEntryKey(t, dc, peerName, token, key string) string {
// TODO(peering): figure out if this is the desired format
if peerName != "" {
return fmt.Sprintf("%s/%s/%s/%s", t, "peer:"+peerName, token, key)
}
return fmt.Sprintf("%s/%s/%s/%s", t, dc, token, key) return fmt.Sprintf("%s/%s/%s/%s", t, dc, token, key)
} }
@ -884,8 +888,8 @@ func (c *Cache) Close() error {
// on startup. It is used to set the ConnectRootCA and AgentLeafCert when // on startup. It is used to set the ConnectRootCA and AgentLeafCert when
// AutoEncrypt.TLS is turned on. The cache itself cannot fetch that the first // AutoEncrypt.TLS is turned on. The cache itself cannot fetch that the first
// time because it requires a special RPCType. Subsequent runs are fine though. // time because it requires a special RPCType. Subsequent runs are fine though.
func (c *Cache) Prepopulate(t string, res FetchResult, dc, token, k string) error { func (c *Cache) Prepopulate(t string, res FetchResult, dc, peerName, token, k string) error {
key := makeEntryKey(t, dc, token, k) key := makeEntryKey(t, dc, peerName, token, k)
newEntry := cacheEntry{ newEntry := cacheEntry{
Valid: true, Valid: true,
Value: res.Value, Value: res.Value,

View File

@ -1545,7 +1545,7 @@ func TestCacheReload(t *testing.T) {
c.entriesLock.Lock() c.entriesLock.Lock()
tEntry, ok := c.types["t1"] tEntry, ok := c.types["t1"]
require.True(t, ok) require.True(t, ok)
keyName := makeEntryKey("t1", "", "", "hello1") keyName := makeEntryKey("t1", "", "", "", "hello1")
ok, entryValid, entry := c.getEntryLocked(tEntry, keyName, RequestInfo{}) ok, entryValid, entry := c.getEntryLocked(tEntry, keyName, RequestInfo{})
require.True(t, ok) require.True(t, ok)
require.True(t, entryValid) require.True(t, entryValid)
@ -1687,7 +1687,7 @@ func TestCache_Prepopulate(t *testing.T) {
c := New(Options{}) c := New(Options{})
c.RegisterType("t", typ) c.RegisterType("t", typ)
c.Prepopulate("t", FetchResult{Value: 17, Index: 1}, "dc1", "token", "v1") c.Prepopulate("t", FetchResult{Value: 17, Index: 1}, "dc1", "", "token", "v1")
ctx := context.Background() ctx := context.Background()
req := fakeRequest{ req := fakeRequest{
@ -1740,7 +1740,7 @@ func TestCache_RefreshLifeCycle(t *testing.T) {
c := New(Options{}) c := New(Options{})
c.RegisterType("t", typ) c.RegisterType("t", typ)
key := makeEntryKey("t", "dc1", "token", "v1") key := makeEntryKey("t", "dc1", "", "token", "v1")
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()

View File

@ -16,6 +16,9 @@ type Request interface {
// RequestInfo represents cache information for a request. The caching // RequestInfo represents cache information for a request. The caching
// framework uses this to control the behavior of caching and to determine // framework uses this to control the behavior of caching and to determine
// cacheability. // cacheability.
//
// TODO(peering): finish ensuring everything that sets a Datacenter sets or doesn't set PeerName.
// TODO(peering): also make sure the peer name is present in the cache key likely in lieu of the datacenter somehow.
type RequestInfo struct { type RequestInfo struct {
// Key is a unique cache key for this request. This key should // Key is a unique cache key for this request. This key should
// be globally unique to identify this request, since any conflicting // be globally unique to identify this request, since any conflicting
@ -28,14 +31,17 @@ type RequestInfo struct {
// //
// Datacenter is the datacenter that the request is targeting. // Datacenter is the datacenter that the request is targeting.
// //
// Both of these values are used to partition the cache. The cache framework // PeerName is the peer that the request is targeting.
//
// All of these values are used to partition the cache. The cache framework
// today partitions data on these values to simplify behavior: by // today partitions data on these values to simplify behavior: by
// partitioning ACL tokens, the cache doesn't need to be smart about // partitioning ACL tokens, the cache doesn't need to be smart about
// filtering results. By filtering datacenter results, the cache can // filtering results. By filtering datacenter/peer results, the cache can
// service the multi-DC nature of Consul. This comes at the expense of // service the multi-DC/multi-peer nature of Consul. This comes at the expense of
// working set size, but in general the effect is minimal. // working set size, but in general the effect is minimal.
Token string Token string
Datacenter string Datacenter string
PeerName string
// MinIndex is the minimum index being queried. This is used to // MinIndex is the minimum index being queried. This is used to
// determine if we already have data satisfying the query or if we need // determine if we already have data satisfying the query or if we need

View File

@ -1174,7 +1174,21 @@ func (r *ACLResolver) ACLsEnabled() bool {
return true return true
} }
func (r *ACLResolver) ResolveTokenAndDefaultMeta(token string, entMeta *acl.EnterpriseMeta, authzContext *acl.AuthorizerContext) (ACLResolveResult, error) { // TODO(peering): fix all calls to use the new signature and rename it back
func (r *ACLResolver) ResolveTokenAndDefaultMeta(
token string,
entMeta *acl.EnterpriseMeta,
authzContext *acl.AuthorizerContext,
) (ACLResolveResult, error) {
return r.ResolveTokenAndDefaultMetaWithPeerName(token, entMeta, structs.DefaultPeerKeyword, authzContext)
}
func (r *ACLResolver) ResolveTokenAndDefaultMetaWithPeerName(
token string,
entMeta *acl.EnterpriseMeta,
peerName string,
authzContext *acl.AuthorizerContext,
) (ACLResolveResult, error) {
result, err := r.ResolveToken(token) result, err := r.ResolveToken(token)
if err != nil { if err != nil {
return ACLResolveResult{}, err return ACLResolveResult{}, err
@ -1186,9 +1200,19 @@ func (r *ACLResolver) ResolveTokenAndDefaultMeta(token string, entMeta *acl.Ente
// Default the EnterpriseMeta based on the Tokens meta or actual defaults // Default the EnterpriseMeta based on the Tokens meta or actual defaults
// in the case of unknown identity // in the case of unknown identity
if result.ACLIdentity != nil { switch {
case peerName == "" && result.ACLIdentity != nil:
entMeta.Merge(result.ACLIdentity.EnterpriseMetadata()) entMeta.Merge(result.ACLIdentity.EnterpriseMetadata())
} else { case result.ACLIdentity != nil:
// We _do not_ normalize the enterprise meta from the token when a peer
// name was specified because namespaces across clusters are not
// equivalent. A local namespace is _never_ correct for a remote query.
entMeta.Merge(
structs.DefaultEnterpriseMetaInPartition(
result.ACLIdentity.EnterpriseMetadata().PartitionOrDefault(),
),
)
default:
entMeta.Merge(structs.DefaultEnterpriseMetaInDefaultPartition()) entMeta.Merge(structs.DefaultEnterpriseMetaInDefaultPartition())
} }

View File

@ -11,12 +11,11 @@ import (
"testing" "testing"
"time" "time"
msgpackrpc "github.com/hashicorp/consul-net-rpc/net-rpc-msgpackrpc"
"github.com/hashicorp/memberlist" "github.com/hashicorp/memberlist"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
msgpackrpc "github.com/hashicorp/consul-net-rpc/net-rpc-msgpackrpc"
"github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest" "github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"

View File

@ -142,7 +142,7 @@ func (s *Server) autopilotServerFromMetadata(srv *metadata.Server) (*autopilot.S
// populate the node meta if there is any. When a node first joins or if // populate the node meta if there is any. When a node first joins or if
// there are ACL issues then this could be empty if the server has not // there are ACL issues then this could be empty if the server has not
// yet been able to register itself in the catalog // yet been able to register itself in the catalog
_, node, err := s.fsm.State().GetNodeID(types.NodeID(srv.ID), structs.NodeEnterpriseMetaInDefaultPartition()) _, node, err := s.fsm.State().GetNodeID(types.NodeID(srv.ID), structs.NodeEnterpriseMetaInDefaultPartition(), structs.DefaultPeerKeyword)
if err != nil { if err != nil {
return nil, fmt.Errorf("error retrieving node from state store: %w", err) return nil, fmt.Errorf("error retrieving node from state store: %w", err)
} }

View File

@ -18,20 +18,20 @@ type MockStateStore struct {
mock.Mock mock.Mock
} }
// GetNodeID provides a mock function with given fields: _a0, _a1 // GetNodeID provides a mock function with given fields: _a0, _a1, _a2
func (_m *MockStateStore) GetNodeID(_a0 types.NodeID, _a1 *acl.EnterpriseMeta) (uint64, *structs.Node, error) { func (_m *MockStateStore) GetNodeID(_a0 types.NodeID, _a1 *acl.EnterpriseMeta, _a2 string) (uint64, *structs.Node, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1, _a2)
var r0 uint64 var r0 uint64
if rf, ok := ret.Get(0).(func(types.NodeID, *acl.EnterpriseMeta) uint64); ok { if rf, ok := ret.Get(0).(func(types.NodeID, *acl.EnterpriseMeta, string) uint64); ok {
r0 = rf(_a0, _a1) r0 = rf(_a0, _a1, _a2)
} else { } else {
r0 = ret.Get(0).(uint64) r0 = ret.Get(0).(uint64)
} }
var r1 *structs.Node var r1 *structs.Node
if rf, ok := ret.Get(1).(func(types.NodeID, *acl.EnterpriseMeta) *structs.Node); ok { if rf, ok := ret.Get(1).(func(types.NodeID, *acl.EnterpriseMeta, string) *structs.Node); ok {
r1 = rf(_a0, _a1) r1 = rf(_a0, _a1, _a2)
} else { } else {
if ret.Get(1) != nil { if ret.Get(1) != nil {
r1 = ret.Get(1).(*structs.Node) r1 = ret.Get(1).(*structs.Node)
@ -39,8 +39,8 @@ func (_m *MockStateStore) GetNodeID(_a0 types.NodeID, _a1 *acl.EnterpriseMeta) (
} }
var r2 error var r2 error
if rf, ok := ret.Get(2).(func(types.NodeID, *acl.EnterpriseMeta) error); ok { if rf, ok := ret.Get(2).(func(types.NodeID, *acl.EnterpriseMeta, string) error); ok {
r2 = rf(_a0, _a1) r2 = rf(_a0, _a1, _a2)
} else { } else {
r2 = ret.Error(2) r2 = ret.Error(2)
} }

View File

@ -12,6 +12,7 @@ import (
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbsubscribe"
"github.com/hashicorp/consul/types" "github.com/hashicorp/consul/types"
) )
@ -70,6 +71,12 @@ func (e EventPayloadReadyServers) HasReadPermission(authz acl.Authorizer) bool {
return authz.ServiceWriteAny(&authzContext) == acl.Allow return authz.ServiceWriteAny(&authzContext) == acl.Allow
} }
func (e EventPayloadReadyServers) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
// TODO(peering) is this right?
// TODO(agentless) is this right?
panic("EventPayloadReadyServers does not implement ToSubscriptionEvent")
}
func ExtractEventPayload(event stream.Event) (EventPayloadReadyServers, error) { func ExtractEventPayload(event stream.Event) (EventPayloadReadyServers, error) {
if event.Topic != EventTopicReadyServers { if event.Topic != EventTopicReadyServers {
return nil, fmt.Errorf("unexpected topic (%q) for a %q event", event.Topic, EventTopicReadyServers) return nil, fmt.Errorf("unexpected topic (%q) for a %q event", event.Topic, EventTopicReadyServers)
@ -114,7 +121,7 @@ func NewReadyServersEventPublisher(config Config) *ReadyServersEventPublisher {
//go:generate mockery --name StateStore --inpackage --testonly //go:generate mockery --name StateStore --inpackage --testonly
type StateStore interface { type StateStore interface {
GetNodeID(types.NodeID, *acl.EnterpriseMeta) (uint64, *structs.Node, error) GetNodeID(types.NodeID, *acl.EnterpriseMeta, string) (uint64, *structs.Node, error)
} }
//go:generate mockery --name Publisher --inpackage --testonly //go:generate mockery --name Publisher --inpackage --testonly
@ -245,7 +252,7 @@ func (r *ReadyServersEventPublisher) getTaggedAddresses(srv *autopilot.ServerSta
// from the catalog at that often and publish the events. So while its not quite // from the catalog at that often and publish the events. So while its not quite
// as responsive as actually watching for the Catalog changes, its MUCH simpler to // as responsive as actually watching for the Catalog changes, its MUCH simpler to
// code and reason about and having those addresses be updated within 30s is good enough. // code and reason about and having those addresses be updated within 30s is good enough.
_, node, err := r.GetStore().GetNodeID(types.NodeID(srv.Server.ID), structs.NodeEnterpriseMetaInDefaultPartition()) _, node, err := r.GetStore().GetNodeID(types.NodeID(srv.Server.ID), structs.NodeEnterpriseMetaInDefaultPartition(), structs.DefaultPeerKeyword)
if err != nil || node == nil { if err != nil || node == nil {
// no catalog information means we should return a nil addres map // no catalog information means we should return a nil addres map
return nil return nil

View File

@ -4,14 +4,16 @@ import (
"testing" "testing"
time "time" time "time"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/stream"
structs "github.com/hashicorp/consul/agent/structs"
types "github.com/hashicorp/consul/types"
"github.com/hashicorp/raft" "github.com/hashicorp/raft"
autopilot "github.com/hashicorp/raft-autopilot" autopilot "github.com/hashicorp/raft-autopilot"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/stream"
structs "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbsubscribe"
types "github.com/hashicorp/consul/types"
) )
var testTime = time.Date(2022, 4, 14, 10, 56, 00, 0, time.UTC) var testTime = time.Date(2022, 4, 14, 10, 56, 00, 0, time.UTC)
@ -161,6 +163,7 @@ func TestAutopilotStateToReadyServersWithTaggedAddresses(t *testing.T) {
store.On("GetNodeID", store.On("GetNodeID",
types.NodeID("792ae13c-d765-470b-852c-e073fdb6e849"), types.NodeID("792ae13c-d765-470b-852c-e073fdb6e849"),
structs.NodeEnterpriseMetaInDefaultPartition(), structs.NodeEnterpriseMetaInDefaultPartition(),
structs.DefaultPeerKeyword,
).Once().Return( ).Once().Return(
uint64(0), uint64(0),
&structs.Node{TaggedAddresses: map[string]string{"wan": "5.4.3.2"}}, &structs.Node{TaggedAddresses: map[string]string{"wan": "5.4.3.2"}},
@ -170,6 +173,7 @@ func TestAutopilotStateToReadyServersWithTaggedAddresses(t *testing.T) {
store.On("GetNodeID", store.On("GetNodeID",
types.NodeID("65e79ff4-bbce-467b-a9d6-725c709fa985"), types.NodeID("65e79ff4-bbce-467b-a9d6-725c709fa985"),
structs.NodeEnterpriseMetaInDefaultPartition(), structs.NodeEnterpriseMetaInDefaultPartition(),
structs.DefaultPeerKeyword,
).Once().Return( ).Once().Return(
uint64(0), uint64(0),
&structs.Node{TaggedAddresses: map[string]string{"wan": "1.2.3.4"}}, &structs.Node{TaggedAddresses: map[string]string{"wan": "1.2.3.4"}},
@ -179,6 +183,7 @@ func TestAutopilotStateToReadyServersWithTaggedAddresses(t *testing.T) {
store.On("GetNodeID", store.On("GetNodeID",
types.NodeID("db11f0ac-0cbe-4215-80cc-b4e843f4df1e"), types.NodeID("db11f0ac-0cbe-4215-80cc-b4e843f4df1e"),
structs.NodeEnterpriseMetaInDefaultPartition(), structs.NodeEnterpriseMetaInDefaultPartition(),
structs.DefaultPeerKeyword,
).Once().Return( ).Once().Return(
uint64(0), uint64(0),
&structs.Node{TaggedAddresses: map[string]string{"wan": "9.8.7.6"}}, &structs.Node{TaggedAddresses: map[string]string{"wan": "9.8.7.6"}},
@ -487,6 +492,7 @@ func TestReadyServerEventsSnapshotHandler(t *testing.T) {
store.On("GetNodeID", store.On("GetNodeID",
types.NodeID("792ae13c-d765-470b-852c-e073fdb6e849"), types.NodeID("792ae13c-d765-470b-852c-e073fdb6e849"),
structs.NodeEnterpriseMetaInDefaultPartition(), structs.NodeEnterpriseMetaInDefaultPartition(),
structs.DefaultPeerKeyword,
).Once().Return( ).Once().Return(
uint64(0), uint64(0),
&structs.Node{TaggedAddresses: map[string]string{"wan": "5.4.3.2"}}, &structs.Node{TaggedAddresses: map[string]string{"wan": "5.4.3.2"}},
@ -496,6 +502,7 @@ func TestReadyServerEventsSnapshotHandler(t *testing.T) {
store.On("GetNodeID", store.On("GetNodeID",
types.NodeID("65e79ff4-bbce-467b-a9d6-725c709fa985"), types.NodeID("65e79ff4-bbce-467b-a9d6-725c709fa985"),
structs.NodeEnterpriseMetaInDefaultPartition(), structs.NodeEnterpriseMetaInDefaultPartition(),
structs.DefaultPeerKeyword,
).Once().Return( ).Once().Return(
uint64(0), uint64(0),
&structs.Node{TaggedAddresses: map[string]string{"wan": "1.2.3.4"}}, &structs.Node{TaggedAddresses: map[string]string{"wan": "1.2.3.4"}},
@ -505,6 +512,7 @@ func TestReadyServerEventsSnapshotHandler(t *testing.T) {
store.On("GetNodeID", store.On("GetNodeID",
types.NodeID("db11f0ac-0cbe-4215-80cc-b4e843f4df1e"), types.NodeID("db11f0ac-0cbe-4215-80cc-b4e843f4df1e"),
structs.NodeEnterpriseMetaInDefaultPartition(), structs.NodeEnterpriseMetaInDefaultPartition(),
structs.DefaultPeerKeyword,
).Once().Return( ).Once().Return(
uint64(0), uint64(0),
&structs.Node{TaggedAddresses: map[string]string{"wan": "9.8.7.6"}}, &structs.Node{TaggedAddresses: map[string]string{"wan": "9.8.7.6"}},
@ -547,6 +555,10 @@ func (e fakePayload) HasReadPermission(authz acl.Authorizer) bool {
return false return false
} }
func (e fakePayload) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
panic("fakePayload does not implement ToSubscriptionEvent")
}
func TestExtractEventPayload(t *testing.T) { func TestExtractEventPayload(t *testing.T) {
t.Run("wrong-topic", func(t *testing.T) { t.Run("wrong-topic", func(t *testing.T) {
payload, err := ExtractEventPayload(stream.NewCloseSubscriptionEvent([]string{"foo"})) payload, err := ExtractEventPayload(stream.NewCloseSubscriptionEvent([]string{"foo"}))

View File

@ -133,7 +133,7 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error
} }
// Check the complete register request against the given ACL policy. // Check the complete register request against the given ACL policy.
_, ns, err := state.NodeServices(nil, args.Node, entMeta) _, ns, err := state.NodeServices(nil, args.Node, entMeta, args.PeerName)
if err != nil { if err != nil {
return fmt.Errorf("Node lookup failed: %v", err) return fmt.Errorf("Node lookup failed: %v", err)
} }
@ -367,7 +367,7 @@ func (c *Catalog) Deregister(args *structs.DeregisterRequest, reply *struct{}) e
var ns *structs.NodeService var ns *structs.NodeService
if args.ServiceID != "" { if args.ServiceID != "" {
_, ns, err = state.NodeService(args.Node, args.ServiceID, &args.EnterpriseMeta) _, ns, err = state.NodeService(args.Node, args.ServiceID, &args.EnterpriseMeta, args.PeerName)
if err != nil { if err != nil {
return fmt.Errorf("Service lookup failed: %v", err) return fmt.Errorf("Service lookup failed: %v", err)
} }
@ -375,7 +375,7 @@ func (c *Catalog) Deregister(args *structs.DeregisterRequest, reply *struct{}) e
var nc *structs.HealthCheck var nc *structs.HealthCheck
if args.CheckID != "" { if args.CheckID != "" {
_, nc, err = state.NodeCheck(args.Node, args.CheckID, &args.EnterpriseMeta) _, nc, err = state.NodeCheck(args.Node, args.CheckID, &args.EnterpriseMeta, args.PeerName)
if err != nil { if err != nil {
return fmt.Errorf("Check lookup failed: %v", err) return fmt.Errorf("Check lookup failed: %v", err)
} }
@ -486,9 +486,9 @@ func (c *Catalog) ListNodes(args *structs.DCSpecificRequest, reply *structs.Inde
func(ws memdb.WatchSet, state *state.Store) error { func(ws memdb.WatchSet, state *state.Store) error {
var err error var err error
if len(args.NodeMetaFilters) > 0 { if len(args.NodeMetaFilters) > 0 {
reply.Index, reply.Nodes, err = state.NodesByMeta(ws, args.NodeMetaFilters, &args.EnterpriseMeta) reply.Index, reply.Nodes, err = state.NodesByMeta(ws, args.NodeMetaFilters, &args.EnterpriseMeta, args.PeerName)
} else { } else {
reply.Index, reply.Nodes, err = state.Nodes(ws, &args.EnterpriseMeta) reply.Index, reply.Nodes, err = state.Nodes(ws, &args.EnterpriseMeta, args.PeerName)
} }
if err != nil { if err != nil {
return err return err
@ -546,9 +546,9 @@ func (c *Catalog) ListServices(args *structs.DCSpecificRequest, reply *structs.I
func(ws memdb.WatchSet, state *state.Store) error { func(ws memdb.WatchSet, state *state.Store) error {
var err error var err error
if len(args.NodeMetaFilters) > 0 { if len(args.NodeMetaFilters) > 0 {
reply.Index, reply.Services, err = state.ServicesByNodeMeta(ws, args.NodeMetaFilters, &args.EnterpriseMeta) reply.Index, reply.Services, err = state.ServicesByNodeMeta(ws, args.NodeMetaFilters, &args.EnterpriseMeta, args.PeerName)
} else { } else {
reply.Index, reply.Services, err = state.Services(ws, &args.EnterpriseMeta) reply.Index, reply.Services, err = state.Services(ws, &args.EnterpriseMeta, args.PeerName)
} }
if err != nil { if err != nil {
return err return err
@ -584,7 +584,7 @@ func (c *Catalog) ServiceList(args *structs.DCSpecificRequest, reply *structs.In
&args.QueryOptions, &args.QueryOptions,
&reply.QueryMeta, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error { func(ws memdb.WatchSet, state *state.Store) error {
index, services, err := state.ServiceList(ws, &args.EnterpriseMeta) index, services, err := state.ServiceList(ws, &args.EnterpriseMeta, args.PeerName)
if err != nil { if err != nil {
return err return err
} }
@ -611,13 +611,13 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru
switch { switch {
case args.Connect: case args.Connect:
f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) { f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) {
return s.ConnectServiceNodes(ws, args.ServiceName, &args.EnterpriseMeta) return s.ConnectServiceNodes(ws, args.ServiceName, &args.EnterpriseMeta, args.PeerName)
} }
default: default:
f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) { f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) {
if args.ServiceAddress != "" { if args.ServiceAddress != "" {
return s.ServiceAddressNodes(ws, args.ServiceAddress, &args.EnterpriseMeta) return s.ServiceAddressNodes(ws, args.ServiceAddress, &args.EnterpriseMeta, args.PeerName)
} }
if args.TagFilter { if args.TagFilter {
@ -630,10 +630,10 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru
tags = []string{args.ServiceTag} tags = []string{args.ServiceTag}
} }
return s.ServiceTagNodes(ws, args.ServiceName, tags, &args.EnterpriseMeta) return s.ServiceTagNodes(ws, args.ServiceName, tags, &args.EnterpriseMeta, args.PeerName)
} }
return s.ServiceNodes(ws, args.ServiceName, &args.EnterpriseMeta) return s.ServiceNodes(ws, args.ServiceName, &args.EnterpriseMeta, args.PeerName)
} }
} }
@ -768,7 +768,7 @@ func (c *Catalog) NodeServices(args *structs.NodeSpecificRequest, reply *structs
&args.QueryOptions, &args.QueryOptions,
&reply.QueryMeta, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error { func(ws memdb.WatchSet, state *state.Store) error {
index, services, err := state.NodeServices(ws, args.Node, &args.EnterpriseMeta) index, services, err := state.NodeServices(ws, args.Node, &args.EnterpriseMeta, args.PeerName)
if err != nil { if err != nil {
return err return err
} }
@ -824,7 +824,7 @@ func (c *Catalog) NodeServiceList(args *structs.NodeSpecificRequest, reply *stru
&args.QueryOptions, &args.QueryOptions,
&reply.QueryMeta, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error { func(ws memdb.WatchSet, state *state.Store) error {
index, services, err := state.NodeServiceList(ws, args.Node, &args.EnterpriseMeta) index, services, err := state.NodeServiceList(ws, args.Node, &args.EnterpriseMeta, args.PeerName)
if err != nil { if err != nil {
return err return err
} }

View File

@ -510,7 +510,7 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{ logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{
Name: c.NodeName, Name: c.NodeName,
Level: testutil.TestLogLevel, Level: hclog.Trace,
Output: testutil.NewLogBuffer(t), Output: testutil.NewLogBuffer(t),
}) })

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/proto/pbpeering"
) )
var CommandsSummaries = []prometheus.SummaryDefinition{ var CommandsSummaries = []prometheus.SummaryDefinition{
@ -93,6 +94,10 @@ var CommandsSummaries = []prometheus.SummaryDefinition{
Name: []string{"fsm", "system_metadata"}, Name: []string{"fsm", "system_metadata"},
Help: "Measures the time it takes to apply a system metadata operation to the FSM.", Help: "Measures the time it takes to apply a system metadata operation to the FSM.",
}, },
{
Name: []string{"fsm", "peering"},
Help: "Measures the time it takes to apply a peering operation to the FSM.",
},
// TODO(kit): We generate the config-entry fsm summaries by reading off of the request. It is // TODO(kit): We generate the config-entry fsm summaries by reading off of the request. It is
// possible to statically declare these when we know all of the names, but I didn't get to it // possible to statically declare these when we know all of the names, but I didn't get to it
// in this patch. Config-entries are known though and we should add these in the future. // in this patch. Config-entries are known though and we should add these in the future.
@ -131,6 +136,11 @@ func init() {
registerCommand(structs.ACLAuthMethodDeleteRequestType, (*FSM).applyACLAuthMethodDeleteOperation) registerCommand(structs.ACLAuthMethodDeleteRequestType, (*FSM).applyACLAuthMethodDeleteOperation)
registerCommand(structs.FederationStateRequestType, (*FSM).applyFederationStateOperation) registerCommand(structs.FederationStateRequestType, (*FSM).applyFederationStateOperation)
registerCommand(structs.SystemMetadataRequestType, (*FSM).applySystemMetadataOperation) registerCommand(structs.SystemMetadataRequestType, (*FSM).applySystemMetadataOperation)
registerCommand(structs.PeeringWriteType, (*FSM).applyPeeringWrite)
registerCommand(structs.PeeringDeleteType, (*FSM).applyPeeringDelete)
registerCommand(structs.PeeringTerminateByIDType, (*FSM).applyPeeringTerminate)
registerCommand(structs.PeeringTrustBundleWriteType, (*FSM).applyPeeringTrustBundleWrite)
registerCommand(structs.PeeringTrustBundleDeleteType, (*FSM).applyPeeringTrustBundleDelete)
} }
func (c *FSM) applyRegister(buf []byte, index uint64) interface{} { func (c *FSM) applyRegister(buf []byte, index uint64) interface{} {
@ -159,17 +169,17 @@ func (c *FSM) applyDeregister(buf []byte, index uint64) interface{} {
// here is also baked into vetDeregisterWithACL() in acl.go, so if you // here is also baked into vetDeregisterWithACL() in acl.go, so if you
// make changes here, be sure to also adjust the code over there. // make changes here, be sure to also adjust the code over there.
if req.ServiceID != "" { if req.ServiceID != "" {
if err := c.state.DeleteService(index, req.Node, req.ServiceID, &req.EnterpriseMeta); err != nil { if err := c.state.DeleteService(index, req.Node, req.ServiceID, &req.EnterpriseMeta, req.PeerName); err != nil {
c.logger.Warn("DeleteNodeService failed", "error", err) c.logger.Warn("DeleteNodeService failed", "error", err)
return err return err
} }
} else if req.CheckID != "" { } else if req.CheckID != "" {
if err := c.state.DeleteCheck(index, req.Node, req.CheckID, &req.EnterpriseMeta); err != nil { if err := c.state.DeleteCheck(index, req.Node, req.CheckID, &req.EnterpriseMeta, req.PeerName); err != nil {
c.logger.Warn("DeleteNodeCheck failed", "error", err) c.logger.Warn("DeleteNodeCheck failed", "error", err)
return err return err
} }
} else { } else {
if err := c.state.DeleteNode(index, req.Node, &req.EnterpriseMeta); err != nil { if err := c.state.DeleteNode(index, req.Node, &req.EnterpriseMeta, req.PeerName); err != nil {
c.logger.Warn("DeleteNode failed", "error", err) c.logger.Warn("DeleteNode failed", "error", err)
return err return err
} }
@ -679,3 +689,73 @@ func (c *FSM) applySystemMetadataOperation(buf []byte, index uint64) interface{}
return fmt.Errorf("invalid system metadata operation type: %v", req.Op) return fmt.Errorf("invalid system metadata operation type: %v", req.Op)
} }
} }
func (c *FSM) applyPeeringWrite(buf []byte, index uint64) interface{} {
var req pbpeering.PeeringWriteRequest
if err := structs.DecodeProto(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode peering write request: %v", err))
}
defer metrics.MeasureSinceWithLabels([]string{"fsm", "peering"}, time.Now(),
[]metrics.Label{{Name: "op", Value: "write"}})
return c.state.PeeringWrite(index, req.Peering)
}
// TODO(peering): replace with deferred deletion since this operation
// should involve cleanup of data associated with the peering.
func (c *FSM) applyPeeringDelete(buf []byte, index uint64) interface{} {
var req pbpeering.PeeringDeleteRequest
if err := structs.DecodeProto(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode peering delete request: %v", err))
}
defer metrics.MeasureSinceWithLabels([]string{"fsm", "peering"}, time.Now(),
[]metrics.Label{{Name: "op", Value: "delete"}})
q := state.Query{
Value: req.Name,
EnterpriseMeta: *structs.NodeEnterpriseMetaInPartition(req.Partition),
}
return c.state.PeeringDelete(index, q)
}
func (c *FSM) applyPeeringTerminate(buf []byte, index uint64) interface{} {
var req pbpeering.PeeringTerminateByIDRequest
if err := structs.DecodeProto(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode peering delete request: %v", err))
}
defer metrics.MeasureSinceWithLabels([]string{"fsm", "peering"}, time.Now(),
[]metrics.Label{{Name: "op", Value: "terminate"}})
return c.state.PeeringTerminateByID(index, req.ID)
}
func (c *FSM) applyPeeringTrustBundleWrite(buf []byte, index uint64) interface{} {
var req pbpeering.PeeringTrustBundleWriteRequest
if err := structs.DecodeProto(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode peering trust bundle write request: %v", err))
}
defer metrics.MeasureSinceWithLabels([]string{"fsm", "peering_trust_bundle"}, time.Now(),
[]metrics.Label{{Name: "op", Value: "write"}})
return c.state.PeeringTrustBundleWrite(index, req.PeeringTrustBundle)
}
func (c *FSM) applyPeeringTrustBundleDelete(buf []byte, index uint64) interface{} {
var req pbpeering.PeeringTrustBundleDeleteRequest
if err := structs.DecodeProto(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode peering trust bundle delete request: %v", err))
}
defer metrics.MeasureSinceWithLabels([]string{"fsm", "peering_trust_bundle"}, time.Now(),
[]metrics.Label{{Name: "op", Value: "delete"}})
q := state.Query{
Value: req.Name,
EnterpriseMeta: *structs.NodeEnterpriseMetaInPartition(req.Partition),
}
return c.state.PeeringTrustBundleDelete(index, q)
}

View File

@ -69,7 +69,7 @@ func TestFSM_RegisterNode(t *testing.T) {
} }
// Verify we are registered // Verify we are registered
_, node, err := fsm.state.GetNode("foo", nil) _, node, err := fsm.state.GetNode("foo", nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -81,7 +81,7 @@ func TestFSM_RegisterNode(t *testing.T) {
} }
// Verify service registered // Verify service registered
_, services, err := fsm.state.NodeServices(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition()) _, services, err := fsm.state.NodeServices(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition(), "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -128,7 +128,7 @@ func TestFSM_RegisterNode_Service(t *testing.T) {
} }
// Verify we are registered // Verify we are registered
_, node, err := fsm.state.GetNode("foo", nil) _, node, err := fsm.state.GetNode("foo", nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -137,7 +137,7 @@ func TestFSM_RegisterNode_Service(t *testing.T) {
} }
// Verify service registered // Verify service registered
_, services, err := fsm.state.NodeServices(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition()) _, services, err := fsm.state.NodeServices(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition(), "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -146,7 +146,7 @@ func TestFSM_RegisterNode_Service(t *testing.T) {
} }
// Verify check // Verify check
_, checks, err := fsm.state.NodeChecks(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition()) _, checks, err := fsm.state.NodeChecks(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition(), "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -200,7 +200,7 @@ func TestFSM_DeregisterService(t *testing.T) {
} }
// Verify we are registered // Verify we are registered
_, node, err := fsm.state.GetNode("foo", nil) _, node, err := fsm.state.GetNode("foo", nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -209,7 +209,7 @@ func TestFSM_DeregisterService(t *testing.T) {
} }
// Verify service not registered // Verify service not registered
_, services, err := fsm.state.NodeServices(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition()) _, services, err := fsm.state.NodeServices(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition(), "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -263,7 +263,7 @@ func TestFSM_DeregisterCheck(t *testing.T) {
} }
// Verify we are registered // Verify we are registered
_, node, err := fsm.state.GetNode("foo", nil) _, node, err := fsm.state.GetNode("foo", nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -272,7 +272,7 @@ func TestFSM_DeregisterCheck(t *testing.T) {
} }
// Verify check not registered // Verify check not registered
_, checks, err := fsm.state.NodeChecks(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition()) _, checks, err := fsm.state.NodeChecks(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition(), "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -332,7 +332,7 @@ func TestFSM_DeregisterNode(t *testing.T) {
} }
// Verify we are not registered // Verify we are not registered
_, node, err := fsm.state.GetNode("foo", nil) _, node, err := fsm.state.GetNode("foo", nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -341,7 +341,7 @@ func TestFSM_DeregisterNode(t *testing.T) {
} }
// Verify service not registered // Verify service not registered
_, services, err := fsm.state.NodeServices(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition()) _, services, err := fsm.state.NodeServices(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition(), "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -350,7 +350,7 @@ func TestFSM_DeregisterNode(t *testing.T) {
} }
// Verify checks not registered // Verify checks not registered
_, checks, err := fsm.state.NodeChecks(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition()) _, checks, err := fsm.state.NodeChecks(nil, "foo", structs.DefaultEnterpriseMetaInDefaultPartition(), "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -1468,7 +1468,7 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
// Verify we are not registered // Verify we are not registered
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
_, node, err := fsm.state.GetNode(fmt.Sprintf("foo%d", i), nil) _, node, err := fsm.state.GetNode(fmt.Sprintf("foo%d", i), nil, "")
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, node) assert.Nil(t, node)
} }
@ -1491,7 +1491,7 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
// Verify we are still not registered // Verify we are still not registered
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
_, node, err := fsm2.state.GetNode(fmt.Sprintf("foo%d", i), nil) _, node, err := fsm2.state.GetNode(fmt.Sprintf("foo%d", i), nil, "")
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, node) assert.Nil(t, node)
} }
@ -1515,19 +1515,19 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
// Verify we are registered // Verify we are registered
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
_, node, err := fsm2.state.GetNode(fmt.Sprintf("foo%d", i), nil) _, node, err := fsm2.state.GetNode(fmt.Sprintf("foo%d", i), nil, "")
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, node) assert.NotNil(t, node)
// Verify service registered // Verify service registered
_, services, err := fsm2.state.NodeServices(nil, fmt.Sprintf("foo%d", i), structs.DefaultEnterpriseMetaInDefaultPartition()) _, services, err := fsm2.state.NodeServices(nil, fmt.Sprintf("foo%d", i), structs.DefaultEnterpriseMetaInDefaultPartition(), "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, services) require.NotNil(t, services)
_, ok := services.Services["db"] _, ok := services.Services["db"]
assert.True(t, ok) assert.True(t, ok)
// Verify check // Verify check
_, checks, err := fsm2.state.NodeChecks(nil, fmt.Sprintf("foo%d", i), nil) _, checks, err := fsm2.state.NodeChecks(nil, fmt.Sprintf("foo%d", i), nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, checks) require.NotNil(t, checks)
assert.Equal(t, string(checks[0].CheckID), "db") assert.Equal(t, string(checks[0].CheckID), "db")

View File

@ -6,6 +6,7 @@ import (
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
) )
func init() { func init() {
@ -35,6 +36,8 @@ func init() {
registerRestorer(structs.SystemMetadataRequestType, restoreSystemMetadata) registerRestorer(structs.SystemMetadataRequestType, restoreSystemMetadata)
registerRestorer(structs.ServiceVirtualIPRequestType, restoreServiceVirtualIP) registerRestorer(structs.ServiceVirtualIPRequestType, restoreServiceVirtualIP)
registerRestorer(structs.FreeVirtualIPRequestType, restoreFreeVirtualIP) registerRestorer(structs.FreeVirtualIPRequestType, restoreFreeVirtualIP)
registerRestorer(structs.PeeringWriteType, restorePeering)
registerRestorer(structs.PeeringTrustBundleWriteType, restorePeeringTrustBundle)
} }
func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error { func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error {
@ -86,6 +89,12 @@ func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) err
if err := s.persistIndex(sink, encoder); err != nil { if err := s.persistIndex(sink, encoder); err != nil {
return err return err
} }
if err := s.persistPeerings(sink, encoder); err != nil {
return err
}
if err := s.persistPeeringTrustBundles(sink, encoder); err != nil {
return err
}
return nil return nil
} }
@ -112,6 +121,7 @@ func (s *snapshot) persistNodes(sink raft.SnapshotSink,
NodeMeta: n.Meta, NodeMeta: n.Meta,
RaftIndex: n.RaftIndex, RaftIndex: n.RaftIndex,
EnterpriseMeta: *nodeEntMeta, EnterpriseMeta: *nodeEntMeta,
PeerName: n.PeerName,
} }
// Register the node itself // Register the node itself
@ -123,7 +133,7 @@ func (s *snapshot) persistNodes(sink raft.SnapshotSink,
} }
// Register each service this node has // Register each service this node has
services, err := s.state.Services(n.Node, nodeEntMeta) services, err := s.state.Services(n.Node, nodeEntMeta, n.PeerName)
if err != nil { if err != nil {
return err return err
} }
@ -139,7 +149,7 @@ func (s *snapshot) persistNodes(sink raft.SnapshotSink,
// Register each check this node has // Register each check this node has
req.Service = nil req.Service = nil
checks, err := s.state.Checks(n.Node, nodeEntMeta) checks, err := s.state.Checks(n.Node, nodeEntMeta, n.PeerName)
if err != nil { if err != nil {
return err return err
} }
@ -161,7 +171,6 @@ func (s *snapshot) persistNodes(sink raft.SnapshotSink,
if err != nil { if err != nil {
return err return err
} }
// TODO(partitions)
for coord := coords.Next(); coord != nil; coord = coords.Next() { for coord := coords.Next(); coord != nil; coord = coords.Next() {
if _, err := sink.Write([]byte{byte(structs.CoordinateBatchUpdateType)}); err != nil { if _, err := sink.Write([]byte{byte(structs.CoordinateBatchUpdateType)}); err != nil {
return err return err
@ -547,6 +556,42 @@ func (s *snapshot) persistVirtualIPs(sink raft.SnapshotSink, encoder *codec.Enco
return nil return nil
} }
func (s *snapshot) persistPeerings(sink raft.SnapshotSink, encoder *codec.Encoder) error {
peerings, err := s.state.Peerings()
if err != nil {
return err
}
for entry := peerings.Next(); entry != nil; entry = peerings.Next() {
if _, err := sink.Write([]byte{byte(structs.PeeringWriteType)}); err != nil {
return err
}
if err := encoder.Encode(entry.(*pbpeering.Peering)); err != nil {
return err
}
}
return nil
}
func (s *snapshot) persistPeeringTrustBundles(sink raft.SnapshotSink, encoder *codec.Encoder) error {
ptbs, err := s.state.PeeringTrustBundles()
if err != nil {
return err
}
for entry := ptbs.Next(); entry != nil; entry = ptbs.Next() {
if _, err := sink.Write([]byte{byte(structs.PeeringTrustBundleWriteType)}); err != nil {
return err
}
if err := encoder.Encode(entry.(*pbpeering.PeeringTrustBundle)); err != nil {
return err
}
}
return nil
}
func restoreRegistration(header *SnapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { func restoreRegistration(header *SnapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
var req structs.RegisterRequest var req structs.RegisterRequest
if err := decoder.Decode(&req); err != nil { if err := decoder.Decode(&req); err != nil {
@ -849,3 +894,25 @@ func restoreFreeVirtualIP(header *SnapshotHeader, restore *state.Restore, decode
} }
return nil return nil
} }
func restorePeering(header *SnapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
var req pbpeering.Peering
if err := decoder.Decode(&req); err != nil {
return err
}
if err := restore.Peering(&req); err != nil {
return err
}
return nil
}
func restorePeeringTrustBundle(header *SnapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
var req pbpeering.PeeringTrustBundle
if err := decoder.Decode(&req); err != nil {
return err
}
if err := restore.PeeringTrustBundle(&req); err != nil {
return err
}
return nil
}

View File

@ -17,6 +17,7 @@ import (
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib/stringslice" "github.com/hashicorp/consul/lib/stringslice"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil"
) )
@ -473,6 +474,18 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
require.Equal(t, expect[i], sn.Service.Name) require.Equal(t, expect[i], sn.Service.Name)
} }
// Peerings
require.NoError(t, fsm.state.PeeringWrite(31, &pbpeering.Peering{
Name: "baz",
}))
// Peering Trust Bundles
require.NoError(t, fsm.state.PeeringTrustBundleWrite(32, &pbpeering.PeeringTrustBundle{
TrustDomain: "qux.com",
PeerName: "qux",
RootPEMs: []string{"qux certificate bundle"},
}))
// Snapshot // Snapshot
snap, err := fsm.Snapshot() snap, err := fsm.Snapshot()
require.NoError(t, err) require.NoError(t, err)
@ -528,7 +541,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
require.NoError(t, fsm2.Restore(sink)) require.NoError(t, fsm2.Restore(sink))
// Verify the contents // Verify the contents
_, nodes, err := fsm2.state.Nodes(nil, nil) _, nodes, err := fsm2.state.Nodes(nil, nil, "")
require.NoError(t, err) require.NoError(t, err)
require.Len(t, nodes, 2, "incorect number of nodes: %v", nodes) require.Len(t, nodes, 2, "incorect number of nodes: %v", nodes)
@ -556,7 +569,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
require.Equal(t, uint64(1), nodes[1].CreateIndex) require.Equal(t, uint64(1), nodes[1].CreateIndex)
require.Equal(t, uint64(23), nodes[1].ModifyIndex) require.Equal(t, uint64(23), nodes[1].ModifyIndex)
_, fooSrv, err := fsm2.state.NodeServices(nil, "foo", nil) _, fooSrv, err := fsm2.state.NodeServices(nil, "foo", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.Len(t, fooSrv.Services, 4) require.Len(t, fooSrv.Services, 4)
require.Contains(t, fooSrv.Services["db"].Tags, "primary") require.Contains(t, fooSrv.Services["db"].Tags, "primary")
@ -569,7 +582,7 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
require.Equal(t, uint64(3), fooSrv.Services["web"].CreateIndex) require.Equal(t, uint64(3), fooSrv.Services["web"].CreateIndex)
require.Equal(t, uint64(3), fooSrv.Services["web"].ModifyIndex) require.Equal(t, uint64(3), fooSrv.Services["web"].ModifyIndex)
_, checks, err := fsm2.state.NodeChecks(nil, "foo", nil) _, checks, err := fsm2.state.NodeChecks(nil, "foo", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.Len(t, checks, 1) require.Len(t, checks, 1)
require.Equal(t, "foo", checks[0].Node) require.Equal(t, "foo", checks[0].Node)
@ -768,6 +781,27 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
require.Equal(t, expect[i], sn.Service.Name) require.Equal(t, expect[i], sn.Service.Name)
} }
// Verify peering is restored
idx, prngRestored, err := fsm2.state.PeeringRead(nil, state.Query{
Value: "baz",
})
require.NoError(t, err)
require.Equal(t, uint64(31), idx)
require.NotNil(t, prngRestored)
require.Equal(t, "baz", prngRestored.Name)
// Verify peering trust bundle is restored
idx, ptbRestored, err := fsm2.state.PeeringTrustBundleRead(nil, state.Query{
Value: "qux",
})
require.NoError(t, err)
require.Equal(t, uint64(32), idx)
require.NotNil(t, ptbRestored)
require.Equal(t, "qux.com", ptbRestored.TrustDomain)
require.Equal(t, "qux", ptbRestored.PeerName)
require.Len(t, ptbRestored.RootPEMs, 1)
require.Equal(t, "qux certificate bundle", ptbRestored.RootPEMs[0])
// Snapshot // Snapshot
snap, err = fsm2.Snapshot() snap, err = fsm2.Snapshot()
require.NoError(t, err) require.NoError(t, err)
@ -821,7 +855,7 @@ func TestFSM_BadRestore_OSS(t *testing.T) {
require.Error(t, fsm.Restore(sink)) require.Error(t, fsm.Restore(sink))
// Verify the contents didn't get corrupted. // Verify the contents didn't get corrupted.
_, nodes, err := fsm.state.Nodes(nil, nil) _, nodes, err := fsm.state.Nodes(nil, nil, "")
require.NoError(t, err) require.NoError(t, err)
require.Len(t, nodes, 1) require.Len(t, nodes, 1)
require.Equal(t, "foo", nodes[0].Node) require.Equal(t, "foo", nodes[0].Node)

View File

@ -47,9 +47,9 @@ func (h *Health) ChecksInState(args *structs.ChecksInStateRequest,
var checks structs.HealthChecks var checks structs.HealthChecks
var err error var err error
if len(args.NodeMetaFilters) > 0 { if len(args.NodeMetaFilters) > 0 {
index, checks, err = state.ChecksInStateByNodeMeta(ws, args.State, args.NodeMetaFilters, &args.EnterpriseMeta) index, checks, err = state.ChecksInStateByNodeMeta(ws, args.State, args.NodeMetaFilters, &args.EnterpriseMeta, args.PeerName)
} else { } else {
index, checks, err = state.ChecksInState(ws, args.State, &args.EnterpriseMeta) index, checks, err = state.ChecksInState(ws, args.State, &args.EnterpriseMeta, args.PeerName)
} }
if err != nil { if err != nil {
return err return err
@ -98,7 +98,7 @@ func (h *Health) NodeChecks(args *structs.NodeSpecificRequest,
&args.QueryOptions, &args.QueryOptions,
&reply.QueryMeta, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error { func(ws memdb.WatchSet, state *state.Store) error {
index, checks, err := state.NodeChecks(ws, args.Node, &args.EnterpriseMeta) index, checks, err := state.NodeChecks(ws, args.Node, &args.EnterpriseMeta, args.PeerName)
if err != nil { if err != nil {
return err return err
} }
@ -157,9 +157,9 @@ func (h *Health) ServiceChecks(args *structs.ServiceSpecificRequest,
var checks structs.HealthChecks var checks structs.HealthChecks
var err error var err error
if len(args.NodeMetaFilters) > 0 { if len(args.NodeMetaFilters) > 0 {
index, checks, err = state.ServiceChecksByNodeMeta(ws, args.ServiceName, args.NodeMetaFilters, &args.EnterpriseMeta) index, checks, err = state.ServiceChecksByNodeMeta(ws, args.ServiceName, args.NodeMetaFilters, &args.EnterpriseMeta, args.PeerName)
} else { } else {
index, checks, err = state.ServiceChecks(ws, args.ServiceName, &args.EnterpriseMeta) index, checks, err = state.ServiceChecks(ws, args.ServiceName, &args.EnterpriseMeta, args.PeerName)
} }
if err != nil { if err != nil {
return err return err
@ -304,7 +304,7 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc
// can be used by the ServiceNodes endpoint. // can be used by the ServiceNodes endpoint.
func (h *Health) serviceNodesConnect(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) { func (h *Health) serviceNodesConnect(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
return s.CheckConnectServiceNodes(ws, args.ServiceName, &args.EnterpriseMeta) return s.CheckConnectServiceNodes(ws, args.ServiceName, &args.EnterpriseMeta, args.PeerName)
} }
func (h *Health) serviceNodesIngress(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) { func (h *Health) serviceNodesIngress(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
@ -317,11 +317,11 @@ func (h *Health) serviceNodesTagFilter(ws memdb.WatchSet, s *state.Store, args *
// Agents < v1.3.0 populate the ServiceTag field. In this case, // Agents < v1.3.0 populate the ServiceTag field. In this case,
// use ServiceTag instead of the ServiceTags field. // use ServiceTag instead of the ServiceTags field.
if args.ServiceTag != "" { if args.ServiceTag != "" {
return s.CheckServiceTagNodes(ws, args.ServiceName, []string{args.ServiceTag}, &args.EnterpriseMeta) return s.CheckServiceTagNodes(ws, args.ServiceName, []string{args.ServiceTag}, &args.EnterpriseMeta, args.PeerName)
} }
return s.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTags, &args.EnterpriseMeta) return s.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTags, &args.EnterpriseMeta, args.PeerName)
} }
func (h *Health) serviceNodesDefault(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) { func (h *Health) serviceNodesDefault(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
return s.CheckServiceNodes(ws, args.ServiceName, &args.EnterpriseMeta) return s.CheckServiceNodes(ws, args.ServiceName, &args.EnterpriseMeta, args.PeerName)
} }

View File

@ -13,7 +13,6 @@ import (
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/lib/stringslice"
"github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/consul/types" "github.com/hashicorp/consul/types"
@ -558,124 +557,109 @@ func TestHealth_ServiceNodes(t *testing.T) {
} }
t.Parallel() t.Parallel()
dir1, s1 := testServer(t) _, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1) codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1") waitForLeaderEstablishment(t, s1)
arg := structs.RegisterRequest{ testingPeerNames := []string{"", "my-peer"}
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
ID: "db",
Service: "db",
Tags: []string{"primary"},
},
Check: &structs.HealthCheck{
Name: "db connect",
Status: api.HealthPassing,
ServiceID: "db",
},
}
var out struct{}
if err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out); err != nil {
t.Fatalf("err: %v", err)
}
arg = structs.RegisterRequest{ // TODO(peering): will have to seed this data differently in the future
Datacenter: "dc1", for _, peerName := range testingPeerNames {
Node: "bar", arg := structs.RegisterRequest{
Address: "127.0.0.2", Datacenter: "dc1",
Service: &structs.NodeService{ Node: "foo",
ID: "db", Address: "127.0.0.1",
Service: "db", PeerName: peerName,
Tags: []string{"replica"}, Service: &structs.NodeService{
}, ID: "db",
Check: &structs.HealthCheck{ Service: "db",
Name: "db connect", Tags: []string{"primary"},
Status: api.HealthWarning, PeerName: peerName,
ServiceID: "db", },
}, Check: &structs.HealthCheck{
} Name: "db connect",
if err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out); err != nil { Status: api.HealthPassing,
t.Fatalf("err: %v", err) ServiceID: "db",
} PeerName: peerName,
},
var out2 structs.IndexedCheckServiceNodes
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: "db",
ServiceTags: []string{"primary"},
TagFilter: false,
}
if err := msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &out2); err != nil {
t.Fatalf("err: %v", err)
}
nodes := out2.Nodes
if len(nodes) != 2 {
t.Fatalf("Bad: %v", nodes)
}
if nodes[0].Node.Node != "bar" {
t.Fatalf("Bad: %v", nodes[0])
}
if nodes[1].Node.Node != "foo" {
t.Fatalf("Bad: %v", nodes[1])
}
if !stringslice.Contains(nodes[0].Service.Tags, "replica") {
t.Fatalf("Bad: %v", nodes[0])
}
if !stringslice.Contains(nodes[1].Service.Tags, "primary") {
t.Fatalf("Bad: %v", nodes[1])
}
if nodes[0].Checks[0].Status != api.HealthWarning {
t.Fatalf("Bad: %v", nodes[0])
}
if nodes[1].Checks[0].Status != api.HealthPassing {
t.Fatalf("Bad: %v", nodes[1])
}
// Same should still work for <1.3 RPCs with singular tags
// DEPRECATED (singular-service-tag) - remove this when backwards RPC compat
// with 1.2.x is not required.
{
var out2 structs.IndexedCheckServiceNodes
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: "db",
ServiceTag: "primary",
TagFilter: false,
}
if err := msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &out2); err != nil {
t.Fatalf("err: %v", err)
} }
var out struct{}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out))
arg = structs.RegisterRequest{
Datacenter: "dc1",
Node: "bar",
Address: "127.0.0.2",
PeerName: peerName,
Service: &structs.NodeService{
ID: "db",
Service: "db",
Tags: []string{"replica"},
PeerName: peerName,
},
Check: &structs.HealthCheck{
Name: "db connect",
Status: api.HealthWarning,
ServiceID: "db",
PeerName: peerName,
},
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out))
}
verify := func(t *testing.T, out2 structs.IndexedCheckServiceNodes, peerName string) {
nodes := out2.Nodes nodes := out2.Nodes
if len(nodes) != 2 { require.Len(t, nodes, 2)
t.Fatalf("Bad: %v", nodes) require.Equal(t, peerName, nodes[0].Node.PeerName)
} require.Equal(t, peerName, nodes[1].Node.PeerName)
if nodes[0].Node.Node != "bar" { require.Equal(t, "bar", nodes[0].Node.Node)
t.Fatalf("Bad: %v", nodes[0]) require.Equal(t, "foo", nodes[1].Node.Node)
} require.Equal(t, peerName, nodes[0].Service.PeerName)
if nodes[1].Node.Node != "foo" { require.Equal(t, peerName, nodes[1].Service.PeerName)
t.Fatalf("Bad: %v", nodes[1]) require.Contains(t, nodes[0].Service.Tags, "replica")
} require.Contains(t, nodes[1].Service.Tags, "primary")
if !stringslice.Contains(nodes[0].Service.Tags, "replica") { require.Equal(t, peerName, nodes[0].Checks[0].PeerName)
t.Fatalf("Bad: %v", nodes[0]) require.Equal(t, peerName, nodes[1].Checks[0].PeerName)
} require.Equal(t, api.HealthWarning, nodes[0].Checks[0].Status)
if !stringslice.Contains(nodes[1].Service.Tags, "primary") { require.Equal(t, api.HealthPassing, nodes[1].Checks[0].Status)
t.Fatalf("Bad: %v", nodes[1]) }
}
if nodes[0].Checks[0].Status != api.HealthWarning { for _, peerName := range testingPeerNames {
t.Fatalf("Bad: %v", nodes[0]) testName := "peer named " + peerName
} if peerName == "" {
if nodes[1].Checks[0].Status != api.HealthPassing { testName = "local peer"
t.Fatalf("Bad: %v", nodes[1])
} }
t.Run(testName, func(t *testing.T) {
t.Run("with service tags", func(t *testing.T) {
var out2 structs.IndexedCheckServiceNodes
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: "db",
ServiceTags: []string{"primary"},
TagFilter: false,
PeerName: peerName,
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &out2))
verify(t, out2, peerName)
})
// Same should still work for <1.3 RPCs with singular tags
// DEPRECATED (singular-service-tag) - remove this when backwards RPC compat
// with 1.2.x is not required.
t.Run("with legacy service tag", func(t *testing.T) {
var out2 structs.IndexedCheckServiceNodes
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: "db",
ServiceTag: "primary",
TagFilter: false,
PeerName: peerName,
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &out2))
verify(t, out2, peerName)
})
})
} }
} }

View File

@ -38,7 +38,7 @@ func (m *Internal) NodeInfo(args *structs.NodeSpecificRequest,
&args.QueryOptions, &args.QueryOptions,
&reply.QueryMeta, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error { func(ws memdb.WatchSet, state *state.Store) error {
index, dump, err := state.NodeInfo(ws, args.Node, &args.EnterpriseMeta) index, dump, err := state.NodeInfo(ws, args.Node, &args.EnterpriseMeta, args.PeerName)
if err != nil { if err != nil {
return err return err
} }
@ -69,7 +69,7 @@ func (m *Internal) NodeDump(args *structs.DCSpecificRequest,
&args.QueryOptions, &args.QueryOptions,
&reply.QueryMeta, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error { func(ws memdb.WatchSet, state *state.Store) error {
index, dump, err := state.NodeDump(ws, &args.EnterpriseMeta) index, dump, err := state.NodeDump(ws, &args.EnterpriseMeta, args.PeerName)
if err != nil { if err != nil {
return err return err
} }
@ -112,7 +112,7 @@ func (m *Internal) ServiceDump(args *structs.ServiceDumpRequest, reply *structs.
&reply.QueryMeta, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error { func(ws memdb.WatchSet, state *state.Store) error {
// Get, store, and filter nodes // Get, store, and filter nodes
maxIdx, nodes, err := state.ServiceDump(ws, args.ServiceKind, args.UseServiceKind, &args.EnterpriseMeta) maxIdx, nodes, err := state.ServiceDump(ws, args.ServiceKind, args.UseServiceKind, &args.EnterpriseMeta, args.PeerName)
if err != nil { if err != nil {
return err return err
} }
@ -314,7 +314,7 @@ func (m *Internal) GatewayServiceDump(args *structs.ServiceSpecificRequest, repl
// Loop over the gateway <-> serviceName mappings and fetch all service instances for each // Loop over the gateway <-> serviceName mappings and fetch all service instances for each
var result structs.ServiceDump var result structs.ServiceDump
for _, gs := range gatewayServices { for _, gs := range gatewayServices {
idx, instances, err := state.CheckServiceNodes(ws, gs.Service.Name, &gs.Service.EnterpriseMeta) idx, instances, err := state.CheckServiceNodes(ws, gs.Service.Name, &gs.Service.EnterpriseMeta, args.PeerName)
if err != nil { if err != nil {
return err return err
} }

View File

@ -62,7 +62,7 @@ func TestHealthCheckRace(t *testing.T) {
} }
// Verify the index // Verify the index
idx, out1, err := state.CheckServiceNodes(nil, "db", nil) idx, out1, err := state.CheckServiceNodes(nil, "db", nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -85,7 +85,7 @@ func TestHealthCheckRace(t *testing.T) {
} }
// Verify the index changed // Verify the index changed
idx, out2, err := state.CheckServiceNodes(nil, "db", nil) idx, out2, err := state.CheckServiceNodes(nil, "db", nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View File

@ -305,6 +305,8 @@ func (s *Server) establishLeadership(ctx context.Context) error {
s.startFederationStateAntiEntropy(ctx) s.startFederationStateAntiEntropy(ctx)
s.startPeeringStreamSync(ctx)
if err := s.startConnectLeader(ctx); err != nil { if err := s.startConnectLeader(ctx); err != nil {
return err return err
} }
@ -342,6 +344,8 @@ func (s *Server) revokeLeadership() {
s.stopACLReplication() s.stopACLReplication()
s.stopPeeringStreamSync()
s.stopConnectLeader() s.stopConnectLeader()
s.stopACLTokenReaping() s.stopACLTokenReaping()
@ -887,7 +891,7 @@ func (s *Server) reconcileReaped(known map[string]struct{}, nodeEntMeta *acl.Ent
} }
state := s.fsm.State() state := s.fsm.State()
_, checks, err := state.ChecksInState(nil, api.HealthAny, nodeEntMeta) _, checks, err := state.ChecksInState(nil, api.HealthAny, nodeEntMeta, structs.DefaultPeerKeyword)
if err != nil { if err != nil {
return err return err
} }
@ -903,7 +907,7 @@ func (s *Server) reconcileReaped(known map[string]struct{}, nodeEntMeta *acl.Ent
} }
// Get the node services, look for ConsulServiceID // Get the node services, look for ConsulServiceID
_, services, err := state.NodeServices(nil, check.Node, nodeEntMeta) _, services, err := state.NodeServices(nil, check.Node, nodeEntMeta, structs.DefaultPeerKeyword)
if err != nil { if err != nil {
return err return err
} }
@ -914,7 +918,7 @@ func (s *Server) reconcileReaped(known map[string]struct{}, nodeEntMeta *acl.Ent
CHECKS: CHECKS:
for _, service := range services.Services { for _, service := range services.Services {
if service.ID == structs.ConsulServiceID { if service.ID == structs.ConsulServiceID {
_, node, err := state.GetNode(check.Node, nodeEntMeta) _, node, err := state.GetNode(check.Node, nodeEntMeta, check.PeerName)
if err != nil { if err != nil {
s.logger.Error("Unable to look up node with name", "name", check.Node, "error", err) s.logger.Error("Unable to look up node with name", "name", check.Node, "error", err)
continue CHECKS continue CHECKS
@ -1051,7 +1055,7 @@ func (s *Server) handleAliveMember(member serf.Member, nodeEntMeta *acl.Enterpri
// Check if the node exists // Check if the node exists
state := s.fsm.State() state := s.fsm.State()
_, node, err := state.GetNode(member.Name, nodeEntMeta) _, node, err := state.GetNode(member.Name, nodeEntMeta, structs.DefaultPeerKeyword)
if err != nil { if err != nil {
return err return err
} }
@ -1059,7 +1063,7 @@ func (s *Server) handleAliveMember(member serf.Member, nodeEntMeta *acl.Enterpri
// Check if the associated service is available // Check if the associated service is available
if service != nil { if service != nil {
match := false match := false
_, services, err := state.NodeServices(nil, member.Name, nodeEntMeta) _, services, err := state.NodeServices(nil, member.Name, nodeEntMeta, structs.DefaultPeerKeyword)
if err != nil { if err != nil {
return err return err
} }
@ -1077,7 +1081,7 @@ func (s *Server) handleAliveMember(member serf.Member, nodeEntMeta *acl.Enterpri
} }
// Check if the serfCheck is in the passing state // Check if the serfCheck is in the passing state
_, checks, err := state.NodeChecks(nil, member.Name, nodeEntMeta) _, checks, err := state.NodeChecks(nil, member.Name, nodeEntMeta, structs.DefaultPeerKeyword)
if err != nil { if err != nil {
return err return err
} }
@ -1127,7 +1131,7 @@ func (s *Server) handleFailedMember(member serf.Member, nodeEntMeta *acl.Enterpr
// Check if the node exists // Check if the node exists
state := s.fsm.State() state := s.fsm.State()
_, node, err := state.GetNode(member.Name, nodeEntMeta) _, node, err := state.GetNode(member.Name, nodeEntMeta, structs.DefaultPeerKeyword)
if err != nil { if err != nil {
return err return err
} }
@ -1142,7 +1146,7 @@ func (s *Server) handleFailedMember(member serf.Member, nodeEntMeta *acl.Enterpr
if node.Address == member.Addr.String() { if node.Address == member.Addr.String() {
// Check if the serfCheck is in the critical state // Check if the serfCheck is in the critical state
_, checks, err := state.NodeChecks(nil, member.Name, nodeEntMeta) _, checks, err := state.NodeChecks(nil, member.Name, nodeEntMeta, structs.DefaultPeerKeyword)
if err != nil { if err != nil {
return err return err
} }
@ -1220,7 +1224,7 @@ func (s *Server) handleDeregisterMember(reason string, member serf.Member, nodeE
// Check if the node does not exist // Check if the node does not exist
state := s.fsm.State() state := s.fsm.State()
_, node, err := state.GetNode(member.Name, nodeEntMeta) _, node, err := state.GetNode(member.Name, nodeEntMeta, structs.DefaultPeerKeyword)
if err != nil { if err != nil {
return err return err
} }

View File

@ -157,7 +157,7 @@ func (s *Server) fetchFederationStateAntiEntropyDetails(
// Fetch our current list of all mesh gateways. // Fetch our current list of all mesh gateways.
entMeta := structs.WildcardEnterpriseMetaInDefaultPartition() entMeta := structs.WildcardEnterpriseMetaInDefaultPartition()
idx2, raw, err := state.ServiceDump(ws, structs.ServiceKindMeshGateway, true, entMeta) idx2, raw, err := state.ServiceDump(ws, structs.ServiceKindMeshGateway, true, entMeta, structs.DefaultPeerKeyword)
if err != nil { if err != nil {
return err return err
} }

View File

@ -0,0 +1,244 @@
package consul
import (
"container/ring"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
)
func (s *Server) startPeeringStreamSync(ctx context.Context) {
s.leaderRoutineManager.Start(ctx, peeringStreamsRoutineName, s.runPeeringSync)
}
func (s *Server) runPeeringSync(ctx context.Context) error {
logger := s.logger.Named("peering-syncer")
cancelFns := make(map[string]context.CancelFunc)
retryLoopBackoff(ctx, func() error {
if err := s.syncPeeringsAndBlock(ctx, logger, cancelFns); err != nil {
return err
}
return nil
}, func(err error) {
s.logger.Error("error syncing peering streams from state store", "error", err)
})
return nil
}
func (s *Server) stopPeeringStreamSync() {
// will be a no-op when not started
s.leaderRoutineManager.Stop(peeringStreamsRoutineName)
}
// syncPeeringsAndBlock is a long-running goroutine that is responsible for watching
// changes to peerings in the state store and managing streams to those peers.
func (s *Server) syncPeeringsAndBlock(ctx context.Context, logger hclog.Logger, cancelFns map[string]context.CancelFunc) error {
state := s.fsm.State()
// Pull the state store contents and set up to block for changes.
ws := memdb.NewWatchSet()
ws.Add(state.AbandonCh())
ws.Add(ctx.Done())
_, peers, err := state.PeeringList(ws, *structs.NodeEnterpriseMetaInPartition(structs.WildcardSpecifier))
if err != nil {
return err
}
// TODO(peering) Adjust this debug info.
// Generate a UUID to trace different passes through this function.
seq, err := uuid.GenerateUUID()
if err != nil {
s.logger.Debug("failed to generate sequence uuid while syncing peerings")
}
logger.Trace("syncing new list of peers", "num_peers", len(peers), "sequence_id", seq)
// Stored tracks the unique set of peers that should be dialed.
// It is used to reconcile the list of active streams.
stored := make(map[string]struct{})
var merr *multierror.Error
// Create connections and streams to peers in the state store that do not have an active stream.
for _, peer := range peers {
logger.Trace("evaluating stored peer", "peer", peer.Name, "should_dial", peer.ShouldDial(), "sequence_id", seq)
if !peer.ShouldDial() {
continue
}
// TODO(peering) Account for deleted peers that are still in the state store
stored[peer.ID] = struct{}{}
status, found := s.peeringService.StreamStatus(peer.ID)
// TODO(peering): If there is new peering data and a connected stream, should we tear down the stream?
// If the data in the updated token is bad, the user wouldn't know until the old servers/certs become invalid.
// Alternatively we could do a basic Ping from the initiate peering endpoint to avoid dealing with that here.
if found && status.Connected {
// Nothing to do when we already have an active stream to the peer.
continue
}
logger.Trace("ensuring stream to peer", "peer_id", peer.ID, "sequence_id", seq)
if cancel, ok := cancelFns[peer.ID]; ok {
// If the peer is known but we're not connected, clean up the retry-er and start over.
// There may be new data in the state store that would enable us to get out of an error state.
logger.Trace("cancelling context to re-establish stream", "peer_id", peer.ID, "sequence_id", seq)
cancel()
}
if err := s.establishStream(ctx, logger, peer, cancelFns); err != nil {
// TODO(peering): These errors should be reported in the peer status, otherwise they're only in the logs.
// Lockable status isn't available here though. Could report it via the peering.Service?
logger.Error("error establishing peering stream", "peer_id", peer.ID, "error", err)
merr = multierror.Append(merr, err)
// Continue on errors to avoid one bad peering from blocking the establishment and cleanup of others.
continue
}
}
logger.Trace("checking connected streams", "streams", s.peeringService.ConnectedStreams(), "sequence_id", seq)
// Clean up active streams of peerings that were deleted from the state store.
// TODO(peering): This is going to trigger shutting down peerings we generated a token for. Is that OK?
for stream, doneCh := range s.peeringService.ConnectedStreams() {
if _, ok := stored[stream]; ok {
// Active stream is in the state store, nothing to do.
continue
}
select {
case <-doneCh:
// channel is closed, do nothing to avoid a panic
default:
logger.Trace("tearing down stream for deleted peer", "peer_id", stream, "sequence_id", seq)
close(doneCh)
}
}
logger.Trace("blocking for changes", "sequence_id", seq)
// Block for any changes to the state store.
ws.WatchCtx(ctx)
logger.Trace("unblocked", "sequence_id", seq)
return merr.ErrorOrNil()
}
func (s *Server) establishStream(ctx context.Context, logger hclog.Logger, peer *pbpeering.Peering, cancelFns map[string]context.CancelFunc) error {
tlsOption := grpc.WithInsecure()
if len(peer.PeerCAPems) > 0 {
var haveCerts bool
pool := x509.NewCertPool()
for _, pem := range peer.PeerCAPems {
if !pool.AppendCertsFromPEM([]byte(pem)) {
return fmt.Errorf("failed to parse PEM %s", pem)
}
if len(pem) > 0 {
haveCerts = true
}
}
if !haveCerts {
return fmt.Errorf("failed to build cert pool from peer CA pems")
}
cfg := tls.Config{
ServerName: peer.PeerServerName,
RootCAs: pool,
}
tlsOption = grpc.WithTransportCredentials(credentials.NewTLS(&cfg))
}
// Create a ring buffer to cycle through peer addresses in the retry loop below.
buffer := ring.New(len(peer.PeerServerAddresses))
for _, addr := range peer.PeerServerAddresses {
buffer.Value = addr
buffer = buffer.Next()
}
logger.Trace("establishing stream to peer", "peer_id", peer.ID)
retryCtx, cancel := context.WithCancel(ctx)
cancelFns[peer.ID] = cancel
// Establish a stream-specific retry so that retrying stream/conn errors isn't dependent on state store changes.
go retryLoopBackoff(retryCtx, func() error {
// Try a new address on each iteration by advancing the ring buffer on errors.
defer func() {
buffer = buffer.Next()
}()
addr, ok := buffer.Value.(string)
if !ok {
return fmt.Errorf("peer server address type %T is not a string", buffer.Value)
}
logger.Trace("dialing peer", "peer_id", peer.ID, "addr", addr)
conn, err := grpc.DialContext(retryCtx, addr,
grpc.WithContextDialer(newPeerDialer(addr)),
grpc.WithBlock(),
tlsOption,
)
if err != nil {
return fmt.Errorf("failed to dial: %w", err)
}
defer conn.Close()
client := pbpeering.NewPeeringServiceClient(conn)
stream, err := client.StreamResources(retryCtx)
if err != nil {
return err
}
err = s.peeringService.HandleStream(peer.ID, peer.PeerID, stream)
if err == nil {
// This will cancel the retry-er context, letting us break out of this loop when we want to shut down the stream.
cancel()
}
return err
}, func(err error) {
// TODO(peering): These errors should be reported in the peer status, otherwise they're only in the logs.
// Lockable status isn't available here though. Could report it via the peering.Service?
logger.Error("error managing peering stream", "peer_id", peer.ID, "error", err)
})
return nil
}
func newPeerDialer(peerAddr string) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (net.Conn, error) {
d := net.Dialer{}
conn, err := d.DialContext(ctx, "tcp", peerAddr)
if err != nil {
return nil, err
}
// TODO(peering): This is going to need to be revisited. This type uses the TLS settings configured on the agent, but
// for peering we never want mutual TLS because the client peer doesn't share its CA cert.
_, err = conn.Write([]byte{byte(pool.RPCGRPC)})
if err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
}

View File

@ -0,0 +1,197 @@
package consul
import (
"context"
"encoding/base64"
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc"
)
func TestLeader_PeeringSync_Lifecycle_ClientDeletion(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
// TODO(peering): Configure with TLS
_, s1 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s1.dc1"
c.Datacenter = "dc1"
c.TLSConfig.Domain = "consul"
})
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create a peering by generating a token
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
conn, err := grpc.DialContext(ctx, s1.config.RPCAddr.String(),
grpc.WithContextDialer(newServerDialer(s1.config.RPCAddr.String())),
grpc.WithInsecure(),
grpc.WithBlock())
require.NoError(t, err)
defer conn.Close()
peeringClient := pbpeering.NewPeeringServiceClient(conn)
req := pbpeering.GenerateTokenRequest{
PeerName: "my-peer-s2",
}
resp, err := peeringClient.GenerateToken(ctx, &req)
require.NoError(t, err)
tokenJSON, err := base64.StdEncoding.DecodeString(resp.PeeringToken)
require.NoError(t, err)
var token structs.PeeringToken
require.NoError(t, json.Unmarshal(tokenJSON, &token))
// S1 should not have a stream tracked for dc2 because s1 generated a token for baz, and therefore needs to wait to be dialed.
time.Sleep(1 * time.Second)
_, found := s1.peeringService.StreamStatus(token.PeerID)
require.False(t, found)
// Bring up s2 and store s1's token so that it attempts to dial.
_, s2 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s2.dc2"
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc2"
})
testrpc.WaitForLeader(t, s2.RPC, "dc2")
// Simulate a peering initiation event by writing a peering with data from a peering token.
// Eventually the leader in dc2 should dial and connect to the leader in dc1.
p := &pbpeering.Peering{
Name: "my-peer-s1",
PeerID: token.PeerID,
PeerCAPems: token.CA,
PeerServerName: token.ServerName,
PeerServerAddresses: token.ServerAddresses,
}
require.True(t, p.ShouldDial())
// We maintain a pointer to the peering on the write so that we can get the ID without needing to re-query the state store.
require.NoError(t, s2.fsm.State().PeeringWrite(1000, p))
retry.Run(t, func(r *retry.R) {
status, found := s2.peeringService.StreamStatus(p.ID)
require.True(r, found)
require.True(r, status.Connected)
})
// Delete the peering to trigger the termination sequence
require.NoError(t, s2.fsm.State().PeeringDelete(2000, state.Query{
Value: "my-peer-s1",
}))
s2.logger.Trace("deleted peering for my-peer-s1")
retry.Run(t, func(r *retry.R) {
_, found := s2.peeringService.StreamStatus(p.ID)
require.False(r, found)
})
// s1 should have also marked the peering as terminated.
retry.Run(t, func(r *retry.R) {
_, peering, err := s1.fsm.State().PeeringRead(nil, state.Query{
Value: "my-peer-s2",
})
require.NoError(r, err)
require.Equal(r, pbpeering.PeeringState_TERMINATED, peering.State)
})
}
func TestLeader_PeeringSync_Lifecycle_ServerDeletion(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
// TODO(peering): Configure with TLS
_, s1 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s1.dc1"
c.Datacenter = "dc1"
c.TLSConfig.Domain = "consul"
})
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create a peering by generating a token
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
conn, err := grpc.DialContext(ctx, s1.config.RPCAddr.String(),
grpc.WithContextDialer(newServerDialer(s1.config.RPCAddr.String())),
grpc.WithInsecure(),
grpc.WithBlock())
require.NoError(t, err)
defer conn.Close()
peeringClient := pbpeering.NewPeeringServiceClient(conn)
req := pbpeering.GenerateTokenRequest{
PeerName: "my-peer-s2",
}
resp, err := peeringClient.GenerateToken(ctx, &req)
require.NoError(t, err)
tokenJSON, err := base64.StdEncoding.DecodeString(resp.PeeringToken)
require.NoError(t, err)
var token structs.PeeringToken
require.NoError(t, json.Unmarshal(tokenJSON, &token))
// Bring up s2 and store s1's token so that it attempts to dial.
_, s2 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s2.dc2"
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc2"
})
testrpc.WaitForLeader(t, s2.RPC, "dc2")
// Simulate a peering initiation event by writing a peering with data from a peering token.
// Eventually the leader in dc2 should dial and connect to the leader in dc1.
p := &pbpeering.Peering{
Name: "my-peer-s1",
PeerID: token.PeerID,
PeerCAPems: token.CA,
PeerServerName: token.ServerName,
PeerServerAddresses: token.ServerAddresses,
}
require.True(t, p.ShouldDial())
// We maintain a pointer to the peering on the write so that we can get the ID without needing to re-query the state store.
require.NoError(t, s2.fsm.State().PeeringWrite(1000, p))
retry.Run(t, func(r *retry.R) {
status, found := s2.peeringService.StreamStatus(p.ID)
require.True(r, found)
require.True(r, status.Connected)
})
// Delete the peering from the server peer to trigger the termination sequence
require.NoError(t, s1.fsm.State().PeeringDelete(2000, state.Query{
Value: "my-peer-s2",
}))
s2.logger.Trace("deleted peering for my-peer-s1")
retry.Run(t, func(r *retry.R) {
_, found := s1.peeringService.StreamStatus(p.PeerID)
require.False(r, found)
})
// s2 should have received the termination message and updated the peering state
retry.Run(t, func(r *retry.R) {
_, peering, err := s2.fsm.State().PeeringRead(nil, state.Query{
Value: "my-peer-s1",
})
require.NoError(r, err)
require.Equal(r, pbpeering.PeeringState_TERMINATED, peering.State)
})
}

View File

@ -51,7 +51,7 @@ func TestLeader_RegisterMember(t *testing.T) {
// Client should be registered // Client should be registered
state := s1.fsm.State() state := s1.fsm.State()
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(c1.config.NodeName, nil) _, node, err := state.GetNode(c1.config.NodeName, nil, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -61,7 +61,7 @@ func TestLeader_RegisterMember(t *testing.T) {
}) })
// Should have a check // Should have a check
_, checks, err := state.NodeChecks(nil, c1.config.NodeName, nil) _, checks, err := state.NodeChecks(nil, c1.config.NodeName, nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -80,7 +80,7 @@ func TestLeader_RegisterMember(t *testing.T) {
// Server should be registered // Server should be registered
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(s1.config.NodeName, nil) _, node, err := state.GetNode(s1.config.NodeName, nil, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -90,7 +90,7 @@ func TestLeader_RegisterMember(t *testing.T) {
}) })
// Service should be registered // Service should be registered
_, services, err := state.NodeServices(nil, s1.config.NodeName, nil) _, services, err := state.NodeServices(nil, s1.config.NodeName, nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -129,7 +129,7 @@ func TestLeader_FailedMember(t *testing.T) {
// Should be registered // Should be registered
state := s1.fsm.State() state := s1.fsm.State()
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(c1.config.NodeName, nil) _, node, err := state.GetNode(c1.config.NodeName, nil, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -139,7 +139,7 @@ func TestLeader_FailedMember(t *testing.T) {
}) })
// Should have a check // Should have a check
_, checks, err := state.NodeChecks(nil, c1.config.NodeName, nil) _, checks, err := state.NodeChecks(nil, c1.config.NodeName, nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -154,7 +154,7 @@ func TestLeader_FailedMember(t *testing.T) {
} }
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, checks, err = state.NodeChecks(nil, c1.config.NodeName, nil) _, checks, err = state.NodeChecks(nil, c1.config.NodeName, nil, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -193,7 +193,7 @@ func TestLeader_LeftMember(t *testing.T) {
// Should be registered // Should be registered
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(c1.config.NodeName, nil) _, node, err := state.GetNode(c1.config.NodeName, nil, "")
require.NoError(r, err) require.NoError(r, err)
require.NotNil(r, node, "client not registered") require.NotNil(r, node, "client not registered")
}) })
@ -204,7 +204,7 @@ func TestLeader_LeftMember(t *testing.T) {
// Should be deregistered // Should be deregistered
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(c1.config.NodeName, nil) _, node, err := state.GetNode(c1.config.NodeName, nil, "")
require.NoError(r, err) require.NoError(r, err)
require.Nil(r, node, "client still registered") require.Nil(r, node, "client still registered")
}) })
@ -236,7 +236,7 @@ func TestLeader_ReapMember(t *testing.T) {
// Should be registered // Should be registered
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(c1.config.NodeName, nil) _, node, err := state.GetNode(c1.config.NodeName, nil, "")
require.NoError(r, err) require.NoError(r, err)
require.NotNil(r, node, "client not registered") require.NotNil(r, node, "client not registered")
}) })
@ -257,7 +257,7 @@ func TestLeader_ReapMember(t *testing.T) {
// anti-entropy will put it back. // anti-entropy will put it back.
reaped := false reaped := false
for start := time.Now(); time.Since(start) < 5*time.Second; { for start := time.Now(); time.Since(start) < 5*time.Second; {
_, node, err := state.GetNode(c1.config.NodeName, nil) _, node, err := state.GetNode(c1.config.NodeName, nil, "")
require.NoError(t, err) require.NoError(t, err)
if node == nil { if node == nil {
reaped = true reaped = true
@ -296,7 +296,7 @@ func TestLeader_ReapOrLeftMember_IgnoreSelf(t *testing.T) {
// Should be registered // Should be registered
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(nodeName, nil) _, node, err := state.GetNode(nodeName, nil, "")
require.NoError(r, err) require.NoError(r, err)
require.NotNil(r, node, "server not registered") require.NotNil(r, node, "server not registered")
}) })
@ -318,7 +318,7 @@ func TestLeader_ReapOrLeftMember_IgnoreSelf(t *testing.T) {
// anti-entropy will put it back if it did get deleted. // anti-entropy will put it back if it did get deleted.
reaped := false reaped := false
for start := time.Now(); time.Since(start) < 5*time.Second; { for start := time.Now(); time.Since(start) < 5*time.Second; {
_, node, err := state.GetNode(nodeName, nil) _, node, err := state.GetNode(nodeName, nil, "")
require.NoError(t, err) require.NoError(t, err)
if node == nil { if node == nil {
reaped = true reaped = true
@ -402,7 +402,7 @@ func TestLeader_CheckServersMeta(t *testing.T) {
} }
// s3 should be registered // s3 should be registered
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, service, err := state.NodeService(s3.config.NodeName, "consul", &consulService.EnterpriseMeta) _, service, err := state.NodeService(s3.config.NodeName, "consul", &consulService.EnterpriseMeta, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -438,7 +438,7 @@ func TestLeader_CheckServersMeta(t *testing.T) {
if err != nil { if err != nil {
r.Fatalf("Unexpected error :%v", err) r.Fatalf("Unexpected error :%v", err)
} }
_, service, err := state.NodeService(s3.config.NodeName, "consul", &consulService.EnterpriseMeta) _, service, err := state.NodeService(s3.config.NodeName, "consul", &consulService.EnterpriseMeta, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -506,7 +506,7 @@ func TestLeader_ReapServer(t *testing.T) {
// s3 should be registered // s3 should be registered
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(s3.config.NodeName, nil) _, node, err := state.GetNode(s3.config.NodeName, nil, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -527,7 +527,7 @@ func TestLeader_ReapServer(t *testing.T) {
} }
// s3 should be deregistered // s3 should be deregistered
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(s3.config.NodeName, nil) _, node, err := state.GetNode(s3.config.NodeName, nil, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -582,7 +582,7 @@ func TestLeader_Reconcile_ReapMember(t *testing.T) {
// Node should be gone // Node should be gone
state := s1.fsm.State() state := s1.fsm.State()
_, node, err := state.GetNode("no-longer-around", nil) _, node, err := state.GetNode("no-longer-around", nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -615,7 +615,7 @@ func TestLeader_Reconcile(t *testing.T) {
// Should not be registered // Should not be registered
state := s1.fsm.State() state := s1.fsm.State()
_, node, err := state.GetNode(c1.config.NodeName, nil) _, node, err := state.GetNode(c1.config.NodeName, nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -625,7 +625,7 @@ func TestLeader_Reconcile(t *testing.T) {
// Should be registered // Should be registered
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(c1.config.NodeName, nil) _, node, err := state.GetNode(c1.config.NodeName, nil, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -657,7 +657,7 @@ func TestLeader_Reconcile_Races(t *testing.T) {
state := s1.fsm.State() state := s1.fsm.State()
var nodeAddr string var nodeAddr string
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(c1.config.NodeName, nil) _, node, err := state.GetNode(c1.config.NodeName, nil, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -693,7 +693,7 @@ func TestLeader_Reconcile_Races(t *testing.T) {
if err := s1.reconcile(); err != nil { if err := s1.reconcile(); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
_, node, err := state.GetNode(c1.config.NodeName, nil) _, node, err := state.GetNode(c1.config.NodeName, nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -707,7 +707,7 @@ func TestLeader_Reconcile_Races(t *testing.T) {
// Fail the member and wait for the health to go critical. // Fail the member and wait for the health to go critical.
c1.Shutdown() c1.Shutdown()
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, checks, err := state.NodeChecks(nil, c1.config.NodeName, nil) _, checks, err := state.NodeChecks(nil, c1.config.NodeName, nil, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -720,7 +720,7 @@ func TestLeader_Reconcile_Races(t *testing.T) {
}) })
// Make sure the metadata didn't get clobbered. // Make sure the metadata didn't get clobbered.
_, node, err = state.GetNode(c1.config.NodeName, nil) _, node, err = state.GetNode(c1.config.NodeName, nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -835,7 +835,7 @@ func TestLeader_LeftLeader(t *testing.T) {
// Verify the old leader is deregistered // Verify the old leader is deregistered
state := remain.fsm.State() state := remain.fsm.State()
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
_, node, err := state.GetNode(leader.config.NodeName, nil) _, node, err := state.GetNode(leader.config.NodeName, nil, "")
if err != nil { if err != nil {
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
@ -2336,7 +2336,7 @@ func TestLeader_EnableVirtualIPs(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
_, node, err := state.NodeService("bar", "tgate1", nil) _, node, err := state.NodeService("bar", "tgate1", nil, "")
require.NoError(t, err) require.NoError(t, err)
sn := structs.ServiceName{Name: "api"} sn := structs.ServiceName{Name: "api"}
key := structs.ServiceGatewayVirtualIPTag(sn) key := structs.ServiceGatewayVirtualIPTag(sn)

View File

@ -0,0 +1,126 @@
package consul
import (
"encoding/base64"
"encoding/json"
"fmt"
"strconv"
"google.golang.org/grpc"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/rpc/peering"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
)
type peeringBackend struct {
srv *Server
connPool GRPCClientConner
apply *peeringApply
}
var _ peering.Backend = (*peeringBackend)(nil)
// NewPeeringBackend returns a peering.Backend implementation that is bound to the given server.
func NewPeeringBackend(srv *Server, connPool GRPCClientConner) peering.Backend {
return &peeringBackend{
srv: srv,
connPool: connPool,
apply: &peeringApply{srv: srv},
}
}
func (b *peeringBackend) Forward(info structs.RPCInfo, f func(*grpc.ClientConn) error) (handled bool, err error) {
// Only forward the request if the dc in the request matches the server's datacenter.
if info.RequestDatacenter() != "" && info.RequestDatacenter() != b.srv.config.Datacenter {
return false, fmt.Errorf("requests to generate peering tokens cannot be forwarded to remote datacenters")
}
return b.srv.ForwardGRPC(b.connPool, info, f)
}
// GetAgentCACertificates gets the server's raw CA data from its TLS Configurator.
func (b *peeringBackend) GetAgentCACertificates() ([]string, error) {
// TODO(peering): handle empty CA pems
return b.srv.tlsConfigurator.ManualCAPems(), nil
}
// GetServerAddresses looks up server node addresses from the state store.
func (b *peeringBackend) GetServerAddresses() ([]string, error) {
state := b.srv.fsm.State()
_, nodes, err := state.ServiceNodes(nil, "consul", structs.DefaultEnterpriseMetaInDefaultPartition(), structs.DefaultPeerKeyword)
if err != nil {
return nil, err
}
var addrs []string
for _, node := range nodes {
addrs = append(addrs, node.Address+":"+strconv.Itoa(node.ServicePort))
}
return addrs, nil
}
// GetServerName returns the SNI to be returned in the peering token data which
// will be used by peers when establishing peering connections over TLS.
func (b *peeringBackend) GetServerName() string {
return b.srv.tlsConfigurator.ServerSNI(b.srv.config.Datacenter, "")
}
// EncodeToken encodes a peering token as a bas64-encoded representation of JSON (for now).
func (b *peeringBackend) EncodeToken(tok *structs.PeeringToken) ([]byte, error) {
jsonToken, err := json.Marshal(tok)
if err != nil {
return nil, fmt.Errorf("failed to marshal token: %w", err)
}
return []byte(base64.StdEncoding.EncodeToString(jsonToken)), nil
}
// DecodeToken decodes a peering token from a base64-encoded JSON byte array (for now).
func (b *peeringBackend) DecodeToken(tokRaw []byte) (*structs.PeeringToken, error) {
tokJSONRaw, err := base64.StdEncoding.DecodeString(string(tokRaw))
if err != nil {
return nil, fmt.Errorf("failed to decode token: %w", err)
}
var tok structs.PeeringToken
if err := json.Unmarshal(tokJSONRaw, &tok); err != nil {
return nil, err
}
return &tok, nil
}
func (s peeringBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) {
return s.srv.publisher.Subscribe(req)
}
func (b *peeringBackend) Store() peering.Store {
return b.srv.fsm.State()
}
func (b *peeringBackend) Apply() peering.Apply {
return b.apply
}
func (b *peeringBackend) EnterpriseCheckPartitions(partition string) error {
return b.enterpriseCheckPartitions(partition)
}
type peeringApply struct {
srv *Server
}
func (a *peeringApply) PeeringWrite(req *pbpeering.PeeringWriteRequest) error {
_, err := a.srv.raftApplyProtobuf(structs.PeeringWriteType, req)
return err
}
func (a *peeringApply) PeeringDelete(req *pbpeering.PeeringDeleteRequest) error {
_, err := a.srv.raftApplyProtobuf(structs.PeeringDeleteType, req)
return err
}
// TODO(peering): This needs RPC metrics interceptor since it's not triggered by an RPC.
func (a *peeringApply) PeeringTerminateByID(req *pbpeering.PeeringTerminateByIDRequest) error {
_, err := a.srv.raftApplyProtobuf(structs.PeeringTerminateByIDType, req)
return err
}
var _ peering.Apply = (*peeringApply)(nil)

View File

@ -0,0 +1,15 @@
//go:build !consulent
// +build !consulent
package consul
import (
"fmt"
)
func (b *peeringBackend) enterpriseCheckPartitions(partition string) error {
if partition != "" {
return fmt.Errorf("Partitions are a Consul Enterprise feature")
}
return nil
}

View File

@ -0,0 +1,51 @@
//go:build !consulent
// +build !consulent
package consul
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
gogrpc "google.golang.org/grpc"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/testrpc"
)
func TestPeeringBackend_RejectsPartition(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
_, s1 := testServerWithConfig(t, func(c *Config) {
c.Datacenter = "dc1"
c.Bootstrap = true
})
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// make a grpc client to dial s1 directly
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
t.Cleanup(cancel)
conn, err := gogrpc.DialContext(ctx, s1.config.RPCAddr.String(),
gogrpc.WithContextDialer(newServerDialer(s1.config.RPCAddr.String())),
gogrpc.WithInsecure(),
gogrpc.WithBlock())
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })
peeringClient := pbpeering.NewPeeringServiceClient(conn)
req := pbpeering.GenerateTokenRequest{
Datacenter: "dc1",
Partition: "test",
}
_, err = peeringClient.GenerateToken(ctx, &req)
require.Error(t, err)
require.Contains(t, err.Error(), "Partitions are a Consul Enterprise feature")
}

View File

@ -0,0 +1,115 @@
package consul
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
gogrpc "google.golang.org/grpc"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/testrpc"
)
func TestPeeringBackend_DoesNotForwardToDifferentDC(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
_, s1 := testServerDC(t, "dc1")
_, s2 := testServerDC(t, "dc2")
joinWAN(t, s2, s1)
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForLeader(t, s2.RPC, "dc2")
// make a grpc client to dial s2 directly
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
conn, err := gogrpc.DialContext(ctx, s2.config.RPCAddr.String(),
gogrpc.WithContextDialer(newServerDialer(s2.config.RPCAddr.String())),
gogrpc.WithInsecure(),
gogrpc.WithBlock())
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })
peeringClient := pbpeering.NewPeeringServiceClient(conn)
// GenerateToken request should fail against dc1, because we are dialing dc2. The GenerateToken request should never be forwarded across datacenters.
req := pbpeering.GenerateTokenRequest{
PeerName: "peer1-usw1",
Datacenter: "dc1",
}
_, err = peeringClient.GenerateToken(ctx, &req)
require.Error(t, err)
require.Contains(t, err.Error(), "requests to generate peering tokens cannot be forwarded to remote datacenters")
}
func TestPeeringBackend_ForwardToLeader(t *testing.T) {
t.Parallel()
_, conf1 := testServerConfig(t)
server1, err := newServer(t, conf1)
require.NoError(t, err)
_, conf2 := testServerConfig(t)
conf2.Bootstrap = false
server2, err := newServer(t, conf2)
require.NoError(t, err)
// Join a 2nd server (not the leader)
testrpc.WaitForLeader(t, server1.RPC, "dc1")
joinLAN(t, server2, server1)
testrpc.WaitForLeader(t, server2.RPC, "dc1")
// Make a write call to server2 and make sure it gets forwarded to server1
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
// Dial server2 directly
conn, err := gogrpc.DialContext(ctx, server2.config.RPCAddr.String(),
gogrpc.WithContextDialer(newServerDialer(server2.config.RPCAddr.String())),
gogrpc.WithInsecure(),
gogrpc.WithBlock())
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })
peeringClient := pbpeering.NewPeeringServiceClient(conn)
runStep(t, "forward a write", func(t *testing.T) {
// Do the grpc Write call to server2
req := pbpeering.GenerateTokenRequest{
Datacenter: "dc1",
PeerName: "foo",
}
_, err := peeringClient.GenerateToken(ctx, &req)
require.NoError(t, err)
// TODO(peering) check that state store is updated on leader, indicating a forwarded request after state store
// is implemented.
})
}
func newServerDialer(serverAddr string) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (net.Conn, error) {
d := net.Dialer{}
conn, err := d.DialContext(ctx, "tcp", serverAddr)
if err != nil {
return nil, err
}
_, err = conn.Write([]byte{byte(pool.RPCGRPC)})
if err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
}

View File

@ -3,12 +3,12 @@ package prepared_query
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"sort"
"testing" "testing"
"sort" "github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/require"
) )
func TestWalk_ServiceQuery(t *testing.T) { func TestWalk_ServiceQuery(t *testing.T) {
@ -42,6 +42,7 @@ func TestWalk_ServiceQuery(t *testing.T) {
".Tags[0]:tag1", ".Tags[0]:tag1",
".Tags[1]:tag2", ".Tags[1]:tag2",
".Tags[2]:tag3", ".Tags[2]:tag3",
".PeerName:",
} }
expected = append(expected, entMetaWalkFields...) expected = append(expected, entMetaWalkFields...)
sort.Strings(expected) sort.Strings(expected)

View File

@ -404,7 +404,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
qs.Node = args.Agent.Node qs.Node = args.Agent.Node
} else if qs.Node == "_ip" { } else if qs.Node == "_ip" {
if args.Source.Ip != "" { if args.Source.Ip != "" {
_, nodes, err := state.Nodes(nil, structs.NodeEnterpriseMetaInDefaultPartition()) _, nodes, err := state.Nodes(nil, structs.NodeEnterpriseMetaInDefaultPartition(), structs.TODOPeerKeyword)
if err != nil { if err != nil {
return err return err
} }
@ -534,7 +534,7 @@ func (p *PreparedQuery) execute(query *structs.PreparedQuery,
f = state.CheckConnectServiceNodes f = state.CheckConnectServiceNodes
} }
_, nodes, err := f(nil, query.Service.Service, &query.Service.EnterpriseMeta) _, nodes, err := f(nil, query.Service.Service, &query.Service.EnterpriseMeta, query.Service.PeerName)
if err != nil { if err != nil {
return err return err
} }

View File

@ -16,24 +16,20 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/hashicorp/consul/agent/rpc/middleware"
"github.com/hashicorp/go-version"
"go.etcd.io/bbolt"
"github.com/armon/go-metrics" "github.com/armon/go-metrics"
"github.com/hashicorp/consul-net-rpc/net/rpc"
connlimit "github.com/hashicorp/go-connlimit" connlimit "github.com/hashicorp/go-connlimit"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-version"
"github.com/hashicorp/raft" "github.com/hashicorp/raft"
autopilot "github.com/hashicorp/raft-autopilot" autopilot "github.com/hashicorp/raft-autopilot"
raftboltdb "github.com/hashicorp/raft-boltdb/v2" raftboltdb "github.com/hashicorp/raft-boltdb/v2"
"github.com/hashicorp/serf/serf" "github.com/hashicorp/serf/serf"
"go.etcd.io/bbolt"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/hashicorp/consul-net-rpc/net/rpc"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/authmethod" "github.com/hashicorp/consul/agent/consul/authmethod"
"github.com/hashicorp/consul/agent/consul/authmethod/ssoauth" "github.com/hashicorp/consul/agent/consul/authmethod/ssoauth"
@ -50,11 +46,14 @@ import (
"github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/router" "github.com/hashicorp/consul/agent/router"
"github.com/hashicorp/consul/agent/rpc/middleware"
"github.com/hashicorp/consul/agent/rpc/peering"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/lib/routine" "github.com/hashicorp/consul/lib/routine"
"github.com/hashicorp/consul/logging" "github.com/hashicorp/consul/logging"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/proto/pbsubscribe" "github.com/hashicorp/consul/proto/pbsubscribe"
"github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types" "github.com/hashicorp/consul/types"
@ -124,6 +123,7 @@ const (
intermediateCertRenewWatchRoutineName = "intermediate cert renew watch" intermediateCertRenewWatchRoutineName = "intermediate cert renew watch"
backgroundCAInitializationRoutineName = "CA initialization" backgroundCAInitializationRoutineName = "CA initialization"
virtualIPCheckRoutineName = "virtual IP version check" virtualIPCheckRoutineName = "virtual IP version check"
peeringStreamsRoutineName = "streaming peering resources"
) )
var ( var (
@ -356,6 +356,9 @@ type Server struct {
// this into the Deps struct and created it much earlier on. // this into the Deps struct and created it much earlier on.
publisher *stream.EventPublisher publisher *stream.EventPublisher
// peering is a service used to handle peering streams.
peeringService *peering.Service
// embedded struct to hold all the enterprise specific data // embedded struct to hold all the enterprise specific data
EnterpriseServer EnterpriseServer
} }
@ -730,12 +733,19 @@ func NewServer(config *Config, flat Deps, publicGRPCServer *grpc.Server) (*Serve
} }
func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler { func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler {
p := peering.NewService(
deps.Logger.Named("grpc-api.peering"),
NewPeeringBackend(s, deps.GRPCConnPool),
)
s.peeringService = p
register := func(srv *grpc.Server) { register := func(srv *grpc.Server) {
if config.RPCConfig.EnableStreaming { if config.RPCConfig.EnableStreaming {
pbsubscribe.RegisterStateChangeSubscriptionServer(srv, subscribe.NewServer( pbsubscribe.RegisterStateChangeSubscriptionServer(srv, subscribe.NewServer(
&subscribeBackend{srv: s, connPool: deps.GRPCConnPool}, &subscribeBackend{srv: s, connPool: deps.GRPCConnPool},
deps.Logger.Named("grpc-api.subscription"))) deps.Logger.Named("grpc-api.subscription")))
} }
pbpeering.RegisterPeeringServiceServer(srv, s.peeringService)
s.registerEnterpriseGRPCServices(deps, srv) s.registerEnterpriseGRPCServices(deps, srv)
// Note: this public gRPC service is also exposed on the private server to // Note: this public gRPC service is also exposed on the private server to
@ -783,7 +793,7 @@ func (s *Server) setupRaft() error {
}() }()
var serverAddressProvider raft.ServerAddressProvider = nil var serverAddressProvider raft.ServerAddressProvider = nil
if s.config.RaftConfig.ProtocolVersion >= 3 { //ServerAddressProvider needs server ids to work correctly, which is only supported in protocol version 3 or higher if s.config.RaftConfig.ProtocolVersion >= 3 { // ServerAddressProvider needs server ids to work correctly, which is only supported in protocol version 3 or higher
serverAddressProvider = s.serverLookup serverAddressProvider = s.serverLookup
} }

View File

@ -237,6 +237,8 @@ func testServerWithConfig(t *testing.T, configOpts ...func(*Config)) (string, *S
r.Fatalf("err: %v", err) r.Fatalf("err: %v", err)
} }
}) })
t.Cleanup(func() { srv.Shutdown() })
return dir, srv return dir, srv
} }

View File

@ -239,6 +239,26 @@ func prefixIndexFromUUIDQuery(arg interface{}) ([]byte, error) {
return nil, fmt.Errorf("unexpected type %T for Query prefix index", arg) return nil, fmt.Errorf("unexpected type %T for Query prefix index", arg)
} }
func prefixIndexFromUUIDWithPeerQuery(arg interface{}) ([]byte, error) {
switch v := arg.(type) {
case Query:
var b indexBuilder
peername := v.PeerOrEmpty()
if peername == "" {
b.String(structs.LocalPeerKeyword)
} else {
b.String(strings.ToLower(peername))
}
uuidBytes, err := variableLengthUUIDStringToBytes(v.Value)
if err != nil {
return nil, err
}
return append(b.Bytes(), uuidBytes...), nil
}
return nil, fmt.Errorf("unexpected type %T for Query prefix index", arg)
}
func multiIndexPolicyFromACLRole(raw interface{}) ([][]byte, error) { func multiIndexPolicyFromACLRole(raw interface{}) ([][]byte, error) {
role, ok := raw.(*structs.ACLRole) role, ok := raw.(*structs.ACLRole)
if !ok { if !ok {

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbservice"
"github.com/hashicorp/consul/proto/pbsubscribe" "github.com/hashicorp/consul/proto/pbsubscribe"
) )
@ -17,33 +18,13 @@ import (
type EventSubjectService struct { type EventSubjectService struct {
Key string Key string
EnterpriseMeta acl.EnterpriseMeta EnterpriseMeta acl.EnterpriseMeta
PeerName string
overrideKey string overrideKey string
overrideNamespace string overrideNamespace string
overridePartition string overridePartition string
} }
// String satisfies the stream.Subject interface.
func (s EventSubjectService) String() string {
partition := s.EnterpriseMeta.PartitionOrDefault()
if v := s.overridePartition; v != "" {
partition = strings.ToLower(v)
}
namespace := s.EnterpriseMeta.NamespaceOrDefault()
if v := s.overrideNamespace; v != "" {
namespace = strings.ToLower(v)
}
key := s.Key
if v := s.overrideKey; v != "" {
key = v
}
key = strings.ToLower(key)
return partition + "/" + namespace + "/" + key
}
// EventPayloadCheckServiceNode is used as the Payload for a stream.Event to // EventPayloadCheckServiceNode is used as the Payload for a stream.Event to
// indicates changes to a CheckServiceNode for service health. // indicates changes to a CheckServiceNode for service health.
// //
@ -62,6 +43,7 @@ type EventPayloadCheckServiceNode struct {
} }
func (e EventPayloadCheckServiceNode) HasReadPermission(authz acl.Authorizer) bool { func (e EventPayloadCheckServiceNode) HasReadPermission(authz acl.Authorizer) bool {
// TODO(peering): figure out how authz works for peered data
return e.Value.CanRead(authz) == acl.Allow return e.Value.CanRead(authz) == acl.Allow
} }
@ -76,6 +58,31 @@ func (e EventPayloadCheckServiceNode) Subject() stream.Subject {
} }
} }
func (e EventPayloadCheckServiceNode) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
return &pbsubscribe.Event{
Index: idx,
Payload: &pbsubscribe.Event_ServiceHealth{
ServiceHealth: &pbsubscribe.ServiceHealthUpdate{
Op: e.Op,
CheckServiceNode: pbservice.NewCheckServiceNodeFromStructs(e.Value),
},
},
}
}
func PBToStreamSubscribeRequest(req *pbsubscribe.SubscribeRequest, entMeta acl.EnterpriseMeta) *stream.SubscribeRequest {
return &stream.SubscribeRequest{
Topic: req.Topic,
Subject: EventSubjectService{
Key: req.Key,
EnterpriseMeta: entMeta,
PeerName: req.PeerName,
},
Token: req.Token,
Index: req.Index,
}
}
// serviceHealthSnapshot returns a stream.SnapshotFunc that provides a snapshot // serviceHealthSnapshot returns a stream.SnapshotFunc that provides a snapshot
// of stream.Events that describe the current state of a service health query. // of stream.Events that describe the current state of a service health query.
func (s *Store) ServiceHealthSnapshot(req stream.SubscribeRequest, buf stream.SnapshotAppender) (index uint64, err error) { func (s *Store) ServiceHealthSnapshot(req stream.SubscribeRequest, buf stream.SnapshotAppender) (index uint64, err error) {
@ -89,7 +96,7 @@ func (s *Store) ServiceHealthSnapshot(req stream.SubscribeRequest, buf stream.Sn
return 0, fmt.Errorf("expected SubscribeRequest.Subject to be a: state.EventSubjectService, was a: %T", req.Subject) return 0, fmt.Errorf("expected SubscribeRequest.Subject to be a: state.EventSubjectService, was a: %T", req.Subject)
} }
idx, nodes, err := checkServiceNodesTxn(tx, nil, subject.Key, connect, &subject.EnterpriseMeta) idx, nodes, err := checkServiceNodesTxn(tx, nil, subject.Key, connect, &subject.EnterpriseMeta, subject.PeerName)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -127,6 +134,7 @@ type nodeServiceTuple struct {
Node string Node string
ServiceID string ServiceID string
EntMeta acl.EnterpriseMeta EntMeta acl.EnterpriseMeta
PeerName string
} }
func newNodeServiceTupleFromServiceNode(sn *structs.ServiceNode) nodeServiceTuple { func newNodeServiceTupleFromServiceNode(sn *structs.ServiceNode) nodeServiceTuple {
@ -134,6 +142,7 @@ func newNodeServiceTupleFromServiceNode(sn *structs.ServiceNode) nodeServiceTupl
Node: strings.ToLower(sn.Node), Node: strings.ToLower(sn.Node),
ServiceID: sn.ServiceID, ServiceID: sn.ServiceID,
EntMeta: sn.EnterpriseMeta, EntMeta: sn.EnterpriseMeta,
PeerName: sn.PeerName,
} }
} }
@ -142,6 +151,7 @@ func newNodeServiceTupleFromServiceHealthCheck(hc *structs.HealthCheck) nodeServ
Node: strings.ToLower(hc.Node), Node: strings.ToLower(hc.Node),
ServiceID: hc.ServiceID, ServiceID: hc.ServiceID,
EntMeta: hc.EnterpriseMeta, EntMeta: hc.EnterpriseMeta,
PeerName: hc.PeerName,
} }
} }
@ -153,6 +163,7 @@ type serviceChange struct {
type nodeTuple struct { type nodeTuple struct {
Node string Node string
Partition string Partition string
PeerName string
} }
var serviceChangeIndirect = serviceChange{changeType: changeIndirect} var serviceChangeIndirect = serviceChange{changeType: changeIndirect}
@ -286,7 +297,7 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event
} }
// Rebuild events for all services on this node // Rebuild events for all services on this node
es, err := newServiceHealthEventsForNode(tx, changes.Index, node.Node, es, err := newServiceHealthEventsForNode(tx, changes.Index, node.Node,
structs.WildcardEnterpriseMetaInPartition(node.Partition)) structs.WildcardEnterpriseMetaInPartition(node.Partition), node.PeerName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -342,6 +353,7 @@ func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event
q := Query{ q := Query{
Value: gs.Gateway.Name, Value: gs.Gateway.Name,
EnterpriseMeta: gatewayName.EnterpriseMeta, EnterpriseMeta: gatewayName.EnterpriseMeta,
PeerName: structs.TODOPeerKeyword,
} }
_, nodes, err := serviceNodesTxn(tx, nil, indexService, q) _, nodes, err := serviceNodesTxn(tx, nil, indexService, q)
if err != nil { if err != nil {
@ -504,6 +516,8 @@ func connectEventsByServiceKind(tx ReadTxn, origEvent stream.Event) ([]stream.Ev
case structs.ServiceKindTerminatingGateway: case structs.ServiceKindTerminatingGateway:
var result []stream.Event var result []stream.Event
// TODO(peering): handle terminating gateways somehow
sn := structs.ServiceName{ sn := structs.ServiceName{
Name: node.Service.Service, Name: node.Service.Service,
EnterpriseMeta: node.Service.EnterpriseMeta, EnterpriseMeta: node.Service.EnterpriseMeta,
@ -551,16 +565,17 @@ func getPayloadCheckServiceNode(payload stream.Payload) *structs.CheckServiceNod
// given node. This mirrors some of the the logic in the oddly-named // given node. This mirrors some of the the logic in the oddly-named
// parseCheckServiceNodes but is more efficient since we know they are all on // parseCheckServiceNodes but is more efficient since we know they are all on
// the same node. // the same node.
func newServiceHealthEventsForNode(tx ReadTxn, idx uint64, node string, entMeta *acl.EnterpriseMeta) ([]stream.Event, error) { func newServiceHealthEventsForNode(tx ReadTxn, idx uint64, node string, entMeta *acl.EnterpriseMeta, peerName string) ([]stream.Event, error) {
services, err := tx.Get(tableServices, indexNode, Query{ services, err := tx.Get(tableServices, indexNode, Query{
Value: node, Value: node,
EnterpriseMeta: *entMeta, EnterpriseMeta: *entMeta,
PeerName: peerName,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
n, checksFunc, err := getNodeAndChecks(tx, node, entMeta) n, checksFunc, err := getNodeAndChecks(tx, node, entMeta, peerName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -578,11 +593,12 @@ func newServiceHealthEventsForNode(tx ReadTxn, idx uint64, node string, entMeta
// getNodeAndNodeChecks returns a the node structure and a function that returns // getNodeAndNodeChecks returns a the node structure and a function that returns
// the full list of checks for a specific service on that node. // the full list of checks for a specific service on that node.
func getNodeAndChecks(tx ReadTxn, node string, entMeta *acl.EnterpriseMeta) (*structs.Node, serviceChecksFunc, error) { func getNodeAndChecks(tx ReadTxn, node string, entMeta *acl.EnterpriseMeta, peerName string) (*structs.Node, serviceChecksFunc, error) {
// Fetch the node // Fetch the node
nodeRaw, err := tx.First(tableNodes, indexID, Query{ nodeRaw, err := tx.First(tableNodes, indexID, Query{
Value: node, Value: node,
EnterpriseMeta: *entMeta, EnterpriseMeta: *entMeta,
PeerName: peerName,
}) })
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -595,6 +611,7 @@ func getNodeAndChecks(tx ReadTxn, node string, entMeta *acl.EnterpriseMeta) (*st
iter, err := tx.Get(tableChecks, indexNode, Query{ iter, err := tx.Get(tableChecks, indexNode, Query{
Value: node, Value: node,
EnterpriseMeta: *entMeta, EnterpriseMeta: *entMeta,
PeerName: peerName,
}) })
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -629,7 +646,7 @@ func getNodeAndChecks(tx ReadTxn, node string, entMeta *acl.EnterpriseMeta) (*st
type serviceChecksFunc func(serviceID string) structs.HealthChecks type serviceChecksFunc func(serviceID string) structs.HealthChecks
func newServiceHealthEventForService(tx ReadTxn, idx uint64, tuple nodeServiceTuple) (stream.Event, error) { func newServiceHealthEventForService(tx ReadTxn, idx uint64, tuple nodeServiceTuple) (stream.Event, error) {
n, checksFunc, err := getNodeAndChecks(tx, tuple.Node, &tuple.EntMeta) n, checksFunc, err := getNodeAndChecks(tx, tuple.Node, &tuple.EntMeta, tuple.PeerName)
if err != nil { if err != nil {
return stream.Event{}, err return stream.Event{}, err
} }
@ -638,6 +655,7 @@ func newServiceHealthEventForService(tx ReadTxn, idx uint64, tuple nodeServiceTu
EnterpriseMeta: tuple.EntMeta, EnterpriseMeta: tuple.EntMeta,
Node: tuple.Node, Node: tuple.Node,
Service: tuple.ServiceID, Service: tuple.ServiceID,
PeerName: tuple.PeerName,
}) })
if err != nil { if err != nil {
return stream.Event{}, err return stream.Event{}, err
@ -690,6 +708,7 @@ func newServiceHealthEventDeregister(idx uint64, sn *structs.ServiceNode) stream
Node: &structs.Node{ Node: &structs.Node{
Node: sn.Node, Node: sn.Node,
Partition: entMeta.PartitionOrEmpty(), Partition: entMeta.PartitionOrEmpty(),
PeerName: sn.PeerName,
}, },
Service: sn.ToNodeService(), Service: sn.ToNodeService(),
} }

View File

@ -13,6 +13,7 @@ func (nst nodeServiceTuple) nodeTuple() nodeTuple {
return nodeTuple{ return nodeTuple{
Node: strings.ToLower(nst.Node), Node: strings.ToLower(nst.Node),
Partition: "", Partition: "",
PeerName: nst.PeerName,
} }
} }
@ -20,6 +21,7 @@ func newNodeTupleFromNode(node *structs.Node) nodeTuple {
return nodeTuple{ return nodeTuple{
Node: strings.ToLower(node.Node), Node: strings.ToLower(node.Node),
Partition: "", Partition: "",
PeerName: node.PeerName,
} }
} }
@ -27,5 +29,20 @@ func newNodeTupleFromHealthCheck(hc *structs.HealthCheck) nodeTuple {
return nodeTuple{ return nodeTuple{
Node: strings.ToLower(hc.Node), Node: strings.ToLower(hc.Node),
Partition: "", Partition: "",
PeerName: hc.PeerName,
} }
} }
// String satisfies the stream.Subject interface.
func (s EventSubjectService) String() string {
key := s.Key
if v := s.overrideKey; v != "" {
key = v
}
key = strings.ToLower(key)
if s.PeerName == "" {
return key
}
return s.PeerName + "/" + key
}

View File

@ -0,0 +1,45 @@
//go:build !consulent
// +build !consulent
package state
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/structs"
)
func TestEventPayloadCheckServiceNode_Subject_OSS(t *testing.T) {
for desc, tc := range map[string]struct {
evt EventPayloadCheckServiceNode
sub string
}{
"mixed casing": {
EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
Service: "FoO",
},
},
},
"foo",
},
"override key": {
EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
Service: "foo",
},
},
overrideKey: "bar",
},
"bar",
},
} {
t.Run(desc, func(t *testing.T) {
require.Equal(t, tc.sub, tc.evt.Subject().String())
})
}
}

View File

@ -16,49 +16,6 @@ import (
"github.com/hashicorp/consul/types" "github.com/hashicorp/consul/types"
) )
func TestEventPayloadCheckServiceNode_Subject(t *testing.T) {
for desc, tc := range map[string]struct {
evt EventPayloadCheckServiceNode
sub string
}{
"default partition and namespace": {
EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
Service: "foo",
},
},
},
"default/default/foo",
},
"mixed casing": {
EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
Service: "FoO",
},
},
},
"default/default/foo",
},
"override key": {
EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
Service: "foo",
},
},
overrideKey: "bar",
},
"default/default/bar",
},
} {
t.Run(desc, func(t *testing.T) {
require.Equal(t, tc.sub, tc.evt.Subject().String())
})
}
}
func TestServiceHealthSnapshot(t *testing.T) { func TestServiceHealthSnapshot(t *testing.T) {
store := NewStateStore(nil) store := NewStateStore(nil)
@ -307,7 +264,7 @@ func TestServiceHealthEventsFromChanges(t *testing.T) {
return nil return nil
}, },
Mutate: func(s *Store, tx *txn) error { Mutate: func(s *Store, tx *txn) error {
return s.deleteServiceTxn(tx, tx.Index, "node1", "web", nil) return s.deleteServiceTxn(tx, tx.Index, "node1", "web", nil, "")
}, },
WantEvents: []stream.Event{ WantEvents: []stream.Event{
// Should only publish deregistration for that service // Should only publish deregistration for that service
@ -327,7 +284,7 @@ func TestServiceHealthEventsFromChanges(t *testing.T) {
return nil return nil
}, },
Mutate: func(s *Store, tx *txn) error { Mutate: func(s *Store, tx *txn) error {
return s.deleteNodeTxn(tx, tx.Index, "node1", nil) return s.deleteNodeTxn(tx, tx.Index, "node1", nil, "")
}, },
WantEvents: []stream.Event{ WantEvents: []stream.Event{
// Should publish deregistration events for all services // Should publish deregistration events for all services
@ -380,7 +337,7 @@ func TestServiceHealthEventsFromChanges(t *testing.T) {
return s.ensureRegistrationTxn(tx, tx.Index, false, testServiceRegistration(t, "web", regConnectNative), false) return s.ensureRegistrationTxn(tx, tx.Index, false, testServiceRegistration(t, "web", regConnectNative), false)
}, },
Mutate: func(s *Store, tx *txn) error { Mutate: func(s *Store, tx *txn) error {
return s.deleteServiceTxn(tx, tx.Index, "node1", "web", nil) return s.deleteServiceTxn(tx, tx.Index, "node1", "web", nil, "")
}, },
WantEvents: []stream.Event{ WantEvents: []stream.Event{
// We should see both a regular service dereg event and a connect one // We should see both a regular service dereg event and a connect one
@ -444,7 +401,7 @@ func TestServiceHealthEventsFromChanges(t *testing.T) {
}, },
Mutate: func(s *Store, tx *txn) error { Mutate: func(s *Store, tx *txn) error {
// Delete only the sidecar // Delete only the sidecar
return s.deleteServiceTxn(tx, tx.Index, "node1", "web_sidecar_proxy", nil) return s.deleteServiceTxn(tx, tx.Index, "node1", "web_sidecar_proxy", nil, "")
}, },
WantEvents: []stream.Event{ WantEvents: []stream.Event{
// We should see both a regular service dereg event and a connect one // We should see both a regular service dereg event and a connect one
@ -910,7 +867,7 @@ func TestServiceHealthEventsFromChanges(t *testing.T) {
}, },
Mutate: func(s *Store, tx *txn) error { Mutate: func(s *Store, tx *txn) error {
// Delete only the node-level check // Delete only the node-level check
if err := s.deleteCheckTxn(tx, tx.Index, "node1", "serf-health", nil); err != nil { if err := s.deleteCheckTxn(tx, tx.Index, "node1", "serf-health", nil, ""); err != nil {
return err return err
} }
return nil return nil
@ -964,11 +921,11 @@ func TestServiceHealthEventsFromChanges(t *testing.T) {
}, },
Mutate: func(s *Store, tx *txn) error { Mutate: func(s *Store, tx *txn) error {
// Delete the service-level check for the main service // Delete the service-level check for the main service
if err := s.deleteCheckTxn(tx, tx.Index, "node1", "service:web", nil); err != nil { if err := s.deleteCheckTxn(tx, tx.Index, "node1", "service:web", nil, ""); err != nil {
return err return err
} }
// Also delete for a proxy // Also delete for a proxy
if err := s.deleteCheckTxn(tx, tx.Index, "node1", "service:web_sidecar_proxy", nil); err != nil { if err := s.deleteCheckTxn(tx, tx.Index, "node1", "service:web_sidecar_proxy", nil, ""); err != nil {
return err return err
} }
return nil return nil
@ -1029,10 +986,10 @@ func TestServiceHealthEventsFromChanges(t *testing.T) {
// In one transaction the operator moves the web service and it's // In one transaction the operator moves the web service and it's
// sidecar from node2 back to node1 and deletes them from node2 // sidecar from node2 back to node1 and deletes them from node2
if err := s.deleteServiceTxn(tx, tx.Index, "node2", "web", nil); err != nil { if err := s.deleteServiceTxn(tx, tx.Index, "node2", "web", nil, ""); err != nil {
return err return err
} }
if err := s.deleteServiceTxn(tx, tx.Index, "node2", "web_sidecar_proxy", nil); err != nil { if err := s.deleteServiceTxn(tx, tx.Index, "node2", "web_sidecar_proxy", nil, ""); err != nil {
return err return err
} }
@ -1544,7 +1501,7 @@ func TestServiceHealthEventsFromChanges(t *testing.T) {
testServiceRegistration(t, "tgate1", regTerminatingGateway), false) testServiceRegistration(t, "tgate1", regTerminatingGateway), false)
}, },
Mutate: func(s *Store, tx *txn) error { Mutate: func(s *Store, tx *txn) error {
return s.deleteServiceTxn(tx, tx.Index, "node1", "srv1", nil) return s.deleteServiceTxn(tx, tx.Index, "node1", "srv1", nil, "")
}, },
WantEvents: []stream.Event{ WantEvents: []stream.Event{
testServiceHealthDeregistrationEvent(t, "srv1"), testServiceHealthDeregistrationEvent(t, "srv1"),
@ -1649,7 +1606,7 @@ func TestServiceHealthEventsFromChanges(t *testing.T) {
testServiceRegistration(t, "tgate1", regTerminatingGateway), false) testServiceRegistration(t, "tgate1", regTerminatingGateway), false)
}, },
Mutate: func(s *Store, tx *txn) error { Mutate: func(s *Store, tx *txn) error {
return s.deleteServiceTxn(tx, tx.Index, "node1", "tgate1", structs.DefaultEnterpriseMetaInDefaultPartition()) return s.deleteServiceTxn(tx, tx.Index, "node1", "tgate1", structs.DefaultEnterpriseMetaInDefaultPartition(), "")
}, },
WantEvents: []stream.Event{ WantEvents: []stream.Event{
testServiceHealthDeregistrationEvent(t, testServiceHealthDeregistrationEvent(t,

View File

@ -15,54 +15,83 @@ import (
func withEnterpriseSchema(_ *memdb.DBSchema) {} func withEnterpriseSchema(_ *memdb.DBSchema) {}
func serviceIndexName(name string, _ *acl.EnterpriseMeta) string { func serviceIndexName(name string, _ *acl.EnterpriseMeta, peerName string) string {
return fmt.Sprintf("service.%s", name) return peeredIndexEntryName(fmt.Sprintf("service.%s", name), peerName)
} }
func serviceKindIndexName(kind structs.ServiceKind, _ *acl.EnterpriseMeta) string { func serviceKindIndexName(kind structs.ServiceKind, _ *acl.EnterpriseMeta, peerName string) string {
return "service_kind." + kind.Normalized() base := "service_kind." + kind.Normalized()
return peeredIndexEntryName(base, peerName)
} }
func catalogUpdateNodesIndexes(tx WriteTxn, idx uint64, entMeta *acl.EnterpriseMeta) error { func catalogUpdateNodesIndexes(tx WriteTxn, idx uint64, _ *acl.EnterpriseMeta, peerName string) error {
// overall nodes index // overall nodes index
if err := indexUpdateMaxTxn(tx, idx, tableNodes); err != nil { if err := indexUpdateMaxTxn(tx, idx, tableNodes); err != nil {
return fmt.Errorf("failed updating index: %s", err) return fmt.Errorf("failed updating index: %s", err)
} }
// peered index
if err := indexUpdateMaxTxn(tx, idx, peeredIndexEntryName(tableNodes, peerName)); err != nil {
return fmt.Errorf("failed updating partitioned+peered index for nodes table: %w", err)
}
return nil return nil
} }
func catalogUpdateServicesIndexes(tx WriteTxn, idx uint64, _ *acl.EnterpriseMeta) error { // catalogUpdateServicesIndexes upserts the max index for the entire services table with varying levels
// of granularity (no-op if `idx` is lower than what exists for that index key):
// - all services
// - all services in a specified peer (including internal)
func catalogUpdateServicesIndexes(tx WriteTxn, idx uint64, _ *acl.EnterpriseMeta, peerName string) error {
// overall services index // overall services index
if err := indexUpdateMaxTxn(tx, idx, tableServices); err != nil { if err := indexUpdateMaxTxn(tx, idx, tableServices); err != nil {
return fmt.Errorf("failed updating index: %s", err) return fmt.Errorf("failed updating index for services table: %w", err)
}
// peered services index
if err := indexUpdateMaxTxn(tx, idx, peeredIndexEntryName(tableServices, peerName)); err != nil {
return fmt.Errorf("failed updating peered index for services table: %w", err)
} }
return nil return nil
} }
func catalogUpdateServiceKindIndexes(tx WriteTxn, kind structs.ServiceKind, idx uint64, _ *acl.EnterpriseMeta) error { // catalogUpdateServiceKindIndexes upserts the max index for the ServiceKind with varying levels
// of granularity (no-op if `idx` is lower than what exists for that index key):
// - all services of ServiceKind
// - all services of ServiceKind in a specified peer (including internal)
func catalogUpdateServiceKindIndexes(tx WriteTxn, idx uint64, kind structs.ServiceKind, _ *acl.EnterpriseMeta, peerName string) error {
base := "service_kind." + kind.Normalized()
// service-kind index // service-kind index
if err := indexUpdateMaxTxn(tx, idx, serviceKindIndexName(kind, nil)); err != nil { if err := indexUpdateMaxTxn(tx, idx, base); err != nil {
return fmt.Errorf("failed updating index: %s", err) return fmt.Errorf("failed updating index for service kind: %w", err)
} }
// peered index
if err := indexUpdateMaxTxn(tx, idx, peeredIndexEntryName(base, peerName)); err != nil {
return fmt.Errorf("failed updating peered index for service kind: %w", err)
}
return nil return nil
} }
func catalogUpdateServiceIndexes(tx WriteTxn, serviceName string, idx uint64, _ *acl.EnterpriseMeta) error { func catalogUpdateServiceIndexes(tx WriteTxn, idx uint64, serviceName string, _ *acl.EnterpriseMeta, peerName string) error {
// per-service index // per-service index
if err := indexUpdateMaxTxn(tx, idx, serviceIndexName(serviceName, nil)); err != nil { if err := indexUpdateMaxTxn(tx, idx, serviceIndexName(serviceName, nil, peerName)); err != nil {
return fmt.Errorf("failed updating index: %s", err) return fmt.Errorf("failed updating index: %s", err)
} }
return nil return nil
} }
func catalogUpdateServiceExtinctionIndex(tx WriteTxn, idx uint64, _ *acl.EnterpriseMeta) error { func catalogUpdateServiceExtinctionIndex(tx WriteTxn, idx uint64, _ *acl.EnterpriseMeta, peerName string) error {
if err := tx.Insert(tableIndex, &IndexEntry{indexServiceExtinction, idx}); err != nil { if err := indexUpdateMaxTxn(tx, idx, indexServiceExtinction); err != nil {
return fmt.Errorf("failed updating missing service extinction index: %s", err) return fmt.Errorf("failed updating missing service extinction index: %w", err)
} }
// update the peer index
if err := indexUpdateMaxTxn(tx, idx, peeredIndexEntryName(indexServiceExtinction, peerName)); err != nil {
return fmt.Errorf("failed updating missing service extinction peered index: %w", err)
}
return nil return nil
} }
@ -75,14 +104,14 @@ func catalogInsertNode(tx WriteTxn, node *structs.Node) error {
return fmt.Errorf("failed inserting node: %s", err) return fmt.Errorf("failed inserting node: %s", err)
} }
if err := catalogUpdateNodesIndexes(tx, node.ModifyIndex, node.GetEnterpriseMeta()); err != nil { if err := catalogUpdateNodesIndexes(tx, node.ModifyIndex, node.GetEnterpriseMeta(), node.PeerName); err != nil {
return err return err
} }
// Update the node's service indexes as the node information is included // Update the node's service indexes as the node information is included
// in health queries and we would otherwise miss node updates in some cases // in health queries and we would otherwise miss node updates in some cases
// for those queries. // for those queries.
if err := updateAllServiceIndexesOfNode(tx, node.ModifyIndex, node.Node, node.GetEnterpriseMeta()); err != nil { if err := updateAllServiceIndexesOfNode(tx, node.ModifyIndex, node.Node, node.GetEnterpriseMeta(), node.PeerName); err != nil {
return fmt.Errorf("failed updating index: %s", err) return fmt.Errorf("failed updating index: %s", err)
} }
@ -95,73 +124,95 @@ func catalogInsertService(tx WriteTxn, svc *structs.ServiceNode) error {
return fmt.Errorf("failed inserting service: %s", err) return fmt.Errorf("failed inserting service: %s", err)
} }
if err := catalogUpdateServicesIndexes(tx, svc.ModifyIndex, &svc.EnterpriseMeta); err != nil { if err := catalogUpdateServicesIndexes(tx, svc.ModifyIndex, &svc.EnterpriseMeta, svc.PeerName); err != nil {
return err return err
} }
if err := catalogUpdateServiceIndexes(tx, svc.ServiceName, svc.ModifyIndex, &svc.EnterpriseMeta); err != nil { if err := catalogUpdateServiceIndexes(tx, svc.ModifyIndex, svc.ServiceName, &svc.EnterpriseMeta, svc.PeerName); err != nil {
return err return err
} }
if err := catalogUpdateServiceKindIndexes(tx, svc.ServiceKind, svc.ModifyIndex, &svc.EnterpriseMeta); err != nil { if err := catalogUpdateServiceKindIndexes(tx, svc.ModifyIndex, svc.ServiceKind, &svc.EnterpriseMeta, svc.PeerName); err != nil {
return err return err
} }
return nil return nil
} }
func catalogNodesMaxIndex(tx ReadTxn, entMeta *acl.EnterpriseMeta) uint64 { func catalogNodesMaxIndex(tx ReadTxn, _ *acl.EnterpriseMeta, peerName string) uint64 {
return maxIndexTxn(tx, tableNodes) return maxIndexTxn(tx, peeredIndexEntryName(tableNodes, peerName))
} }
func catalogServicesMaxIndex(tx ReadTxn, _ *acl.EnterpriseMeta) uint64 { func catalogServicesMaxIndex(tx ReadTxn, _ *acl.EnterpriseMeta, peerName string) uint64 {
return maxIndexTxn(tx, tableServices) return maxIndexTxn(tx, peeredIndexEntryName(tableServices, peerName))
} }
func catalogServiceMaxIndex(tx ReadTxn, serviceName string, _ *acl.EnterpriseMeta) (<-chan struct{}, interface{}, error) { func catalogServiceMaxIndex(tx ReadTxn, serviceName string, _ *acl.EnterpriseMeta, peerName string) (<-chan struct{}, interface{}, error) {
return tx.FirstWatch(tableIndex, "id", serviceIndexName(serviceName, nil)) return tx.FirstWatch(tableIndex, indexID, serviceIndexName(serviceName, nil, peerName))
} }
func catalogServiceKindMaxIndex(tx ReadTxn, ws memdb.WatchSet, kind structs.ServiceKind, entMeta *acl.EnterpriseMeta) uint64 { func catalogServiceKindMaxIndex(tx ReadTxn, ws memdb.WatchSet, kind structs.ServiceKind, _ *acl.EnterpriseMeta, peerName string) uint64 {
return maxIndexWatchTxn(tx, ws, serviceKindIndexName(kind, nil)) return maxIndexWatchTxn(tx, ws, serviceKindIndexName(kind, nil, peerName))
} }
func catalogServiceListNoWildcard(tx ReadTxn, _ *acl.EnterpriseMeta) (memdb.ResultIterator, error) { func catalogServiceListNoWildcard(tx ReadTxn, _ *acl.EnterpriseMeta, peerName string) (memdb.ResultIterator, error) {
return tx.Get(tableServices, indexID) q := Query{
} PeerName: peerName,
func catalogServiceListByNode(tx ReadTxn, node string, _ *acl.EnterpriseMeta, _ bool) (memdb.ResultIterator, error) {
return tx.Get(tableServices, indexNode, Query{Value: node})
}
func catalogServiceLastExtinctionIndex(tx ReadTxn, _ *acl.EnterpriseMeta) (interface{}, error) {
return tx.First(tableIndex, "id", indexServiceExtinction)
}
func catalogMaxIndex(tx ReadTxn, _ *acl.EnterpriseMeta, checks bool) uint64 {
if checks {
return maxIndexTxn(tx, tableNodes, tableServices, tableChecks)
} }
return maxIndexTxn(tx, tableNodes, tableServices) return tx.Get(tableServices, indexID+"_prefix", q)
} }
func catalogMaxIndexWatch(tx ReadTxn, ws memdb.WatchSet, _ *acl.EnterpriseMeta, checks bool) uint64 { func catalogServiceListByNode(tx ReadTxn, node string, _ *acl.EnterpriseMeta, peerName string, _ bool) (memdb.ResultIterator, error) {
return tx.Get(tableServices, indexNode, Query{Value: node, PeerName: peerName})
}
func catalogServiceLastExtinctionIndex(tx ReadTxn, _ *acl.EnterpriseMeta, peerName string) (interface{}, error) {
return tx.First(tableIndex, indexID, peeredIndexEntryName(indexServiceExtinction, peerName))
}
func catalogMaxIndex(tx ReadTxn, _ *acl.EnterpriseMeta, peerName string, checks bool) uint64 {
if checks { if checks {
return maxIndexWatchTxn(tx, ws, tableNodes, tableServices, tableChecks) return maxIndexTxn(tx,
peeredIndexEntryName(tableChecks, peerName),
peeredIndexEntryName(tableServices, peerName),
peeredIndexEntryName(tableNodes, peerName),
)
} }
return maxIndexWatchTxn(tx, ws, tableNodes, tableServices) return maxIndexTxn(tx,
peeredIndexEntryName(tableServices, peerName),
peeredIndexEntryName(tableNodes, peerName),
)
} }
func catalogUpdateCheckIndexes(tx WriteTxn, idx uint64, _ *acl.EnterpriseMeta) error { func catalogMaxIndexWatch(tx ReadTxn, ws memdb.WatchSet, _ *acl.EnterpriseMeta, peerName string, checks bool) uint64 {
// TODO(peering_indexes): pipe peerName here
if checks {
return maxIndexWatchTxn(tx, ws,
peeredIndexEntryName(tableChecks, peerName),
peeredIndexEntryName(tableServices, peerName),
peeredIndexEntryName(tableNodes, peerName),
)
}
return maxIndexWatchTxn(tx, ws,
peeredIndexEntryName(tableServices, peerName),
peeredIndexEntryName(tableNodes, peerName),
)
}
func catalogUpdateCheckIndexes(tx WriteTxn, idx uint64, _ *acl.EnterpriseMeta, peerName string) error {
// update the universal index entry // update the universal index entry
if err := tx.Insert(tableIndex, &IndexEntry{tableChecks, idx}); err != nil { if err := indexUpdateMaxTxn(tx, idx, tableChecks); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
if err := indexUpdateMaxTxn(tx, idx, peeredIndexEntryName(tableChecks, peerName)); err != nil {
return fmt.Errorf("failed updating index: %s", err) return fmt.Errorf("failed updating index: %s", err)
} }
return nil return nil
} }
func catalogChecksMaxIndex(tx ReadTxn, _ *acl.EnterpriseMeta) uint64 { func catalogChecksMaxIndex(tx ReadTxn, _ *acl.EnterpriseMeta, peerName string) uint64 {
return maxIndexTxn(tx, tableChecks) return maxIndexTxn(tx, peeredIndexEntryName(tableChecks, peerName))
} }
func catalogListChecksByNode(tx ReadTxn, q Query) (memdb.ResultIterator, error) { func catalogListChecksByNode(tx ReadTxn, q Query) (memdb.ResultIterator, error) {
@ -174,7 +225,7 @@ func catalogInsertCheck(tx WriteTxn, chk *structs.HealthCheck, idx uint64) error
return fmt.Errorf("failed inserting check: %s", err) return fmt.Errorf("failed inserting check: %s", err)
} }
if err := catalogUpdateCheckIndexes(tx, idx, &chk.EnterpriseMeta); err != nil { if err := catalogUpdateCheckIndexes(tx, idx, &chk.EnterpriseMeta, chk.PeerName); err != nil {
return err return err
} }
@ -207,3 +258,10 @@ func indexFromKindServiceName(arg interface{}) ([]byte, error) {
return nil, fmt.Errorf("type must be KindServiceNameQuery or *KindServiceName: %T", arg) return nil, fmt.Errorf("type must be KindServiceNameQuery or *KindServiceName: %T", arg)
} }
} }
func updateKindServiceNamesIndex(tx WriteTxn, idx uint64, kind structs.ServiceKind, entMeta acl.EnterpriseMeta) error {
if err := indexUpdateMaxTxn(tx, idx, kindServiceNameIndexName(kind.Normalized())); err != nil {
return fmt.Errorf("failed updating %s table index: %v", tableKindServiceNames, err)
}
return nil
}

View File

@ -19,6 +19,14 @@ func testIndexerTableChecks() map[string]indexerTestCase {
CheckID: "CheckID", CheckID: "CheckID",
Status: "PASSING", Status: "PASSING",
} }
objWPeer := &structs.HealthCheck{
Node: "NoDe",
ServiceID: "SeRvIcE",
ServiceName: "ServiceName",
CheckID: "CheckID",
Status: "PASSING",
PeerName: "Peer1",
}
return map[string]indexerTestCase{ return map[string]indexerTestCase{
indexID: { indexID: {
read: indexValue{ read: indexValue{
@ -26,11 +34,11 @@ func testIndexerTableChecks() map[string]indexerTestCase {
Node: "NoDe", Node: "NoDe",
CheckID: "CheckId", CheckID: "CheckId",
}, },
expected: []byte("node\x00checkid\x00"), expected: []byte("internal\x00node\x00checkid\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("node\x00checkid\x00"), expected: []byte("internal\x00node\x00checkid\x00"),
}, },
prefix: []indexValue{ prefix: []indexValue{
{ {
@ -39,28 +47,75 @@ func testIndexerTableChecks() map[string]indexerTestCase {
}, },
{ {
source: Query{Value: "nOdE"}, source: Query{Value: "nOdE"},
expected: []byte("node\x00"), expected: []byte("internal\x00node\x00"),
},
},
extra: []indexerTestCase{
{
read: indexValue{
source: NodeCheckQuery{
Node: "NoDe",
CheckID: "CheckId",
PeerName: "Peer1",
},
expected: []byte("peer1\x00node\x00checkid\x00"),
},
write: indexValue{
source: objWPeer,
expected: []byte("peer1\x00node\x00checkid\x00"),
},
prefix: []indexValue{
{
source: Query{Value: "nOdE",
PeerName: "Peer1"},
expected: []byte("peer1\x00node\x00"),
},
},
}, },
}, },
}, },
indexStatus: { indexStatus: {
read: indexValue{ read: indexValue{
source: Query{Value: "PASSING"}, source: Query{Value: "PASSING"},
expected: []byte("passing\x00"), expected: []byte("internal\x00passing\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("passing\x00"), expected: []byte("internal\x00passing\x00"),
},
extra: []indexerTestCase{
{
read: indexValue{
source: Query{Value: "PASSING", PeerName: "Peer1"},
expected: []byte("peer1\x00passing\x00"),
},
write: indexValue{
source: objWPeer,
expected: []byte("peer1\x00passing\x00"),
},
},
}, },
}, },
indexService: { indexService: {
read: indexValue{ read: indexValue{
source: Query{Value: "ServiceName"}, source: Query{Value: "ServiceName"},
expected: []byte("servicename\x00"), expected: []byte("internal\x00servicename\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("servicename\x00"), expected: []byte("internal\x00servicename\x00"),
},
extra: []indexerTestCase{
{
read: indexValue{
source: Query{Value: "ServiceName", PeerName: "Peer1"},
expected: []byte("peer1\x00servicename\x00"),
},
write: indexValue{
source: objWPeer,
expected: []byte("peer1\x00servicename\x00"),
},
},
}, },
}, },
indexNodeService: { indexNodeService: {
@ -69,11 +124,27 @@ func testIndexerTableChecks() map[string]indexerTestCase {
Node: "NoDe", Node: "NoDe",
Service: "SeRvIcE", Service: "SeRvIcE",
}, },
expected: []byte("node\x00service\x00"), expected: []byte("internal\x00node\x00service\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("node\x00service\x00"), expected: []byte("internal\x00node\x00service\x00"),
},
extra: []indexerTestCase{
{
read: indexValue{
source: NodeServiceQuery{
Node: "NoDe",
PeerName: "Peer1",
Service: "SeRvIcE",
},
expected: []byte("peer1\x00node\x00service\x00"),
},
write: indexValue{
source: objWPeer,
expected: []byte("peer1\x00node\x00service\x00"),
},
},
}, },
}, },
indexNode: { indexNode: {
@ -81,11 +152,26 @@ func testIndexerTableChecks() map[string]indexerTestCase {
source: Query{ source: Query{
Value: "NoDe", Value: "NoDe",
}, },
expected: []byte("node\x00"), expected: []byte("internal\x00node\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("node\x00"), expected: []byte("internal\x00node\x00"),
},
extra: []indexerTestCase{
{
read: indexValue{
source: Query{
Value: "NoDe",
PeerName: "Peer1",
},
expected: []byte("peer1\x00node\x00"),
},
write: indexValue{
source: objWPeer,
expected: []byte("peer1\x00node\x00"),
},
},
}, },
}, },
} }
@ -186,11 +272,11 @@ func testIndexerTableNodes() map[string]indexerTestCase {
indexID: { indexID: {
read: indexValue{ read: indexValue{
source: Query{Value: "NoDeId"}, source: Query{Value: "NoDeId"},
expected: []byte("nodeid\x00"), expected: []byte("internal\x00nodeid\x00"),
}, },
write: indexValue{ write: indexValue{
source: &structs.Node{Node: "NoDeId"}, source: &structs.Node{Node: "NoDeId"},
expected: []byte("nodeid\x00"), expected: []byte("internal\x00nodeid\x00"),
}, },
prefix: []indexValue{ prefix: []indexValue{
{ {
@ -203,38 +289,90 @@ func testIndexerTableNodes() map[string]indexerTestCase {
}, },
{ {
source: Query{Value: "NoDeId"}, source: Query{Value: "NoDeId"},
expected: []byte("nodeid\x00"), expected: []byte("internal\x00nodeid\x00"),
},
{
source: Query{},
expected: []byte("internal\x00"),
},
},
extra: []indexerTestCase{
{
read: indexValue{
source: Query{Value: "NoDeId", PeerName: "Peer1"},
expected: []byte("peer1\x00nodeid\x00"),
},
write: indexValue{
source: &structs.Node{Node: "NoDeId", PeerName: "Peer1"},
expected: []byte("peer1\x00nodeid\x00"),
},
prefix: []indexValue{
{
source: Query{PeerName: "Peer1"},
expected: []byte("peer1\x00"),
},
{
source: Query{Value: "NoDeId", PeerName: "Peer1"},
expected: []byte("peer1\x00nodeid\x00"),
},
},
}, },
}, },
}, },
indexUUID: { indexUUID: {
read: indexValue{ read: indexValue{
source: Query{Value: uuid}, source: Query{Value: uuid},
expected: uuidBuf, expected: append([]byte("internal\x00"), uuidBuf...),
}, },
write: indexValue{ write: indexValue{
source: &structs.Node{ source: &structs.Node{
ID: types.NodeID(uuid), ID: types.NodeID(uuid),
Node: "NoDeId", Node: "NoDeId",
}, },
expected: uuidBuf, expected: append([]byte("internal\x00"), uuidBuf...),
}, },
prefix: []indexValue{ prefix: []indexValue{
{
source: (*acl.EnterpriseMeta)(nil),
expected: nil,
},
{
source: acl.EnterpriseMeta{},
expected: nil,
},
{ // partial length { // partial length
source: Query{Value: uuid[:6]}, source: Query{Value: uuid[:6]},
expected: uuidBuf[:3], expected: append([]byte("internal\x00"), uuidBuf[:3]...),
}, },
{ // full length { // full length
source: Query{Value: uuid}, source: Query{Value: uuid},
expected: uuidBuf, expected: append([]byte("internal\x00"), uuidBuf...),
},
{
source: Query{},
expected: []byte("internal\x00"),
},
},
extra: []indexerTestCase{
{
read: indexValue{
source: Query{Value: uuid, PeerName: "Peer1"},
expected: append([]byte("peer1\x00"), uuidBuf...),
},
write: indexValue{
source: &structs.Node{
ID: types.NodeID(uuid),
PeerName: "Peer1",
Node: "NoDeId",
},
expected: append([]byte("peer1\x00"), uuidBuf...),
},
prefix: []indexValue{
{ // partial length
source: Query{Value: uuid[:6], PeerName: "Peer1"},
expected: append([]byte("peer1\x00"), uuidBuf[:3]...),
},
{ // full length
source: Query{Value: uuid, PeerName: "Peer1"},
expected: append([]byte("peer1\x00"), uuidBuf...),
},
{
source: Query{PeerName: "Peer1"},
expected: []byte("peer1\x00"),
},
},
}, },
}, },
}, },
@ -244,7 +382,7 @@ func testIndexerTableNodes() map[string]indexerTestCase {
Key: "KeY", Key: "KeY",
Value: "VaLuE", Value: "VaLuE",
}, },
expected: []byte("KeY\x00VaLuE\x00"), expected: []byte("internal\x00KeY\x00VaLuE\x00"),
}, },
writeMulti: indexValueMulti{ writeMulti: indexValueMulti{
source: &structs.Node{ source: &structs.Node{
@ -255,8 +393,34 @@ func testIndexerTableNodes() map[string]indexerTestCase {
}, },
}, },
expected: [][]byte{ expected: [][]byte{
[]byte("MaP-kEy-1\x00mAp-VaL-1\x00"), []byte("internal\x00MaP-kEy-1\x00mAp-VaL-1\x00"),
[]byte("mAp-KeY-2\x00MaP-vAl-2\x00"), []byte("internal\x00mAp-KeY-2\x00MaP-vAl-2\x00"),
},
},
extra: []indexerTestCase{
{
read: indexValue{
source: KeyValueQuery{
Key: "KeY",
Value: "VaLuE",
PeerName: "Peer1",
},
expected: []byte("peer1\x00KeY\x00VaLuE\x00"),
},
writeMulti: indexValueMulti{
source: &structs.Node{
Node: "NoDeId",
Meta: map[string]string{
"MaP-kEy-1": "mAp-VaL-1",
"mAp-KeY-2": "MaP-vAl-2",
},
PeerName: "Peer1",
},
expected: [][]byte{
[]byte("peer1\x00MaP-kEy-1\x00mAp-VaL-1\x00"),
[]byte("peer1\x00mAp-KeY-2\x00MaP-vAl-2\x00"),
},
},
}, },
}, },
}, },
@ -271,6 +435,12 @@ func testIndexerTableServices() map[string]indexerTestCase {
ServiceID: "SeRviCe", ServiceID: "SeRviCe",
ServiceName: "ServiceName", ServiceName: "ServiceName",
} }
objWPeer := &structs.ServiceNode{
Node: "NoDeId",
ServiceID: "SeRviCe",
ServiceName: "ServiceName",
PeerName: "Peer1",
}
return map[string]indexerTestCase{ return map[string]indexerTestCase{
indexID: { indexID: {
@ -279,11 +449,11 @@ func testIndexerTableServices() map[string]indexerTestCase {
Node: "NoDeId", Node: "NoDeId",
Service: "SeRvIcE", Service: "SeRvIcE",
}, },
expected: []byte("nodeid\x00service\x00"), expected: []byte("internal\x00nodeid\x00service\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("nodeid\x00service\x00"), expected: []byte("internal\x00nodeid\x00service\x00"),
}, },
prefix: []indexValue{ prefix: []indexValue{
{ {
@ -294,9 +464,39 @@ func testIndexerTableServices() map[string]indexerTestCase {
source: acl.EnterpriseMeta{}, source: acl.EnterpriseMeta{},
expected: nil, expected: nil,
}, },
{
source: Query{},
expected: []byte("internal\x00"),
},
{ {
source: Query{Value: "NoDeId"}, source: Query{Value: "NoDeId"},
expected: []byte("nodeid\x00"), expected: []byte("internal\x00nodeid\x00"),
},
},
extra: []indexerTestCase{
{
read: indexValue{
source: NodeServiceQuery{
Node: "NoDeId",
PeerName: "Peer1",
Service: "SeRvIcE",
},
expected: []byte("peer1\x00nodeid\x00service\x00"),
},
write: indexValue{
source: objWPeer,
expected: []byte("peer1\x00nodeid\x00service\x00"),
},
prefix: []indexValue{
{
source: Query{Value: "NoDeId", PeerName: "Peer1"},
expected: []byte("peer1\x00nodeid\x00"),
},
{
source: Query{PeerName: "Peer1"},
expected: []byte("peer1\x00"),
},
},
}, },
}, },
}, },
@ -305,34 +505,61 @@ func testIndexerTableServices() map[string]indexerTestCase {
source: Query{ source: Query{
Value: "NoDeId", Value: "NoDeId",
}, },
expected: []byte("nodeid\x00"), expected: []byte("internal\x00nodeid\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("nodeid\x00"), expected: []byte("internal\x00nodeid\x00"),
},
extra: []indexerTestCase{
{
read: indexValue{
source: Query{
Value: "NoDeId",
PeerName: "Peer1",
},
expected: []byte("peer1\x00nodeid\x00"),
},
write: indexValue{
source: objWPeer,
expected: []byte("peer1\x00nodeid\x00"),
},
},
}, },
}, },
indexService: { indexService: {
read: indexValue{ read: indexValue{
source: Query{Value: "ServiceName"}, source: Query{Value: "ServiceName"},
expected: []byte("servicename\x00"), expected: []byte("internal\x00servicename\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("servicename\x00"), expected: []byte("internal\x00servicename\x00"),
},
extra: []indexerTestCase{
{
read: indexValue{
source: Query{Value: "ServiceName", PeerName: "Peer1"},
expected: []byte("peer1\x00servicename\x00"),
},
write: indexValue{
source: objWPeer,
expected: []byte("peer1\x00servicename\x00"),
},
},
}, },
}, },
indexConnect: { indexConnect: {
read: indexValue{ read: indexValue{
source: Query{Value: "ConnectName"}, source: Query{Value: "ConnectName"},
expected: []byte("connectname\x00"), expected: []byte("internal\x00connectname\x00"),
}, },
write: indexValue{ write: indexValue{
source: &structs.ServiceNode{ source: &structs.ServiceNode{
ServiceName: "ConnectName", ServiceName: "ConnectName",
ServiceConnect: structs.ServiceConnect{Native: true}, ServiceConnect: structs.ServiceConnect{Native: true},
}, },
expected: []byte("connectname\x00"), expected: []byte("internal\x00connectname\x00"),
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
{ {
@ -344,7 +571,20 @@ func testIndexerTableServices() map[string]indexerTestCase {
DestinationServiceName: "ConnectName", DestinationServiceName: "ConnectName",
}, },
}, },
expected: []byte("connectname\x00"), expected: []byte("internal\x00connectname\x00"),
},
},
{
write: indexValue{
source: &structs.ServiceNode{
ServiceName: "ServiceName",
ServiceKind: structs.ServiceKindConnectProxy,
ServiceProxy: structs.ConnectProxyConfig{
DestinationServiceName: "ConnectName",
},
PeerName: "Peer1",
},
expected: []byte("peer1\x00connectname\x00"),
}, },
}, },
{ {
@ -362,18 +602,32 @@ func testIndexerTableServices() map[string]indexerTestCase {
expectedIndexMissing: true, expectedIndexMissing: true,
}, },
}, },
{
read: indexValue{
source: Query{Value: "ConnectName", PeerName: "Peer1"},
expected: []byte("peer1\x00connectname\x00"),
},
write: indexValue{
source: &structs.ServiceNode{
ServiceName: "ConnectName",
ServiceConnect: structs.ServiceConnect{Native: true},
PeerName: "Peer1",
},
expected: []byte("peer1\x00connectname\x00"),
},
},
}, },
}, },
indexKind: { indexKind: {
read: indexValue{ read: indexValue{
source: Query{Value: "connect-proxy"}, source: Query{Value: "connect-proxy"},
expected: []byte("connect-proxy\x00"), expected: []byte("internal\x00connect-proxy\x00"),
}, },
write: indexValue{ write: indexValue{
source: &structs.ServiceNode{ source: &structs.ServiceNode{
ServiceKind: structs.ServiceKindConnectProxy, ServiceKind: structs.ServiceKindConnectProxy,
}, },
expected: []byte("connect-proxy\x00"), expected: []byte("internal\x00connect-proxy\x00"),
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
{ {
@ -382,7 +636,30 @@ func testIndexerTableServices() map[string]indexerTestCase {
ServiceName: "ServiceName", ServiceName: "ServiceName",
ServiceKind: structs.ServiceKindTypical, ServiceKind: structs.ServiceKindTypical,
}, },
expected: []byte("\x00"), expected: []byte("internal\x00\x00"),
},
},
{
write: indexValue{
source: &structs.ServiceNode{
ServiceName: "ServiceName",
ServiceKind: structs.ServiceKindTypical,
PeerName: "Peer1",
},
expected: []byte("peer1\x00\x00"),
},
},
{
read: indexValue{
source: Query{Value: "connect-proxy", PeerName: "Peer1"},
expected: []byte("peer1\x00connect-proxy\x00"),
},
write: indexValue{
source: &structs.ServiceNode{
ServiceKind: structs.ServiceKindConnectProxy,
PeerName: "Peer1",
},
expected: []byte("peer1\x00connect-proxy\x00"),
}, },
}, },
}, },
@ -440,7 +717,7 @@ func testIndexerTableKindServiceNames() map[string]indexerTestCase {
}, },
indexKind: { indexKind: {
read: indexValue{ read: indexValue{
source: structs.ServiceKindConnectProxy, source: Query{Value: string(structs.ServiceKindConnectProxy)},
expected: []byte("connect-proxy\x00"), expected: []byte("connect-proxy\x00"),
}, },
write: indexValue{ write: indexValue{

View File

@ -48,9 +48,9 @@ func nodesTableSchema() *memdb.TableSchema {
AllowMissing: false, AllowMissing: false,
Unique: true, Unique: true,
Indexer: indexerSingleWithPrefix{ Indexer: indexerSingleWithPrefix{
readIndex: indexFromQuery, readIndex: indexWithPeerName(indexFromQuery),
writeIndex: indexFromNode, writeIndex: indexWithPeerName(indexFromNode),
prefixIndex: prefixIndexFromQueryNoNamespace, prefixIndex: prefixIndexFromQueryWithPeer,
}, },
}, },
indexUUID: { indexUUID: {
@ -58,9 +58,9 @@ func nodesTableSchema() *memdb.TableSchema {
AllowMissing: true, AllowMissing: true,
Unique: true, Unique: true,
Indexer: indexerSingleWithPrefix{ Indexer: indexerSingleWithPrefix{
readIndex: indexFromUUIDQuery, readIndex: indexWithPeerName(indexFromUUIDQuery),
writeIndex: indexIDFromNode, writeIndex: indexWithPeerName(indexIDFromNode),
prefixIndex: prefixIndexFromUUIDQuery, prefixIndex: prefixIndexFromUUIDWithPeerQuery,
}, },
}, },
indexMeta: { indexMeta: {
@ -68,8 +68,8 @@ func nodesTableSchema() *memdb.TableSchema {
AllowMissing: true, AllowMissing: true,
Unique: false, Unique: false,
Indexer: indexerMulti{ Indexer: indexerMulti{
readIndex: indexFromKeyValueQuery, readIndex: indexWithPeerName(indexFromKeyValueQuery),
writeIndexMulti: indexMetaFromNode, writeIndexMulti: multiIndexWithPeerName(indexMetaFromNode),
}, },
}, },
}, },
@ -146,9 +146,9 @@ func servicesTableSchema() *memdb.TableSchema {
AllowMissing: false, AllowMissing: false,
Unique: true, Unique: true,
Indexer: indexerSingleWithPrefix{ Indexer: indexerSingleWithPrefix{
readIndex: indexFromNodeServiceQuery, readIndex: indexWithPeerName(indexFromNodeServiceQuery),
writeIndex: indexFromServiceNode, writeIndex: indexWithPeerName(indexFromServiceNode),
prefixIndex: prefixIndexFromQuery, prefixIndex: prefixIndexFromQueryWithPeer,
}, },
}, },
indexNode: { indexNode: {
@ -156,8 +156,8 @@ func servicesTableSchema() *memdb.TableSchema {
AllowMissing: false, AllowMissing: false,
Unique: false, Unique: false,
Indexer: indexerSingle{ Indexer: indexerSingle{
readIndex: indexFromQuery, readIndex: indexWithPeerName(indexFromQuery),
writeIndex: indexFromNodeIdentity, writeIndex: indexWithPeerName(indexFromNodeIdentity),
}, },
}, },
indexService: { indexService: {
@ -165,8 +165,8 @@ func servicesTableSchema() *memdb.TableSchema {
AllowMissing: true, AllowMissing: true,
Unique: false, Unique: false,
Indexer: indexerSingle{ Indexer: indexerSingle{
readIndex: indexFromQuery, readIndex: indexWithPeerName(indexFromQuery),
writeIndex: indexServiceNameFromServiceNode, writeIndex: indexWithPeerName(indexServiceNameFromServiceNode),
}, },
}, },
indexConnect: { indexConnect: {
@ -174,8 +174,8 @@ func servicesTableSchema() *memdb.TableSchema {
AllowMissing: true, AllowMissing: true,
Unique: false, Unique: false,
Indexer: indexerSingle{ Indexer: indexerSingle{
readIndex: indexFromQuery, readIndex: indexWithPeerName(indexFromQuery),
writeIndex: indexConnectNameFromServiceNode, writeIndex: indexWithPeerName(indexConnectNameFromServiceNode),
}, },
}, },
indexKind: { indexKind: {
@ -183,8 +183,8 @@ func servicesTableSchema() *memdb.TableSchema {
AllowMissing: false, AllowMissing: false,
Unique: false, Unique: false,
Indexer: indexerSingle{ Indexer: indexerSingle{
readIndex: indexFromQuery, readIndex: indexWithPeerName(indexFromQuery),
writeIndex: indexKindFromServiceNode, writeIndex: indexWithPeerName(indexKindFromServiceNode),
}, },
}, },
}, },
@ -295,6 +295,61 @@ func indexKindFromServiceNode(raw interface{}) ([]byte, error) {
return b.Bytes(), nil return b.Bytes(), nil
} }
// indexWithPeerName adds peer name to the index.
func indexWithPeerName(
fn func(interface{}) ([]byte, error),
) func(interface{}) ([]byte, error) {
return func(raw interface{}) ([]byte, error) {
v, err := fn(raw)
if err != nil {
return nil, err
}
n, ok := raw.(peerIndexable)
if !ok {
return nil, fmt.Errorf("type must be peerIndexable: %T", raw)
}
peername := n.PeerOrEmpty()
if peername == "" {
peername = structs.LocalPeerKeyword
}
b := newIndexBuilder(len(v) + len(peername) + 1)
b.String(strings.ToLower(peername))
b.Raw(v)
return b.Bytes(), nil
}
}
// multiIndexWithPeerName adds peer name to multiple indices, and returns multiple indices.
func multiIndexWithPeerName(
fn func(interface{}) ([][]byte, error),
) func(interface{}) ([][]byte, error) {
return func(raw interface{}) ([][]byte, error) {
results, err := fn(raw)
if err != nil {
return nil, err
}
n, ok := raw.(peerIndexable)
if !ok {
return nil, fmt.Errorf("type must be peerIndexable: %T", raw)
}
peername := n.PeerOrEmpty()
if peername == "" {
peername = structs.LocalPeerKeyword
}
for i, v := range results {
b := newIndexBuilder(len(v) + len(peername) + 1)
b.String(strings.ToLower(peername))
b.Raw(v)
results[i] = b.Bytes()
}
return results, nil
}
}
// checksTableSchema returns a new table schema used for storing and indexing // checksTableSchema returns a new table schema used for storing and indexing
// health check information. Health checks have a number of different attributes // health check information. Health checks have a number of different attributes
// we want to filter by, so this table is a bit more complex. // we want to filter by, so this table is a bit more complex.
@ -307,9 +362,9 @@ func checksTableSchema() *memdb.TableSchema {
AllowMissing: false, AllowMissing: false,
Unique: true, Unique: true,
Indexer: indexerSingleWithPrefix{ Indexer: indexerSingleWithPrefix{
readIndex: indexFromNodeCheckQuery, readIndex: indexWithPeerName(indexFromNodeCheckQuery),
writeIndex: indexFromHealthCheck, writeIndex: indexWithPeerName(indexFromHealthCheck),
prefixIndex: prefixIndexFromQuery, prefixIndex: prefixIndexFromQueryWithPeer,
}, },
}, },
indexStatus: { indexStatus: {
@ -317,8 +372,8 @@ func checksTableSchema() *memdb.TableSchema {
AllowMissing: false, AllowMissing: false,
Unique: false, Unique: false,
Indexer: indexerSingle{ Indexer: indexerSingle{
readIndex: indexFromQuery, readIndex: indexWithPeerName(indexFromQuery),
writeIndex: indexStatusFromHealthCheck, writeIndex: indexWithPeerName(indexStatusFromHealthCheck),
}, },
}, },
indexService: { indexService: {
@ -326,8 +381,8 @@ func checksTableSchema() *memdb.TableSchema {
AllowMissing: true, AllowMissing: true,
Unique: false, Unique: false,
Indexer: indexerSingle{ Indexer: indexerSingle{
readIndex: indexFromQuery, readIndex: indexWithPeerName(indexFromQuery),
writeIndex: indexServiceNameFromHealthCheck, writeIndex: indexWithPeerName(indexServiceNameFromHealthCheck),
}, },
}, },
indexNode: { indexNode: {
@ -335,8 +390,8 @@ func checksTableSchema() *memdb.TableSchema {
AllowMissing: true, AllowMissing: true,
Unique: false, Unique: false,
Indexer: indexerSingle{ Indexer: indexerSingle{
readIndex: indexFromQuery, readIndex: indexWithPeerName(indexFromQuery),
writeIndex: indexFromNodeIdentity, writeIndex: indexWithPeerName(indexFromNodeIdentity),
}, },
}, },
indexNodeService: { indexNodeService: {
@ -344,8 +399,8 @@ func checksTableSchema() *memdb.TableSchema {
AllowMissing: true, AllowMissing: true,
Unique: false, Unique: false,
Indexer: indexerSingle{ Indexer: indexerSingle{
readIndex: indexFromNodeServiceQuery, readIndex: indexWithPeerName(indexFromNodeServiceQuery),
writeIndex: indexNodeServiceFromHealthCheck, writeIndex: indexWithPeerName(indexNodeServiceFromHealthCheck),
}, },
}, },
}, },
@ -588,11 +643,20 @@ type upstreamDownstream struct {
// NodeCheckQuery is used to query the ID index of the checks table. // NodeCheckQuery is used to query the ID index of the checks table.
type NodeCheckQuery struct { type NodeCheckQuery struct {
Node string Node string
CheckID string CheckID string
PeerName string
acl.EnterpriseMeta acl.EnterpriseMeta
} }
type peerIndexable interface {
PeerOrEmpty() string
}
func (q NodeCheckQuery) PeerOrEmpty() string {
return q.PeerName
}
// NamespaceOrDefault exists because structs.EnterpriseMeta uses a pointer // NamespaceOrDefault exists because structs.EnterpriseMeta uses a pointer
// receiver for this method. Remove once that is fixed. // receiver for this method. Remove once that is fixed.
func (q NodeCheckQuery) NamespaceOrDefault() string { func (q NodeCheckQuery) NamespaceOrDefault() string {
@ -680,7 +744,16 @@ type KindServiceName struct {
structs.RaftIndex structs.RaftIndex
} }
func (n *KindServiceName) PartitionOrDefault() string {
return n.Service.PartitionOrDefault()
}
func (n *KindServiceName) NamespaceOrDefault() string {
return n.Service.NamespaceOrDefault()
}
func kindServiceNameTableSchema() *memdb.TableSchema { func kindServiceNameTableSchema() *memdb.TableSchema {
// TODO(peering): make this peer-aware
return &memdb.TableSchema{ return &memdb.TableSchema{
Name: tableKindServiceNames, Name: tableKindServiceNames,
Indexes: map[string]*memdb.IndexSchema{ Indexes: map[string]*memdb.IndexSchema{
@ -693,8 +766,8 @@ func kindServiceNameTableSchema() *memdb.TableSchema {
writeIndex: indexFromKindServiceName, writeIndex: indexFromKindServiceName,
}, },
}, },
indexKindOnly: { indexKind: {
Name: indexKindOnly, Name: indexKind,
AllowMissing: false, AllowMissing: false,
Unique: false, Unique: false,
Indexer: indexerSingle{ Indexer: indexerSingle{
@ -732,20 +805,20 @@ func indexFromKindServiceNameKindOnly(raw interface{}) ([]byte, error) {
b.String(strings.ToLower(string(x.Kind))) b.String(strings.ToLower(string(x.Kind)))
return b.Bytes(), nil return b.Bytes(), nil
case structs.ServiceKind: case Query:
var b indexBuilder var b indexBuilder
b.String(strings.ToLower(string(x))) b.String(strings.ToLower(x.Value))
return b.Bytes(), nil return b.Bytes(), nil
default: default:
return nil, fmt.Errorf("type must be *KindServiceName or structs.ServiceKind: %T", raw) return nil, fmt.Errorf("type must be *KindServiceName or Query: %T", raw)
} }
} }
func kindServiceNamesMaxIndex(tx ReadTxn, ws memdb.WatchSet, kind structs.ServiceKind) uint64 { func kindServiceNamesMaxIndex(tx ReadTxn, ws memdb.WatchSet, kind string) uint64 {
return maxIndexWatchTxn(tx, ws, kindServiceNameIndexName(kind)) return maxIndexWatchTxn(tx, ws, kindServiceNameIndexName(kind))
} }
func kindServiceNameIndexName(kind structs.ServiceKind) string { func kindServiceNameIndexName(kind string) string {
return "kind_service_names." + kind.Normalized() return "kind_service_names." + kind
} }

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,7 @@ import (
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbsubscribe"
) )
// EventTopicCARoots is the streaming topic to which events will be published // EventTopicCARoots is the streaming topic to which events will be published
@ -29,6 +30,10 @@ func (e EventPayloadCARoots) HasReadPermission(authz acl.Authorizer) bool {
return authz.ServiceWriteAny(&authzContext) == acl.Allow return authz.ServiceWriteAny(&authzContext) == acl.Allow
} }
func (e EventPayloadCARoots) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
panic("EventPayloadCARoots does not implement ToSubscriptionEvent")
}
// caRootsChangeEvents returns an event on EventTopicCARoots whenever the list // caRootsChangeEvents returns an event on EventTopicCARoots whenever the list
// of active CA Roots changes. // of active CA Roots changes.
func caRootsChangeEvents(tx ReadTxn, changes Changes) ([]stream.Event, error) { func caRootsChangeEvents(tx ReadTxn, changes Changes) ([]stream.Event, error) {

View File

@ -181,7 +181,7 @@ func TestStateStore_Coordinate_Cleanup(t *testing.T) {
require.Equal(t, expected, coords) require.Equal(t, expected, coords)
// Now delete the node. // Now delete the node.
require.NoError(t, s.DeleteNode(3, "node1", nil)) require.NoError(t, s.DeleteNode(3, "node1", nil, ""))
// Make sure the coordinate is gone. // Make sure the coordinate is gone.
_, coords, err = s.Coordinate(nil, "node1", nil) _, coords, err = s.Coordinate(nil, "node1", nil)

View File

@ -997,8 +997,9 @@ func (s *Store) intentionTopologyTxn(tx ReadTxn, ws memdb.WatchSet,
// TODO(tproxy): One remaining improvement is that this includes non-Connect services (typical services without a proxy) // TODO(tproxy): One remaining improvement is that this includes non-Connect services (typical services without a proxy)
// Ideally those should be excluded as well, since they can't be upstreams/downstreams without a proxy. // Ideally those should be excluded as well, since they can't be upstreams/downstreams without a proxy.
// Maybe narrow serviceNamesOfKindTxn to services represented by proxies? (ingress, sidecar-proxy, terminating) // Maybe narrow serviceNamesOfKindTxn to services represented by proxies? (ingress, sidecar-
index, services, err := serviceNamesOfKindTxn(tx, ws, structs.ServiceKindTypical) wildcardMeta := structs.WildcardEnterpriseMetaInPartition(structs.WildcardSpecifier)
index, services, err := serviceNamesOfKindTxn(tx, ws, structs.ServiceKindTypical, *wildcardMeta)
if err != nil { if err != nil {
return index, nil, fmt.Errorf("failed to list ingress service names: %v", err) return index, nil, fmt.Errorf("failed to list ingress service names: %v", err)
} }
@ -1008,7 +1009,7 @@ func (s *Store) intentionTopologyTxn(tx ReadTxn, ws memdb.WatchSet,
if downstreams { if downstreams {
// Ingress gateways can only ever be downstreams, since mesh services don't dial them. // Ingress gateways can only ever be downstreams, since mesh services don't dial them.
index, ingress, err := serviceNamesOfKindTxn(tx, ws, structs.ServiceKindIngressGateway) index, ingress, err := serviceNamesOfKindTxn(tx, ws, structs.ServiceKindIngressGateway, *wildcardMeta)
if err != nil { if err != nil {
return index, nil, fmt.Errorf("failed to list ingress service names: %v", err) return index, nil, fmt.Errorf("failed to list ingress service names: %v", err)
} }

View File

@ -0,0 +1,486 @@
package state
import (
"fmt"
"github.com/golang/protobuf/proto"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
)
const (
tablePeering = "peering"
tablePeeringTrustBundles = "peering-trust-bundles"
)
func peeringTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: tablePeering,
Indexes: map[string]*memdb.IndexSchema{
indexID: {
Name: indexID,
AllowMissing: false,
Unique: true,
Indexer: indexerSingle{
readIndex: readIndex(indexFromUUIDString),
writeIndex: writeIndex(indexIDFromPeering),
},
},
indexName: {
Name: indexName,
AllowMissing: false,
Unique: true,
Indexer: indexerSingleWithPrefix{
readIndex: indexPeeringFromQuery,
writeIndex: indexFromPeering,
prefixIndex: prefixIndexFromQueryNoNamespace,
},
},
},
}
}
func peeringTrustBundlesTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: tablePeeringTrustBundles,
Indexes: map[string]*memdb.IndexSchema{
indexID: {
Name: indexID,
AllowMissing: false,
Unique: true,
Indexer: indexerSingle{
readIndex: indexPeeringFromQuery, // same as peering table since we'll use the query.Value
writeIndex: indexFromPeeringTrustBundle,
},
},
},
}
}
func indexIDFromPeering(raw interface{}) ([]byte, error) {
p, ok := raw.(*pbpeering.Peering)
if !ok {
return nil, fmt.Errorf("unexpected type %T for pbpeering.Peering index", raw)
}
if p.ID == "" {
return nil, errMissingValueForIndex
}
uuid, err := uuidStringToBytes(p.ID)
if err != nil {
return nil, err
}
var b indexBuilder
b.Raw(uuid)
return b.Bytes(), nil
}
func (s *Store) PeeringReadByID(ws memdb.WatchSet, id string) (uint64, *pbpeering.Peering, error) {
tx := s.db.ReadTxn()
defer tx.Abort()
peering, err := peeringReadByIDTxn(ws, tx, id)
if err != nil {
return 0, nil, fmt.Errorf("failed to read peering by id: %w", err)
}
if peering == nil {
// Return the tables index so caller can watch it for changes if the peering doesn't exist
return maxIndexWatchTxn(tx, ws, tablePeering), nil, nil
}
return peering.ModifyIndex, peering, nil
}
func (s *Store) PeeringRead(ws memdb.WatchSet, q Query) (uint64, *pbpeering.Peering, error) {
tx := s.db.ReadTxn()
defer tx.Abort()
watchCh, peeringRaw, err := tx.FirstWatch(tablePeering, indexName, q)
if err != nil {
return 0, nil, fmt.Errorf("failed peering lookup: %w", err)
}
peering, ok := peeringRaw.(*pbpeering.Peering)
if peering != nil && !ok {
return 0, nil, fmt.Errorf("invalid type %T", peering)
}
ws.Add(watchCh)
if peering == nil {
// Return the tables index so caller can watch it for changes if the peering doesn't exist
return maxIndexWatchTxn(tx, ws, partitionedIndexEntryName(tablePeering, q.PartitionOrDefault())), nil, nil
}
return peering.ModifyIndex, peering, nil
}
func peeringReadByIDTxn(ws memdb.WatchSet, tx ReadTxn, id string) (*pbpeering.Peering, error) {
watchCh, peeringRaw, err := tx.FirstWatch(tablePeering, indexID, id)
if err != nil {
return nil, fmt.Errorf("failed peering lookup: %w", err)
}
ws.Add(watchCh)
peering, ok := peeringRaw.(*pbpeering.Peering)
if peering != nil && !ok {
return nil, fmt.Errorf("invalid type %T", peering)
}
return peering, nil
}
func (s *Store) PeeringList(ws memdb.WatchSet, entMeta acl.EnterpriseMeta) (uint64, []*pbpeering.Peering, error) {
tx := s.db.ReadTxn()
defer tx.Abort()
var (
iter memdb.ResultIterator
err error
idx uint64
)
if entMeta.PartitionOrDefault() == structs.WildcardSpecifier {
iter, err = tx.Get(tablePeering, indexID)
idx = maxIndexWatchTxn(tx, ws, tablePeering)
} else {
iter, err = tx.Get(tablePeering, indexName+"_prefix", entMeta)
idx = maxIndexWatchTxn(tx, ws, partitionedIndexEntryName(tablePeering, entMeta.PartitionOrDefault()))
}
if err != nil {
return 0, nil, fmt.Errorf("failed peering lookup: %v", err)
}
var result []*pbpeering.Peering
for entry := iter.Next(); entry != nil; entry = iter.Next() {
result = append(result, entry.(*pbpeering.Peering))
}
return idx, result, nil
}
func generatePeeringUUID(tx ReadTxn) (string, error) {
for {
uuid, err := uuid.GenerateUUID()
if err != nil {
return "", fmt.Errorf("failed to generate UUID: %w", err)
}
existing, err := peeringReadByIDTxn(nil, tx, uuid)
if err != nil {
return "", fmt.Errorf("failed to read peering: %w", err)
}
if existing == nil {
return uuid, nil
}
}
}
func (s *Store) PeeringWrite(idx uint64, p *pbpeering.Peering) error {
tx := s.db.WriteTxn(idx)
defer tx.Abort()
q := Query{
Value: p.Name,
EnterpriseMeta: *structs.NodeEnterpriseMetaInPartition(p.Partition),
}
existingRaw, err := tx.First(tablePeering, indexName, q)
if err != nil {
return fmt.Errorf("failed peering lookup: %w", err)
}
existing, ok := existingRaw.(*pbpeering.Peering)
if existingRaw != nil && !ok {
return fmt.Errorf("invalid type %T", existingRaw)
}
if existing != nil {
p.CreateIndex = existing.CreateIndex
p.ID = existing.ID
} else {
// TODO(peering): consider keeping PeeringState enum elsewhere?
p.State = pbpeering.PeeringState_INITIAL
p.CreateIndex = idx
p.ID, err = generatePeeringUUID(tx)
if err != nil {
return fmt.Errorf("failed to generate peering id: %w", err)
}
}
p.ModifyIndex = idx
if err := tx.Insert(tablePeering, p); err != nil {
return fmt.Errorf("failed inserting peering: %w", err)
}
if err := updatePeeringTableIndexes(tx, idx, p.PartitionOrDefault()); err != nil {
return err
}
return tx.Commit()
}
// TODO(peering): replace with deferred deletion since this operation
// should involve cleanup of data associated with the peering.
func (s *Store) PeeringDelete(idx uint64, q Query) error {
tx := s.db.WriteTxn(idx)
defer tx.Abort()
existing, err := tx.First(tablePeering, indexName, q)
if err != nil {
return fmt.Errorf("failed peering lookup: %v", err)
}
if existing == nil {
return nil
}
if err := tx.Delete(tablePeering, existing); err != nil {
return fmt.Errorf("failed deleting peering: %v", err)
}
if err := updatePeeringTableIndexes(tx, idx, q.PartitionOrDefault()); err != nil {
return err
}
return tx.Commit()
}
func (s *Store) PeeringTerminateByID(idx uint64, id string) error {
tx := s.db.WriteTxn(idx)
defer tx.Abort()
existing, err := peeringReadByIDTxn(nil, tx, id)
if err != nil {
return fmt.Errorf("failed to read peering %q: %w", id, err)
}
if existing == nil {
return nil
}
c := proto.Clone(existing)
clone, ok := c.(*pbpeering.Peering)
if !ok {
return fmt.Errorf("invalid type %T, expected *pbpeering.Peering", existing)
}
clone.State = pbpeering.PeeringState_TERMINATED
clone.ModifyIndex = idx
if err := tx.Insert(tablePeering, clone); err != nil {
return fmt.Errorf("failed inserting peering: %w", err)
}
if err := updatePeeringTableIndexes(tx, idx, clone.PartitionOrDefault()); err != nil {
return err
}
return tx.Commit()
}
// ExportedServicesForPeer returns the list of typical and proxy services exported to a peer.
// TODO(peering): What to do about terminating gateways? Sometimes terminating gateways are the appropriate destination
// to dial for an upstream mesh service. However, that information is handled by observing the terminating gateway's
// config entry, which we wouldn't want to replicate. How would client peers know to route through terminating gateways
// when they're not dialing through a remote mesh gateway?
func (s *Store) ExportedServicesForPeer(ws memdb.WatchSet, peerID string) (uint64, []structs.ServiceName, error) {
tx := s.db.ReadTxn()
defer tx.Abort()
peering, err := peeringReadByIDTxn(ws, tx, peerID)
if err != nil {
return 0, nil, fmt.Errorf("failed to read peering: %w", err)
}
if peering == nil {
return 0, nil, nil
}
maxIdx := peering.ModifyIndex
entMeta := structs.NodeEnterpriseMetaInPartition(peering.Partition)
idx, raw, err := configEntryTxn(tx, ws, structs.ExportedServices, entMeta.PartitionOrDefault(), entMeta)
if err != nil {
return 0, nil, fmt.Errorf("failed to fetch exported-services config entry: %w", err)
}
if idx > maxIdx {
maxIdx = idx
}
if raw == nil {
return maxIdx, nil, nil
}
conf, ok := raw.(*structs.ExportedServicesConfigEntry)
if !ok {
return 0, nil, fmt.Errorf("expected type *structs.ExportedServicesConfigEntry, got %T", raw)
}
set := make(map[structs.ServiceName]struct{})
for _, svc := range conf.Services {
svcMeta := acl.NewEnterpriseMetaWithPartition(entMeta.PartitionOrDefault(), svc.Namespace)
sawPeer := false
for _, consumer := range svc.Consumers {
name := structs.NewServiceName(svc.Name, &svcMeta)
if _, ok := set[name]; ok {
// Service was covered by a wildcard that was already accounted for
continue
}
if consumer.PeerName != peering.Name {
continue
}
sawPeer = true
if svc.Name != structs.WildcardSpecifier {
set[name] = struct{}{}
}
}
// If the target peer is a consumer, and all services in the namespace are exported, query those service names.
if sawPeer && svc.Name == structs.WildcardSpecifier {
var typicalServices []*KindServiceName
idx, typicalServices, err = serviceNamesOfKindTxn(tx, ws, structs.ServiceKindTypical, svcMeta)
if err != nil {
return 0, nil, fmt.Errorf("failed to get service names: %w", err)
}
if idx > maxIdx {
maxIdx = idx
}
for _, s := range typicalServices {
set[s.Service] = struct{}{}
}
var proxyServices []*KindServiceName
idx, proxyServices, err = serviceNamesOfKindTxn(tx, ws, structs.ServiceKindConnectProxy, svcMeta)
if err != nil {
return 0, nil, fmt.Errorf("failed to get service names: %w", err)
}
if idx > maxIdx {
maxIdx = idx
}
for _, s := range proxyServices {
set[s.Service] = struct{}{}
}
}
}
var resp []structs.ServiceName
for svc := range set {
resp = append(resp, svc)
}
return maxIdx, resp, nil
}
func (s *Store) PeeringTrustBundleRead(ws memdb.WatchSet, q Query) (uint64, *pbpeering.PeeringTrustBundle, error) {
tx := s.db.ReadTxn()
defer tx.Abort()
watchCh, ptbRaw, err := tx.FirstWatch(tablePeeringTrustBundles, indexID, q)
if err != nil {
return 0, nil, fmt.Errorf("failed peering trust bundle lookup: %w", err)
}
ptb, ok := ptbRaw.(*pbpeering.PeeringTrustBundle)
if ptb != nil && !ok {
return 0, nil, fmt.Errorf("invalid type %T", ptb)
}
ws.Add(watchCh)
if ptb == nil {
// Return the tables index so caller can watch it for changes if the trust bundle doesn't exist
return maxIndexWatchTxn(tx, ws, partitionedIndexEntryName(tablePeeringTrustBundles, q.PartitionOrDefault())), nil, nil
}
return ptb.ModifyIndex, ptb, nil
}
// PeeringTrustBundleWrite writes ptb to the state store. If there is an existing trust bundle with the given peer name,
// it will be overwritten.
func (s *Store) PeeringTrustBundleWrite(idx uint64, ptb *pbpeering.PeeringTrustBundle) error {
tx := s.db.WriteTxn(idx)
defer tx.Abort()
q := Query{
Value: ptb.PeerName,
EnterpriseMeta: *structs.NodeEnterpriseMetaInPartition(ptb.Partition),
}
existingRaw, err := tx.First(tablePeeringTrustBundles, indexID, q)
if err != nil {
return fmt.Errorf("failed peering trust bundle lookup: %w", err)
}
existing, ok := existingRaw.(*pbpeering.PeeringTrustBundle)
if existingRaw != nil && !ok {
return fmt.Errorf("invalid type %T", existingRaw)
}
if existing != nil {
ptb.CreateIndex = existing.CreateIndex
} else {
ptb.CreateIndex = idx
}
ptb.ModifyIndex = idx
if err := tx.Insert(tablePeeringTrustBundles, ptb); err != nil {
return fmt.Errorf("failed inserting peering trust bundle: %w", err)
}
if err := updatePeeringTrustBundlesTableIndexes(tx, idx, ptb.PartitionOrDefault()); err != nil {
return err
}
return tx.Commit()
}
func (s *Store) PeeringTrustBundleDelete(idx uint64, q Query) error {
tx := s.db.WriteTxn(idx)
defer tx.Abort()
existing, err := tx.First(tablePeeringTrustBundles, indexID, q)
if err != nil {
return fmt.Errorf("failed peering trust bundle lookup: %v", err)
}
if existing == nil {
return nil
}
if err := tx.Delete(tablePeeringTrustBundles, existing); err != nil {
return fmt.Errorf("failed deleting peering trust bundle: %v", err)
}
if err := updatePeeringTrustBundlesTableIndexes(tx, idx, q.PartitionOrDefault()); err != nil {
return err
}
return tx.Commit()
}
func (s *Snapshot) Peerings() (memdb.ResultIterator, error) {
return s.tx.Get(tablePeering, indexName)
}
func (s *Snapshot) PeeringTrustBundles() (memdb.ResultIterator, error) {
return s.tx.Get(tablePeeringTrustBundles, indexID)
}
func (r *Restore) Peering(p *pbpeering.Peering) error {
if err := r.tx.Insert(tablePeering, p); err != nil {
return fmt.Errorf("failed restoring peering: %w", err)
}
if err := updatePeeringTableIndexes(r.tx, p.ModifyIndex, p.PartitionOrDefault()); err != nil {
return err
}
return nil
}
func (r *Restore) PeeringTrustBundle(ptb *pbpeering.PeeringTrustBundle) error {
if err := r.tx.Insert(tablePeeringTrustBundles, ptb); err != nil {
return fmt.Errorf("failed restoring peering trust bundle: %w", err)
}
if err := updatePeeringTrustBundlesTableIndexes(r.tx, ptb.ModifyIndex, ptb.PartitionOrDefault()); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,66 @@
//go:build !consulent
// +build !consulent
package state
import (
"fmt"
"strings"
"github.com/hashicorp/consul/proto/pbpeering"
)
func indexPeeringFromQuery(raw interface{}) ([]byte, error) {
q, ok := raw.(Query)
if !ok {
return nil, fmt.Errorf("unexpected type %T for Query index", raw)
}
var b indexBuilder
b.String(strings.ToLower(q.Value))
return b.Bytes(), nil
}
func indexFromPeering(raw interface{}) ([]byte, error) {
p, ok := raw.(*pbpeering.Peering)
if !ok {
return nil, fmt.Errorf("unexpected type %T for structs.Peering index", raw)
}
if p.Name == "" {
return nil, errMissingValueForIndex
}
var b indexBuilder
b.String(strings.ToLower(p.Name))
return b.Bytes(), nil
}
func indexFromPeeringTrustBundle(raw interface{}) ([]byte, error) {
ptb, ok := raw.(*pbpeering.PeeringTrustBundle)
if !ok {
return nil, fmt.Errorf("unexpected type %T for pbpeering.PeeringTrustBundle index", raw)
}
if ptb.PeerName == "" {
return nil, errMissingValueForIndex
}
var b indexBuilder
b.String(strings.ToLower(ptb.PeerName))
return b.Bytes(), nil
}
func updatePeeringTableIndexes(tx WriteTxn, idx uint64, _ string) error {
if err := tx.Insert(tableIndex, &IndexEntry{Key: tablePeering, Value: idx}); err != nil {
return fmt.Errorf("failed updating table index: %w", err)
}
return nil
}
func updatePeeringTrustBundlesTableIndexes(tx WriteTxn, idx uint64, _ string) error {
if err := tx.Insert(tableIndex, &IndexEntry{Key: tablePeeringTrustBundles, Value: idx}); err != nil {
return fmt.Errorf("failed updating table index: %w", err)
}
return nil
}

View File

@ -0,0 +1,811 @@
package state
import (
"fmt"
"math/rand"
"testing"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-uuid"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
)
func insertTestPeerings(t *testing.T, s *Store) {
t.Helper()
tx := s.db.WriteTxn(0)
defer tx.Abort()
err := tx.Insert(tablePeering, &pbpeering.Peering{
Name: "foo",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
ID: "9e650110-ac74-4c5a-a6a8-9348b2bed4e9",
State: pbpeering.PeeringState_INITIAL,
CreateIndex: 1,
ModifyIndex: 1,
})
require.NoError(t, err)
err = tx.Insert(tablePeering, &pbpeering.Peering{
Name: "bar",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
ID: "5ebcff30-5509-4858-8142-a8e580f1863f",
State: pbpeering.PeeringState_FAILING,
CreateIndex: 2,
ModifyIndex: 2,
})
require.NoError(t, err)
err = tx.Insert(tableIndex, &IndexEntry{
Key: tablePeering,
Value: 2,
})
require.NoError(t, err)
require.NoError(t, tx.Commit())
}
func insertTestPeeringTrustBundles(t *testing.T, s *Store) {
t.Helper()
tx := s.db.WriteTxn(0)
defer tx.Abort()
err := tx.Insert(tablePeeringTrustBundles, &pbpeering.PeeringTrustBundle{
TrustDomain: "foo.com",
PeerName: "foo",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
RootPEMs: []string{"foo certificate bundle"},
CreateIndex: 1,
ModifyIndex: 1,
})
require.NoError(t, err)
err = tx.Insert(tablePeeringTrustBundles, &pbpeering.PeeringTrustBundle{
TrustDomain: "bar.com",
PeerName: "bar",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
RootPEMs: []string{"bar certificate bundle"},
CreateIndex: 2,
ModifyIndex: 2,
})
require.NoError(t, err)
err = tx.Insert(tableIndex, &IndexEntry{
Key: tablePeeringTrustBundles,
Value: 2,
})
require.NoError(t, err)
require.NoError(t, tx.Commit())
}
func TestStateStore_PeeringReadByID(t *testing.T) {
s := NewStateStore(nil)
insertTestPeerings(t, s)
type testcase struct {
name string
id string
expect *pbpeering.Peering
}
run := func(t *testing.T, tc testcase) {
_, peering, err := s.PeeringReadByID(nil, tc.id)
require.NoError(t, err)
require.Equal(t, tc.expect, peering)
}
tcs := []testcase{
{
name: "get foo",
id: "9e650110-ac74-4c5a-a6a8-9348b2bed4e9",
expect: &pbpeering.Peering{
Name: "foo",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
ID: "9e650110-ac74-4c5a-a6a8-9348b2bed4e9",
State: pbpeering.PeeringState_INITIAL,
CreateIndex: 1,
ModifyIndex: 1,
},
},
{
name: "get bar",
id: "5ebcff30-5509-4858-8142-a8e580f1863f",
expect: &pbpeering.Peering{
Name: "bar",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
ID: "5ebcff30-5509-4858-8142-a8e580f1863f",
State: pbpeering.PeeringState_FAILING,
CreateIndex: 2,
ModifyIndex: 2,
},
},
{
name: "get non-existent",
id: "05f54e2f-7813-4d4d-ba03-534554c88a18",
expect: nil,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func TestStateStore_PeeringRead(t *testing.T) {
s := NewStateStore(nil)
insertTestPeerings(t, s)
type testcase struct {
name string
query Query
expect *pbpeering.Peering
}
run := func(t *testing.T, tc testcase) {
_, peering, err := s.PeeringRead(nil, tc.query)
require.NoError(t, err)
require.Equal(t, tc.expect, peering)
}
tcs := []testcase{
{
name: "get foo",
query: Query{
Value: "foo",
},
expect: &pbpeering.Peering{
Name: "foo",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
ID: "9e650110-ac74-4c5a-a6a8-9348b2bed4e9",
State: pbpeering.PeeringState_INITIAL,
CreateIndex: 1,
ModifyIndex: 1,
},
},
{
name: "get non-existent baz",
query: Query{
Value: "baz",
},
expect: nil,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func TestStore_Peering_Watch(t *testing.T) {
s := NewStateStore(nil)
var lastIdx uint64
lastIdx++
// set up initial write
err := s.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "foo",
})
require.NoError(t, err)
newWatch := func(t *testing.T, q Query) memdb.WatchSet {
t.Helper()
// set up a watch
ws := memdb.NewWatchSet()
_, _, err := s.PeeringRead(ws, q)
require.NoError(t, err)
return ws
}
t.Run("insert fires watch", func(t *testing.T) {
// watch on non-existent bar
ws := newWatch(t, Query{Value: "bar"})
lastIdx++
err := s.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "bar",
})
require.NoError(t, err)
require.True(t, watchFired(ws))
// should find bar peering
idx, p, err := s.PeeringRead(ws, Query{Value: "bar"})
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.NotNil(t, p)
})
t.Run("update fires watch", func(t *testing.T) {
// watch on existing foo
ws := newWatch(t, Query{Value: "foo"})
// unrelated write shouldn't fire watch
lastIdx++
err := s.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "bar",
})
require.NoError(t, err)
require.False(t, watchFired(ws))
// foo write should fire watch
lastIdx++
err = s.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "foo",
State: pbpeering.PeeringState_FAILING,
})
require.NoError(t, err)
require.True(t, watchFired(ws))
// check foo is updated
idx, p, err := s.PeeringRead(ws, Query{Value: "foo"})
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.Equal(t, pbpeering.PeeringState_FAILING, p.State)
})
t.Run("delete fires watch", func(t *testing.T) {
// watch on existing foo
ws := newWatch(t, Query{Value: "foo"})
// delete on bar shouldn't fire watch
lastIdx++
require.NoError(t, s.PeeringWrite(lastIdx, &pbpeering.Peering{Name: "bar"}))
lastIdx++
require.NoError(t, s.PeeringDelete(lastIdx, Query{Value: "bar"}))
require.False(t, watchFired(ws))
// delete on foo should fire watch
lastIdx++
err := s.PeeringDelete(lastIdx, Query{Value: "foo"})
require.NoError(t, err)
require.True(t, watchFired(ws))
// check foo is gone
idx, p, err := s.PeeringRead(ws, Query{Value: "foo"})
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.Nil(t, p)
})
}
func TestStore_PeeringList(t *testing.T) {
s := NewStateStore(nil)
insertTestPeerings(t, s)
_, pps, err := s.PeeringList(nil, acl.EnterpriseMeta{})
require.NoError(t, err)
expect := []*pbpeering.Peering{
{
Name: "foo",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
ID: "9e650110-ac74-4c5a-a6a8-9348b2bed4e9",
State: pbpeering.PeeringState_INITIAL,
CreateIndex: 1,
ModifyIndex: 1,
},
{
Name: "bar",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
ID: "5ebcff30-5509-4858-8142-a8e580f1863f",
State: pbpeering.PeeringState_FAILING,
CreateIndex: 2,
ModifyIndex: 2,
},
}
require.ElementsMatch(t, expect, pps)
}
func TestStore_PeeringList_Watch(t *testing.T) {
s := NewStateStore(nil)
var lastIdx uint64
lastIdx++ // start at 1
// track number of expected peerings in state store
var count int
newWatch := func(t *testing.T, entMeta acl.EnterpriseMeta) memdb.WatchSet {
t.Helper()
// set up a watch
ws := memdb.NewWatchSet()
_, _, err := s.PeeringList(ws, entMeta)
require.NoError(t, err)
return ws
}
t.Run("insert fires watch", func(t *testing.T) {
ws := newWatch(t, acl.EnterpriseMeta{})
lastIdx++
// insert a peering
err := s.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "bar",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
})
require.NoError(t, err)
count++
require.True(t, watchFired(ws))
// should find bar peering
idx, pp, err := s.PeeringList(ws, acl.EnterpriseMeta{})
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.Len(t, pp, count)
})
t.Run("update fires watch", func(t *testing.T) {
// set up initial write
lastIdx++
err := s.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "foo",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
})
require.NoError(t, err)
count++
ws := newWatch(t, acl.EnterpriseMeta{})
// update peering
lastIdx++
err = s.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "foo",
State: pbpeering.PeeringState_FAILING,
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
})
require.NoError(t, err)
require.True(t, watchFired(ws))
idx, pp, err := s.PeeringList(ws, acl.EnterpriseMeta{})
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.Len(t, pp, count)
})
t.Run("delete fires watch", func(t *testing.T) {
// set up initial write
lastIdx++
err := s.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "baz",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
})
require.NoError(t, err)
count++
ws := newWatch(t, acl.EnterpriseMeta{})
// delete peering
lastIdx++
err = s.PeeringDelete(lastIdx, Query{Value: "baz"})
require.NoError(t, err)
count--
require.True(t, watchFired(ws))
idx, pp, err := s.PeeringList(ws, acl.EnterpriseMeta{})
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.Len(t, pp, count)
})
}
func TestStore_PeeringWrite(t *testing.T) {
s := NewStateStore(nil)
insertTestPeerings(t, s)
type testcase struct {
name string
input *pbpeering.Peering
}
run := func(t *testing.T, tc testcase) {
require.NoError(t, s.PeeringWrite(10, tc.input))
q := Query{
Value: tc.input.Name,
EnterpriseMeta: *structs.NodeEnterpriseMetaInPartition(tc.input.Partition),
}
_, p, err := s.PeeringRead(nil, q)
require.NoError(t, err)
require.NotNil(t, p)
if tc.input.State == 0 {
require.Equal(t, pbpeering.PeeringState_INITIAL, p.State)
}
require.Equal(t, tc.input.Name, p.Name)
}
tcs := []testcase{
{
name: "create baz",
input: &pbpeering.Peering{
Name: "baz",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
},
},
{
name: "update foo",
input: &pbpeering.Peering{
Name: "foo",
State: pbpeering.PeeringState_FAILING,
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func TestStore_PeeringWrite_GenerateUUID(t *testing.T) {
rand.Seed(1)
s := NewStateStore(nil)
entMeta := structs.NodeEnterpriseMetaInDefaultPartition()
partition := entMeta.PartitionOrDefault()
for i := 1; i < 11; i++ {
require.NoError(t, s.PeeringWrite(uint64(i), &pbpeering.Peering{
Name: fmt.Sprintf("peering-%d", i),
Partition: partition,
}))
}
idx, peerings, err := s.PeeringList(nil, *entMeta)
require.NoError(t, err)
require.Equal(t, uint64(10), idx)
require.Len(t, peerings, 10)
// Ensure that all assigned UUIDs are unique.
uniq := make(map[string]struct{})
for _, p := range peerings {
uniq[p.ID] = struct{}{}
}
require.Len(t, uniq, 10)
// Ensure that the ID of an existing peering cannot be overwritten.
updated := &pbpeering.Peering{
Name: peerings[0].Name,
Partition: peerings[0].Partition,
}
// Attempt to overwrite ID.
updated.ID, err = uuid.GenerateUUID()
require.NoError(t, err)
require.NoError(t, s.PeeringWrite(11, updated))
q := Query{
Value: updated.Name,
EnterpriseMeta: *entMeta,
}
idx, got, err := s.PeeringRead(nil, q)
require.NoError(t, err)
require.Equal(t, uint64(11), idx)
require.Equal(t, peerings[0].ID, got.ID)
}
func TestStore_PeeringDelete(t *testing.T) {
s := NewStateStore(nil)
insertTestPeerings(t, s)
q := Query{Value: "foo"}
require.NoError(t, s.PeeringDelete(10, q))
_, p, err := s.PeeringRead(nil, q)
require.NoError(t, err)
require.Nil(t, p)
}
func TestStore_PeeringTerminateByID(t *testing.T) {
s := NewStateStore(nil)
insertTestPeerings(t, s)
// id corresponding to default/foo
id := "9e650110-ac74-4c5a-a6a8-9348b2bed4e9"
require.NoError(t, s.PeeringTerminateByID(10, id))
_, p, err := s.PeeringReadByID(nil, id)
require.NoError(t, err)
require.Equal(t, pbpeering.PeeringState_TERMINATED, p.State)
}
func TestStateStore_PeeringTrustBundleRead(t *testing.T) {
s := NewStateStore(nil)
insertTestPeeringTrustBundles(t, s)
type testcase struct {
name string
query Query
expect *pbpeering.PeeringTrustBundle
}
run := func(t *testing.T, tc testcase) {
_, ptb, err := s.PeeringTrustBundleRead(nil, tc.query)
require.NoError(t, err)
require.Equal(t, tc.expect, ptb)
}
entMeta := structs.NodeEnterpriseMetaInDefaultPartition()
tcs := []testcase{
{
name: "get foo",
query: Query{
Value: "foo",
EnterpriseMeta: *entMeta,
},
expect: &pbpeering.PeeringTrustBundle{
TrustDomain: "foo.com",
PeerName: "foo",
Partition: entMeta.PartitionOrEmpty(),
RootPEMs: []string{"foo certificate bundle"},
CreateIndex: 1,
ModifyIndex: 1,
},
},
{
name: "get non-existent baz",
query: Query{
Value: "baz",
},
expect: nil,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func TestStore_PeeringTrustBundleWrite(t *testing.T) {
s := NewStateStore(nil)
insertTestPeeringTrustBundles(t, s)
type testcase struct {
name string
input *pbpeering.PeeringTrustBundle
}
run := func(t *testing.T, tc testcase) {
require.NoError(t, s.PeeringTrustBundleWrite(10, tc.input))
q := Query{
Value: tc.input.PeerName,
EnterpriseMeta: *structs.NodeEnterpriseMetaInPartition(tc.input.Partition),
}
_, ptb, err := s.PeeringTrustBundleRead(nil, q)
require.NoError(t, err)
require.NotNil(t, ptb)
require.Equal(t, tc.input.TrustDomain, ptb.TrustDomain)
require.Equal(t, tc.input.PeerName, ptb.PeerName)
}
tcs := []testcase{
{
name: "create baz",
input: &pbpeering.PeeringTrustBundle{
TrustDomain: "baz.com",
PeerName: "baz",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
},
},
{
name: "update foo",
input: &pbpeering.PeeringTrustBundle{
TrustDomain: "foo-updated.com",
PeerName: "foo",
Partition: structs.NodeEnterpriseMetaInDefaultPartition().PartitionOrEmpty(),
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func TestStore_PeeringTrustBundleDelete(t *testing.T) {
s := NewStateStore(nil)
insertTestPeeringTrustBundles(t, s)
q := Query{Value: "foo"}
require.NoError(t, s.PeeringTrustBundleDelete(10, q))
_, ptb, err := s.PeeringRead(nil, q)
require.NoError(t, err)
require.Nil(t, ptb)
}
func TestStateStore_ExportedServicesForPeer(t *testing.T) {
s := NewStateStore(nil)
var lastIdx uint64
lastIdx++
err := s.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "my-peering",
})
require.NoError(t, err)
q := Query{Value: "my-peering"}
_, p, err := s.PeeringRead(nil, q)
require.NoError(t, err)
require.NotNil(t, p)
id := p.ID
ws := memdb.NewWatchSet()
runStep(t, "no exported services", func(t *testing.T) {
idx, exported, err := s.ExportedServicesForPeer(ws, id)
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.Empty(t, exported)
})
runStep(t, "config entry with exact service names", func(t *testing.T) {
entry := &structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "mysql",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peering",
},
},
},
{
Name: "redis",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peering",
},
},
},
{
Name: "mongo",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-other-peering",
},
},
},
},
}
lastIdx++
err = s.EnsureConfigEntry(lastIdx, entry)
require.NoError(t, err)
require.True(t, watchFired(ws))
ws = memdb.NewWatchSet()
expect := []structs.ServiceName{
{
Name: "mysql",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
},
{
Name: "redis",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
},
}
idx, got, err := s.ExportedServicesForPeer(ws, id)
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.ElementsMatch(t, expect, got)
})
runStep(t, "config entry with wildcard service name picks up existing service", func(t *testing.T) {
lastIdx++
require.NoError(t, s.EnsureNode(lastIdx, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
lastIdx++
require.NoError(t, s.EnsureService(lastIdx, "foo", &structs.NodeService{ID: "billing", Service: "billing", Port: 5000}))
entry := &structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "*",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peering",
},
},
},
},
}
lastIdx++
err = s.EnsureConfigEntry(lastIdx, entry)
require.NoError(t, err)
require.True(t, watchFired(ws))
ws = memdb.NewWatchSet()
expect := []structs.ServiceName{
{
Name: "billing",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
},
}
idx, got, err := s.ExportedServicesForPeer(ws, id)
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.Equal(t, expect, got)
})
runStep(t, "config entry with wildcard service names picks up new registrations", func(t *testing.T) {
lastIdx++
require.NoError(t, s.EnsureService(lastIdx, "foo", &structs.NodeService{ID: "payments", Service: "payments", Port: 5000}))
lastIdx++
proxy := structs.NodeService{
Kind: structs.ServiceKindConnectProxy,
ID: "payments-proxy",
Service: "payments-proxy",
Port: 5000,
}
require.NoError(t, s.EnsureService(lastIdx, "foo", &proxy))
require.True(t, watchFired(ws))
ws = memdb.NewWatchSet()
expect := []structs.ServiceName{
{
Name: "billing",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
},
{
Name: "payments",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
},
{
Name: "payments-proxy",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
},
}
idx, got, err := s.ExportedServicesForPeer(ws, id)
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.ElementsMatch(t, expect, got)
})
runStep(t, "config entry with wildcard service names picks up service deletions", func(t *testing.T) {
lastIdx++
require.NoError(t, s.DeleteService(lastIdx, "foo", "billing", nil, ""))
require.True(t, watchFired(ws))
ws = memdb.NewWatchSet()
expect := []structs.ServiceName{
{
Name: "payments",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
},
{
Name: "payments-proxy",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
},
}
idx, got, err := s.ExportedServicesForPeer(ws, id)
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.ElementsMatch(t, expect, got)
})
runStep(t, "deleting the config entry clears exported services", func(t *testing.T) {
require.NoError(t, s.DeleteConfigEntry(lastIdx, structs.ExportedServices, "default", structs.DefaultEnterpriseMetaInDefaultPartition()))
idx, exported, err := s.ExportedServicesForPeer(ws, id)
require.NoError(t, err)
require.Equal(t, lastIdx, idx)
require.Empty(t, exported)
})
}

View File

@ -12,10 +12,15 @@ import (
// Query is a type used to query any single value index that may include an // Query is a type used to query any single value index that may include an
// enterprise identifier. // enterprise identifier.
type Query struct { type Query struct {
Value string Value string
PeerName string
acl.EnterpriseMeta acl.EnterpriseMeta
} }
func (q Query) PeerOrEmpty() string {
return q.PeerName
}
func (q Query) IDValue() string { func (q Query) IDValue() string {
return q.Value return q.Value
} }
@ -137,11 +142,16 @@ func (q BoolQuery) PartitionOrDefault() string {
// KeyValueQuery is a type used to query for both a key and a value that may // KeyValueQuery is a type used to query for both a key and a value that may
// include an enterprise identifier. // include an enterprise identifier.
type KeyValueQuery struct { type KeyValueQuery struct {
Key string Key string
Value string Value string
PeerName string
acl.EnterpriseMeta acl.EnterpriseMeta
} }
func (q KeyValueQuery) PeerOrEmpty() string {
return q.PeerName
}
// NamespaceOrDefault exists because structs.EnterpriseMeta uses a pointer // NamespaceOrDefault exists because structs.EnterpriseMeta uses a pointer
// receiver for this method. Remove once that is fixed. // receiver for this method. Remove once that is fixed.
func (q KeyValueQuery) NamespaceOrDefault() string { func (q KeyValueQuery) NamespaceOrDefault() string {

View File

@ -8,6 +8,7 @@ import (
"strings" "strings"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/structs"
) )
func prefixIndexFromQuery(arg interface{}) ([]byte, error) { func prefixIndexFromQuery(arg interface{}) ([]byte, error) {
@ -28,6 +29,29 @@ func prefixIndexFromQuery(arg interface{}) ([]byte, error) {
return nil, fmt.Errorf("unexpected type %T for Query prefix index", arg) return nil, fmt.Errorf("unexpected type %T for Query prefix index", arg)
} }
func prefixIndexFromQueryWithPeer(arg interface{}) ([]byte, error) {
var b indexBuilder
switch v := arg.(type) {
case *acl.EnterpriseMeta:
return nil, nil
case acl.EnterpriseMeta:
return nil, nil
case Query:
if v.PeerOrEmpty() == "" {
b.String(structs.LocalPeerKeyword)
} else {
b.String(strings.ToLower(v.PeerOrEmpty()))
}
if v.Value == "" {
return b.Bytes(), nil
}
b.String(strings.ToLower(v.Value))
return b.Bytes(), nil
}
return nil, fmt.Errorf("unexpected type %T for Query prefix index", arg)
}
func prefixIndexFromQueryNoNamespace(arg interface{}) ([]byte, error) { func prefixIndexFromQueryNoNamespace(arg interface{}) ([]byte, error) {
return prefixIndexFromQuery(arg) return prefixIndexFromQuery(arg)
} }

View File

@ -22,12 +22,16 @@ func newDBSchema() *memdb.DBSchema {
configTableSchema, configTableSchema,
coordinatesTableSchema, coordinatesTableSchema,
federationStateTableSchema, federationStateTableSchema,
freeVirtualIPTableSchema,
gatewayServicesTableSchema, gatewayServicesTableSchema,
indexTableSchema, indexTableSchema,
intentionsTableSchema, intentionsTableSchema,
kindServiceNameTableSchema,
kvsTableSchema, kvsTableSchema,
meshTopologyTableSchema, meshTopologyTableSchema,
nodesTableSchema, nodesTableSchema,
peeringTableSchema,
peeringTrustBundlesTableSchema,
policiesTableSchema, policiesTableSchema,
preparedQueriesTableSchema, preparedQueriesTableSchema,
rolesTableSchema, rolesTableSchema,
@ -39,8 +43,6 @@ func newDBSchema() *memdb.DBSchema {
tokensTableSchema, tokensTableSchema,
tombstonesTableSchema, tombstonesTableSchema,
usageTableSchema, usageTableSchema,
freeVirtualIPTableSchema,
kindServiceNameTableSchema,
) )
withEnterpriseSchema(db) withEnterpriseSchema(db)
return db return db

View File

@ -3,7 +3,12 @@
package state package state
import "github.com/hashicorp/consul/acl" import (
"fmt"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/structs"
)
func partitionedIndexEntryName(entry string, _ string) string { func partitionedIndexEntryName(entry string, _ string) string {
return entry return entry
@ -12,3 +17,11 @@ func partitionedIndexEntryName(entry string, _ string) string {
func partitionedAndNamespacedIndexEntryName(entry string, _ *acl.EnterpriseMeta) string { func partitionedAndNamespacedIndexEntryName(entry string, _ *acl.EnterpriseMeta) string {
return entry return entry
} }
// peeredIndexEntryName returns the peered index key for an importable entity (e.g. checks, services, or nodes).
func peeredIndexEntryName(entry, peerName string) string {
if peerName == "" {
peerName = structs.LocalPeerKeyword
}
return fmt.Sprintf("peer.%s:%s", peerName, entry)
}

View File

@ -553,7 +553,7 @@ func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if err := s.DeleteNode(15, "foo", nil); err != nil { if err := s.DeleteNode(15, "foo", nil, ""); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if !watchFired(ws) { if !watchFired(ws) {
@ -608,7 +608,7 @@ func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if err := s.DeleteService(15, "foo", "api", nil); err != nil { if err := s.DeleteService(15, "foo", "api", nil, ""); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if !watchFired(ws) { if !watchFired(ws) {
@ -709,7 +709,7 @@ func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if err := s.DeleteCheck(15, "foo", "bar", nil); err != nil { if err := s.DeleteCheck(15, "foo", "bar", nil, ""); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if !watchFired(ws) { if !watchFired(ws) {
@ -777,7 +777,7 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if err := s.DeleteNode(6, "foo", nil); err != nil { if err := s.DeleteNode(6, "foo", nil, ""); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if !watchFired(ws) { if !watchFired(ws) {
@ -859,7 +859,7 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if err := s.DeleteNode(6, "foo", nil); err != nil { if err := s.DeleteNode(6, "foo", nil, ""); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if !watchFired(ws) { if !watchFired(ws) {

View File

@ -291,10 +291,9 @@ func maxIndexWatchTxn(tx ReadTxn, ws memdb.WatchSet, tables ...string) uint64 {
return lindex return lindex
} }
// indexUpdateMaxTxn is used when restoring entries and sets the table's index to // indexUpdateMaxTxn sets the table's index to the given idx only if it's greater than the current index.
// the given idx only if it's greater than the current index. func indexUpdateMaxTxn(tx WriteTxn, idx uint64, key string) error {
func indexUpdateMaxTxn(tx WriteTxn, idx uint64, table string) error { ti, err := tx.First(tableIndex, indexID, key)
ti, err := tx.First(tableIndex, indexID, table)
if err != nil { if err != nil {
return fmt.Errorf("failed to retrieve existing index: %s", err) return fmt.Errorf("failed to retrieve existing index: %s", err)
} }
@ -311,7 +310,7 @@ func indexUpdateMaxTxn(tx WriteTxn, idx uint64, table string) error {
} }
} }
if err := tx.Insert(tableIndex, &IndexEntry{table, idx}); err != nil { if err := tx.Insert(tableIndex, &IndexEntry{key, idx}); err != nil {
return fmt.Errorf("failed updating index %s", err) return fmt.Errorf("failed updating index %s", err)
} }
return nil return nil

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbsubscribe"
) )
func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) { func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
@ -399,7 +400,7 @@ var topicService topic = "test-topic-service"
func (s *Store) topicServiceTestHandler(req stream.SubscribeRequest, snap stream.SnapshotAppender) (uint64, error) { func (s *Store) topicServiceTestHandler(req stream.SubscribeRequest, snap stream.SnapshotAppender) (uint64, error) {
key := req.Subject.String() key := req.Subject.String()
idx, nodes, err := s.ServiceNodes(nil, key, nil) idx, nodes, err := s.ServiceNodes(nil, key, nil, structs.TODOPeerKeyword)
if err != nil { if err != nil {
return idx, err return idx, err
} }
@ -434,6 +435,10 @@ func (p nodePayload) Subject() stream.Subject {
return stream.StringSubject(p.key) return stream.StringSubject(p.key)
} }
func (e nodePayload) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
panic("EventPayloadCARoots does not implement ToSubscriptionEvent")
}
func createTokenAndWaitForACLEventPublish(t *testing.T, s *Store) *structs.ACLToken { func createTokenAndWaitForACLEventPublish(t *testing.T, s *Store) *structs.ACLToken {
token := &structs.ACLToken{ token := &structs.ACLToken{
AccessorID: "3af117a9-2233-4cf4-8ff8-3c749c9906b4", AccessorID: "3af117a9-2233-4cf4-8ff8-3c749c9906b4",

View File

@ -153,9 +153,9 @@ func (s *Store) txnNode(tx WriteTxn, idx uint64, op *structs.TxnNodeOp) (structs
getNode := func() (*structs.Node, error) { getNode := func() (*structs.Node, error) {
if op.Node.ID != "" { if op.Node.ID != "" {
return getNodeIDTxn(tx, op.Node.ID, op.Node.GetEnterpriseMeta()) return getNodeIDTxn(tx, op.Node.ID, op.Node.GetEnterpriseMeta(), op.Node.PeerName)
} else { } else {
return getNodeTxn(tx, op.Node.Node, op.Node.GetEnterpriseMeta()) return getNodeTxn(tx, op.Node.Node, op.Node.GetEnterpriseMeta(), op.Node.PeerName)
} }
} }
@ -182,11 +182,11 @@ func (s *Store) txnNode(tx WriteTxn, idx uint64, op *structs.TxnNodeOp) (structs
entry, err = getNode() entry, err = getNode()
case api.NodeDelete: case api.NodeDelete:
err = s.deleteNodeTxn(tx, idx, op.Node.Node, op.Node.GetEnterpriseMeta()) err = s.deleteNodeTxn(tx, idx, op.Node.Node, op.Node.GetEnterpriseMeta(), op.Node.PeerName)
case api.NodeDeleteCAS: case api.NodeDeleteCAS:
var ok bool var ok bool
ok, err = s.deleteNodeCASTxn(tx, idx, op.Node.ModifyIndex, op.Node.Node, op.Node.GetEnterpriseMeta()) ok, err = s.deleteNodeCASTxn(tx, idx, op.Node.ModifyIndex, op.Node.Node, op.Node.GetEnterpriseMeta(), op.Node.PeerName)
if !ok && err == nil { if !ok && err == nil {
err = fmt.Errorf("failed to delete node %q, index is stale", op.Node.Node) err = fmt.Errorf("failed to delete node %q, index is stale", op.Node.Node)
} }
@ -219,7 +219,7 @@ func (s *Store) txnNode(tx WriteTxn, idx uint64, op *structs.TxnNodeOp) (structs
func (s *Store) txnService(tx WriteTxn, idx uint64, op *structs.TxnServiceOp) (structs.TxnResults, error) { func (s *Store) txnService(tx WriteTxn, idx uint64, op *structs.TxnServiceOp) (structs.TxnResults, error) {
switch op.Verb { switch op.Verb {
case api.ServiceGet: case api.ServiceGet:
entry, err := getNodeServiceTxn(tx, op.Node, op.Service.ID, &op.Service.EnterpriseMeta) entry, err := getNodeServiceTxn(tx, op.Node, op.Service.ID, &op.Service.EnterpriseMeta, op.Service.PeerName)
switch { switch {
case err != nil: case err != nil:
return nil, err return nil, err
@ -233,7 +233,7 @@ func (s *Store) txnService(tx WriteTxn, idx uint64, op *structs.TxnServiceOp) (s
if err := ensureServiceTxn(tx, idx, op.Node, false, &op.Service); err != nil { if err := ensureServiceTxn(tx, idx, op.Node, false, &op.Service); err != nil {
return nil, err return nil, err
} }
entry, err := getNodeServiceTxn(tx, op.Node, op.Service.ID, &op.Service.EnterpriseMeta) entry, err := getNodeServiceTxn(tx, op.Node, op.Service.ID, &op.Service.EnterpriseMeta, op.Service.PeerName)
return newTxnResultFromNodeServiceEntry(entry), err return newTxnResultFromNodeServiceEntry(entry), err
case api.ServiceCAS: case api.ServiceCAS:
@ -246,15 +246,15 @@ func (s *Store) txnService(tx WriteTxn, idx uint64, op *structs.TxnServiceOp) (s
return nil, err return nil, err
} }
entry, err := getNodeServiceTxn(tx, op.Node, op.Service.ID, &op.Service.EnterpriseMeta) entry, err := getNodeServiceTxn(tx, op.Node, op.Service.ID, &op.Service.EnterpriseMeta, op.Service.PeerName)
return newTxnResultFromNodeServiceEntry(entry), err return newTxnResultFromNodeServiceEntry(entry), err
case api.ServiceDelete: case api.ServiceDelete:
err := s.deleteServiceTxn(tx, idx, op.Node, op.Service.ID, &op.Service.EnterpriseMeta) err := s.deleteServiceTxn(tx, idx, op.Node, op.Service.ID, &op.Service.EnterpriseMeta, op.Service.PeerName)
return nil, err return nil, err
case api.ServiceDeleteCAS: case api.ServiceDeleteCAS:
ok, err := s.deleteServiceCASTxn(tx, idx, op.Service.ModifyIndex, op.Node, op.Service.ID, &op.Service.EnterpriseMeta) ok, err := s.deleteServiceCASTxn(tx, idx, op.Service.ModifyIndex, op.Node, op.Service.ID, &op.Service.EnterpriseMeta, op.Service.PeerName)
if !ok && err == nil { if !ok && err == nil {
return nil, fmt.Errorf("failed to delete service %q on node %q, index is stale", op.Service.ID, op.Node) return nil, fmt.Errorf("failed to delete service %q on node %q, index is stale", op.Service.ID, op.Node)
} }
@ -284,7 +284,7 @@ func (s *Store) txnCheck(tx WriteTxn, idx uint64, op *structs.TxnCheckOp) (struc
switch op.Verb { switch op.Verb {
case api.CheckGet: case api.CheckGet:
_, entry, err = getNodeCheckTxn(tx, op.Check.Node, op.Check.CheckID, &op.Check.EnterpriseMeta) _, entry, err = getNodeCheckTxn(tx, op.Check.Node, op.Check.CheckID, &op.Check.EnterpriseMeta, op.Check.PeerName)
if entry == nil && err == nil { if entry == nil && err == nil {
err = fmt.Errorf("check %q on node %q doesn't exist", op.Check.CheckID, op.Check.Node) err = fmt.Errorf("check %q on node %q doesn't exist", op.Check.CheckID, op.Check.Node)
} }
@ -292,7 +292,7 @@ func (s *Store) txnCheck(tx WriteTxn, idx uint64, op *structs.TxnCheckOp) (struc
case api.CheckSet: case api.CheckSet:
err = s.ensureCheckTxn(tx, idx, false, &op.Check) err = s.ensureCheckTxn(tx, idx, false, &op.Check)
if err == nil { if err == nil {
_, entry, err = getNodeCheckTxn(tx, op.Check.Node, op.Check.CheckID, &op.Check.EnterpriseMeta) _, entry, err = getNodeCheckTxn(tx, op.Check.Node, op.Check.CheckID, &op.Check.EnterpriseMeta, op.Check.PeerName)
} }
case api.CheckCAS: case api.CheckCAS:
@ -303,14 +303,14 @@ func (s *Store) txnCheck(tx WriteTxn, idx uint64, op *structs.TxnCheckOp) (struc
err = fmt.Errorf("failed to set check %q on node %q, index is stale", entry.CheckID, entry.Node) err = fmt.Errorf("failed to set check %q on node %q, index is stale", entry.CheckID, entry.Node)
break break
} }
_, entry, err = getNodeCheckTxn(tx, op.Check.Node, op.Check.CheckID, &op.Check.EnterpriseMeta) _, entry, err = getNodeCheckTxn(tx, op.Check.Node, op.Check.CheckID, &op.Check.EnterpriseMeta, op.Check.PeerName)
case api.CheckDelete: case api.CheckDelete:
err = s.deleteCheckTxn(tx, idx, op.Check.Node, op.Check.CheckID, &op.Check.EnterpriseMeta) err = s.deleteCheckTxn(tx, idx, op.Check.Node, op.Check.CheckID, &op.Check.EnterpriseMeta, op.Check.PeerName)
case api.CheckDeleteCAS: case api.CheckDeleteCAS:
var ok bool var ok bool
ok, err = s.deleteCheckCASTxn(tx, idx, op.Check.ModifyIndex, op.Check.Node, op.Check.CheckID, &op.Check.EnterpriseMeta) ok, err = s.deleteCheckCASTxn(tx, idx, op.Check.ModifyIndex, op.Check.Node, op.Check.CheckID, &op.Check.EnterpriseMeta, op.Check.PeerName)
if !ok && err == nil { if !ok && err == nil {
err = fmt.Errorf("failed to delete check %q on node %q, index is stale", op.Check.CheckID, op.Check.Node) err = fmt.Errorf("failed to delete check %q on node %q, index is stale", op.Check.CheckID, op.Check.Node)
} }

View File

@ -196,7 +196,7 @@ func TestStateStore_Txn_Node(t *testing.T) {
require.Equal(t, expected, results) require.Equal(t, expected, results)
// Pull the resulting state store contents. // Pull the resulting state store contents.
idx, actual, err := s.Nodes(nil, nil) idx, actual, err := s.Nodes(nil, nil, "")
require.NoError(t, err) require.NoError(t, err)
if idx != 8 { if idx != 8 {
t.Fatalf("bad index: %d", idx) t.Fatalf("bad index: %d", idx)
@ -311,7 +311,7 @@ func TestStateStore_Txn_Service(t *testing.T) {
require.Equal(t, expected, results) require.Equal(t, expected, results)
// Pull the resulting state store contents. // Pull the resulting state store contents.
idx, actual, err := s.NodeServices(nil, "node1", nil) idx, actual, err := s.NodeServices(nil, "node1", nil, "")
require.NoError(t, err) require.NoError(t, err)
if idx != 6 { if idx != 6 {
t.Fatalf("bad index: %d", idx) t.Fatalf("bad index: %d", idx)
@ -464,7 +464,7 @@ func TestStateStore_Txn_Checks(t *testing.T) {
require.Equal(t, expected, results) require.Equal(t, expected, results)
// Pull the resulting state store contents. // Pull the resulting state store contents.
idx, actual, err := s.NodeChecks(nil, "node1", nil) idx, actual, err := s.NodeChecks(nil, "node1", nil, "")
require.NoError(t, err) require.NoError(t, err)
if idx != 6 { if idx != 6 {
t.Fatalf("bad index: %d", idx) t.Fatalf("bad index: %d", idx)

View File

@ -38,7 +38,7 @@ func TestStateStore_Usage_NodeUsage_Delete(t *testing.T) {
require.Equal(t, idx, uint64(1)) require.Equal(t, idx, uint64(1))
require.Equal(t, usage.Nodes, 2) require.Equal(t, usage.Nodes, 2)
require.NoError(t, s.DeleteNode(2, "node2", nil)) require.NoError(t, s.DeleteNode(2, "node2", nil, ""))
idx, usage, err = s.NodeUsage() idx, usage, err = s.NodeUsage()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, idx, uint64(2)) require.Equal(t, idx, uint64(2))
@ -152,7 +152,7 @@ func TestStateStore_Usage_ServiceUsage_DeleteNode(t *testing.T) {
require.Equal(t, 1, usage.ConnectServiceInstances[string(structs.ServiceKindConnectProxy)]) require.Equal(t, 1, usage.ConnectServiceInstances[string(structs.ServiceKindConnectProxy)])
require.Equal(t, 1, usage.ConnectServiceInstances[connectNativeInstancesTable]) require.Equal(t, 1, usage.ConnectServiceInstances[connectNativeInstancesTable])
require.NoError(t, s.DeleteNode(4, "node1", nil)) require.NoError(t, s.DeleteNode(4, "node1", nil, ""))
idx, usage, err = s.ServiceUsage() idx, usage, err = s.ServiceUsage()
require.NoError(t, err) require.NoError(t, err)

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/proto/pbsubscribe"
) )
// Topic is an identifier that partitions events. A subscription will only receive // Topic is an identifier that partitions events. A subscription will only receive
@ -46,6 +47,10 @@ type Payload interface {
// it is usually the normalized resource name (including the partition and // it is usually the normalized resource name (including the partition and
// namespace if applicable). // namespace if applicable).
Subject() Subject Subject() Subject
// ToSubscriptionEvent is used to convert streaming events to their
// serializable equivalent.
ToSubscriptionEvent(idx uint64) *pbsubscribe.Event
} }
// PayloadEvents is a Payload that may be returned by Subscription.Next when // PayloadEvents is a Payload that may be returned by Subscription.Next when
@ -109,6 +114,26 @@ func (PayloadEvents) Subject() Subject {
panic("PayloadEvents does not implement Subject") panic("PayloadEvents does not implement Subject")
} }
func (p PayloadEvents) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
return &pbsubscribe.Event{
Index: idx,
Payload: &pbsubscribe.Event_EventBatch{
EventBatch: &pbsubscribe.EventBatch{
Events: batchEventsFromEventSlice(p.Items),
},
},
}
}
func batchEventsFromEventSlice(events []Event) []*pbsubscribe.Event {
result := make([]*pbsubscribe.Event, len(events))
for i := range events {
event := events[i]
result[i] = event.Payload.ToSubscriptionEvent(event.Index)
}
return result
}
// IsEndOfSnapshot returns true if this is a framing event that indicates the // IsEndOfSnapshot returns true if this is a framing event that indicates the
// snapshot has completed. Subsequent events from Subscription.Next will be // snapshot has completed. Subsequent events from Subscription.Next will be
// streamed as they occur. // streamed as they occur.
@ -142,18 +167,42 @@ func (framingEvent) Subject() Subject {
panic("framing events do not implement Subject") panic("framing events do not implement Subject")
} }
func (framingEvent) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
panic("framingEvent does not implement ToSubscriptionEvent")
}
type endOfSnapshot struct { type endOfSnapshot struct {
framingEvent framingEvent
} }
func (s endOfSnapshot) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
return &pbsubscribe.Event{
Index: idx,
Payload: &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true},
}
}
type newSnapshotToFollow struct { type newSnapshotToFollow struct {
framingEvent framingEvent
} }
func (s newSnapshotToFollow) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
return &pbsubscribe.Event{
Index: idx,
Payload: &pbsubscribe.Event_NewSnapshotToFollow{NewSnapshotToFollow: true},
}
}
type closeSubscriptionPayload struct { type closeSubscriptionPayload struct {
tokensSecretIDs []string tokensSecretIDs []string
} }
// closeSubscriptionPayload is only used internally and does not correspond to
// a subscription event that would be sent to clients.
func (s closeSubscriptionPayload) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
panic("closeSubscriptionPayload does not implement ToSubscriptionEvent")
}
func (closeSubscriptionPayload) HasReadPermission(acl.Authorizer) bool { func (closeSubscriptionPayload) HasReadPermission(acl.Authorizer) bool {
return false return false
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/proto/pbsubscribe"
) )
type intTopic int type intTopic int
@ -84,6 +85,10 @@ func (p simplePayload) HasReadPermission(acl.Authorizer) bool {
func (p simplePayload) Subject() Subject { return StringSubject(p.key) } func (p simplePayload) Subject() Subject { return StringSubject(p.key) }
func (p simplePayload) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
panic("simplePayload does not implement ToSubscriptionEvent")
}
func registerTestSnapshotHandlers(t *testing.T, publisher *EventPublisher) { func registerTestSnapshotHandlers(t *testing.T, publisher *EventPublisher) {
t.Helper() t.Helper()

View File

@ -234,7 +234,7 @@ func TestTxn_Apply(t *testing.T) {
t.Fatalf("bad: %v", d) t.Fatalf("bad: %v", d)
} }
_, n, err := state.GetNode("foo", nil) _, n, err := state.GetNode("foo", nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -242,7 +242,7 @@ func TestTxn_Apply(t *testing.T) {
t.Fatalf("bad: %v", err) t.Fatalf("bad: %v", err)
} }
_, s, err := state.NodeService("foo", "svc-foo", nil) _, s, err := state.NodeService("foo", "svc-foo", nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -250,7 +250,7 @@ func TestTxn_Apply(t *testing.T) {
t.Fatalf("bad: %v", err) t.Fatalf("bad: %v", err)
} }
_, c, err := state.NodeCheck("foo", types.CheckID("check-foo"), nil) _, c, err := state.NodeCheck("foo", types.CheckID("check-foo"), nil, "")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@ -2,7 +2,6 @@ package subscribe
import ( import (
"errors" "errors"
"fmt"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -13,7 +12,6 @@ import (
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbservice"
"github.com/hashicorp/consul/proto/pbsubscribe" "github.com/hashicorp/consul/proto/pbsubscribe"
) )
@ -61,7 +59,7 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub
return status.Error(codes.InvalidArgument, "Key is required") return status.Error(codes.InvalidArgument, "Key is required")
} }
sub, err := h.Backend.Subscribe(toStreamSubscribeRequest(req, entMeta)) sub, err := h.Backend.Subscribe(state.PBToStreamSubscribeRequest(req, entMeta))
if err != nil { if err != nil {
return err return err
} }
@ -84,25 +82,15 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub
} }
elog.Trace(event) elog.Trace(event)
e := newEventFromStreamEvent(event)
// TODO: This conversion could be cached if needed
e := event.Payload.ToSubscriptionEvent(event.Index)
if err := serverStream.Send(e); err != nil { if err := serverStream.Send(e); err != nil {
return err return err
} }
} }
} }
func toStreamSubscribeRequest(req *pbsubscribe.SubscribeRequest, entMeta acl.EnterpriseMeta) *stream.SubscribeRequest {
return &stream.SubscribeRequest{
Topic: req.Topic,
Subject: state.EventSubjectService{
Key: req.Key,
EnterpriseMeta: entMeta,
},
Token: req.Token,
Index: req.Index,
}
}
func forwardToDC( func forwardToDC(
req *pbsubscribe.SubscribeRequest, req *pbsubscribe.SubscribeRequest,
serverStream pbsubscribe.StateChangeSubscription_SubscribeServer, serverStream pbsubscribe.StateChangeSubscription_SubscribeServer,
@ -129,48 +117,3 @@ func forwardToDC(
} }
} }
} }
func newEventFromStreamEvent(event stream.Event) *pbsubscribe.Event {
e := &pbsubscribe.Event{Index: event.Index}
switch {
case event.IsEndOfSnapshot():
e.Payload = &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true}
return e
case event.IsNewSnapshotToFollow():
e.Payload = &pbsubscribe.Event_NewSnapshotToFollow{NewSnapshotToFollow: true}
return e
}
setPayload(e, event.Payload)
return e
}
func setPayload(e *pbsubscribe.Event, payload stream.Payload) {
switch p := payload.(type) {
case *stream.PayloadEvents:
e.Payload = &pbsubscribe.Event_EventBatch{
EventBatch: &pbsubscribe.EventBatch{
Events: batchEventsFromEventSlice(p.Items),
},
}
case state.EventPayloadCheckServiceNode:
e.Payload = &pbsubscribe.Event_ServiceHealth{
ServiceHealth: &pbsubscribe.ServiceHealthUpdate{
Op: p.Op,
// TODO: this could be cached
CheckServiceNode: pbservice.NewCheckServiceNodeFromStructs(p.Value),
},
}
default:
panic(fmt.Sprintf("unexpected payload: %T: %#v", p, p))
}
}
func batchEventsFromEventSlice(events []stream.Event) []*pbsubscribe.Event {
result := make([]*pbsubscribe.Event, len(events))
for i := range events {
event := events[i]
result[i] = &pbsubscribe.Event{Index: event.Index}
setPayload(result[i], event.Payload)
}
return result
}

View File

@ -956,7 +956,7 @@ func TestNewEventFromSteamEvent(t *testing.T) {
fn := func(t *testing.T, tc testCase) { fn := func(t *testing.T, tc testCase) {
expected := tc.expected expected := tc.expected
actual := newEventFromStreamEvent(tc.event) actual := tc.event.Payload.ToSubscriptionEvent(tc.event.Index)
prototest.AssertDeepEqual(t, expected, actual, cmpopts.EquateEmpty()) prototest.AssertDeepEqual(t, expected, actual, cmpopts.EquateEmpty())
} }

View File

@ -5,14 +5,15 @@ import (
"errors" "errors"
"strings" "strings"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/structpb"
acl "github.com/hashicorp/consul/acl" acl "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/grpc/public" "github.com/hashicorp/consul/agent/grpc/public"
structs "github.com/hashicorp/consul/agent/structs" structs "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto-public/pbdataplane" "github.com/hashicorp/consul/proto-public/pbdataplane"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/structpb"
) )
func (s *Server) GetEnvoyBootstrapParams(ctx context.Context, req *pbdataplane.GetEnvoyBootstrapParamsRequest) (*pbdataplane.GetEnvoyBootstrapParamsResponse, error) { func (s *Server) GetEnvoyBootstrapParams(ctx context.Context, req *pbdataplane.GetEnvoyBootstrapParamsRequest) (*pbdataplane.GetEnvoyBootstrapParamsResponse, error) {
@ -31,7 +32,7 @@ func (s *Server) GetEnvoyBootstrapParams(ctx context.Context, req *pbdataplane.G
store := s.GetStore() store := s.GetStore()
_, svc, err := store.ServiceNode(req.GetNodeId(), req.GetNodeName(), req.GetServiceId(), &entMeta) _, svc, err := store.ServiceNode(req.GetNodeId(), req.GetNodeName(), req.GetServiceId(), &entMeta, structs.DefaultPeerKeyword)
if err != nil { if err != nil {
logger.Error("Error looking up service", "error", err) logger.Error("Error looking up service", "error", err)
if errors.Is(err, state.ErrNodeNotFound) { if errors.Is(err, state.ErrNodeNotFound) {

View File

@ -23,7 +23,7 @@ type Config struct {
} }
type StateStore interface { type StateStore interface {
ServiceNode(string, string, string, *acl.EnterpriseMeta) (uint64, *structs.ServiceNode, error) ServiceNode(string, string, string, *acl.EnterpriseMeta, string) (uint64, *structs.ServiceNode, error)
} }
//go:generate mockery --name ACLResolver --inpackage //go:generate mockery --name ACLResolver --inpackage

View File

@ -194,6 +194,8 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re
return nil, nil return nil, nil
} }
s.parsePeerName(req, &args)
// Check for tags // Check for tags
params := req.URL.Query() params := req.URL.Query()
if _, ok := params["tag"]; ok { if _, ok := params["tag"]; ok {

View File

@ -607,129 +607,163 @@ func TestHealthServiceNodes(t *testing.T) {
t.Parallel() t.Parallel()
a := NewTestAgent(t, "") a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1") testrpc.WaitForTestAgent(t, a.RPC, "dc1")
req, _ := http.NewRequest("GET", "/v1/health/service/consul?dc=dc1", nil) testingPeerNames := []string{"", "my-peer"}
resp := httptest.NewRecorder()
obj, err := a.srv.HealthServiceNodes(resp, req) suffix := func(peerName string) string {
if err != nil { if peerName == "" {
t.Fatalf("err: %v", err) return ""
}
// TODO(peering): after streaming works, remove the "&near=_agent" part
return "&peer=" + peerName + "&near=_agent"
} }
assertIndex(t, resp) for _, peerName := range testingPeerNames {
req, err := http.NewRequest("GET", "/v1/health/service/consul?dc=dc1"+suffix(peerName), nil)
// Should be 1 health check for consul require.NoError(t, err)
nodes := obj.(structs.CheckServiceNodes)
if len(nodes) != 1 {
t.Fatalf("bad: %v", obj)
}
req, _ = http.NewRequest("GET", "/v1/health/service/nope?dc=dc1", nil)
resp = httptest.NewRecorder()
obj, err = a.srv.HealthServiceNodes(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
assertIndex(t, resp)
// Should be a non-nil empty list
nodes = obj.(structs.CheckServiceNodes)
if nodes == nil || len(nodes) != 0 {
t.Fatalf("bad: %v", obj)
}
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "bar",
Address: "127.0.0.1",
Service: &structs.NodeService{
ID: "test",
Service: "test",
},
}
var out struct{}
if err := a.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
req, _ = http.NewRequest("GET", "/v1/health/service/test?dc=dc1", nil)
resp = httptest.NewRecorder()
obj, err = a.srv.HealthServiceNodes(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
assertIndex(t, resp)
// Should be a non-nil empty list for checks
nodes = obj.(structs.CheckServiceNodes)
if len(nodes) != 1 || nodes[0].Checks == nil || len(nodes[0].Checks) != 0 {
t.Fatalf("bad: %v", obj)
}
// Test caching
{
// List instances with cache enabled
req, _ := http.NewRequest("GET", "/v1/health/service/test?cached", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.HealthServiceNodes(resp, req) obj, err := a.srv.HealthServiceNodes(resp, req)
require.NoError(t, err) require.NoError(t, err)
nodes := obj.(structs.CheckServiceNodes)
assert.Len(t, nodes, 1)
// Should be a cache miss assertIndex(t, resp)
assert.Equal(t, "MISS", resp.Header().Get("X-Cache"))
nodes := obj.(structs.CheckServiceNodes)
if peerName == "" {
// Should be 1 health check for consul
require.Len(t, nodes, 1)
} else {
require.NotNil(t, nodes)
require.Len(t, nodes, 0)
}
req, err = http.NewRequest("GET", "/v1/health/service/nope?dc=dc1"+suffix(peerName), nil)
require.NoError(t, err)
resp = httptest.NewRecorder()
obj, err = a.srv.HealthServiceNodes(resp, req)
require.NoError(t, err)
assertIndex(t, resp)
// Should be a non-nil empty list
nodes = obj.(structs.CheckServiceNodes)
require.NotNil(t, nodes)
require.Len(t, nodes, 0)
} }
{ // TODO(peering): will have to seed this data differently in the future
// List instances with cache enabled originalRegister := make(map[string]*structs.RegisterRequest)
req, _ := http.NewRequest("GET", "/v1/health/service/test?cached", nil) for _, peerName := range testingPeerNames {
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "bar",
Address: "127.0.0.1",
PeerName: peerName,
Service: &structs.NodeService{
ID: "test",
Service: "test",
PeerName: peerName,
},
}
var out struct{}
require.NoError(t, a.RPC("Catalog.Register", args, &out))
originalRegister[peerName] = args
}
verify := func(t *testing.T, peerName string, nodes structs.CheckServiceNodes) {
require.Len(t, nodes, 1)
require.Equal(t, peerName, nodes[0].Node.PeerName)
require.Equal(t, "bar", nodes[0].Node.Node)
require.Equal(t, peerName, nodes[0].Service.PeerName)
require.Equal(t, "test", nodes[0].Service.Service)
require.NotNil(t, nodes[0].Checks)
require.Len(t, nodes[0].Checks, 0)
}
for _, peerName := range testingPeerNames {
req, err := http.NewRequest("GET", "/v1/health/service/test?dc=dc1"+suffix(peerName), nil)
require.NoError(t, err)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.HealthServiceNodes(resp, req) obj, err := a.srv.HealthServiceNodes(resp, req)
require.NoError(t, err) require.NoError(t, err)
nodes := obj.(structs.CheckServiceNodes)
assert.Len(t, nodes, 1)
// Should be a cache HIT now! assertIndex(t, resp)
assert.Equal(t, "HIT", resp.Header().Get("X-Cache"))
// Should be a non-nil empty list for checks
nodes := obj.(structs.CheckServiceNodes)
verify(t, peerName, nodes)
// Test caching
{
// List instances with cache enabled
req, err := http.NewRequest("GET", "/v1/health/service/test?cached"+suffix(peerName), nil)
require.NoError(t, err)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthServiceNodes(resp, req)
require.NoError(t, err)
nodes := obj.(structs.CheckServiceNodes)
verify(t, peerName, nodes)
// Should be a cache miss
require.Equal(t, "MISS", resp.Header().Get("X-Cache"))
}
{
// List instances with cache enabled
req, err := http.NewRequest("GET", "/v1/health/service/test?cached"+suffix(peerName), nil)
require.NoError(t, err)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthServiceNodes(resp, req)
require.NoError(t, err)
nodes := obj.(structs.CheckServiceNodes)
verify(t, peerName, nodes)
// Should be a cache HIT now!
require.Equal(t, "HIT", resp.Header().Get("X-Cache"))
}
} }
// Ensure background refresh works // Ensure background refresh works
{ {
// Register a new instance of the service // TODO(peering): will have to seed this data differently in the future
args2 := args for _, peerName := range testingPeerNames {
args2.Node = "baz" args := originalRegister[peerName]
args2.Address = "127.0.0.2" // Register a new instance of the service
require.NoError(t, a.RPC("Catalog.Register", args, &out)) args2 := *args
args2.Node = "baz"
args2.Address = "127.0.0.2"
var out struct{}
require.NoError(t, a.RPC("Catalog.Register", &args2, &out))
}
retry.Run(t, func(r *retry.R) { for _, peerName := range testingPeerNames {
// List it again retry.Run(t, func(r *retry.R) {
req, _ := http.NewRequest("GET", "/v1/health/service/test?cached", nil) // List it again
resp := httptest.NewRecorder() req, err := http.NewRequest("GET", "/v1/health/service/test?cached"+suffix(peerName), nil)
obj, err := a.srv.HealthServiceNodes(resp, req) require.NoError(r, err)
r.Check(err) resp := httptest.NewRecorder()
obj, err := a.srv.HealthServiceNodes(resp, req)
require.NoError(r, err)
nodes := obj.(structs.CheckServiceNodes) nodes := obj.(structs.CheckServiceNodes)
if len(nodes) != 2 { require.Len(r, nodes, 2)
r.Fatalf("Want 2 nodes")
}
header := resp.Header().Get("X-Consul-Index")
if header == "" || header == "0" {
r.Fatalf("Want non-zero header: %q", header)
}
_, err = strconv.ParseUint(header, 10, 64)
r.Check(err)
// Should be a cache hit! The data should've updated in the cache header := resp.Header().Get("X-Consul-Index")
// in the background so this should've been fetched directly from if header == "" || header == "0" {
// the cache. r.Fatalf("Want non-zero header: %q", header)
if resp.Header().Get("X-Cache") != "HIT" { }
r.Fatalf("should be a cache hit") _, err = strconv.ParseUint(header, 10, 64)
} require.NoError(r, err)
})
// Should be a cache hit! The data should've updated in the cache
// in the background so this should've been fetched directly from
// the cache.
if resp.Header().Get("X-Cache") != "HIT" {
r.Fatalf("should be a cache hit")
}
})
}
} }
} }

View File

@ -1105,6 +1105,12 @@ func (s *HTTPHandlers) parseSource(req *http.Request, source *structs.QuerySourc
} }
} }
func (s *HTTPHandlers) parsePeerName(req *http.Request, args *structs.ServiceSpecificRequest) {
if peer := req.URL.Query().Get("peer"); peer != "" {
args.PeerName = peer
}
}
// parseMetaFilter is used to parse the ?node-meta=key:value query parameter, used for // parseMetaFilter is used to parse the ?node-meta=key:value query parameter, used for
// filtering results to nodes with the given metadata key/value // filtering results to nodes with the given metadata key/value
func (s *HTTPHandlers) parseMetaFilter(req *http.Request) map[string]string { func (s *HTTPHandlers) parseMetaFilter(req *http.Request) map[string]string {

View File

@ -103,6 +103,10 @@ func init() {
registerEndpoint("/v1/operator/autopilot/configuration", []string{"GET", "PUT"}, (*HTTPHandlers).OperatorAutopilotConfiguration) registerEndpoint("/v1/operator/autopilot/configuration", []string{"GET", "PUT"}, (*HTTPHandlers).OperatorAutopilotConfiguration)
registerEndpoint("/v1/operator/autopilot/health", []string{"GET"}, (*HTTPHandlers).OperatorServerHealth) registerEndpoint("/v1/operator/autopilot/health", []string{"GET"}, (*HTTPHandlers).OperatorServerHealth)
registerEndpoint("/v1/operator/autopilot/state", []string{"GET"}, (*HTTPHandlers).OperatorAutopilotState) registerEndpoint("/v1/operator/autopilot/state", []string{"GET"}, (*HTTPHandlers).OperatorAutopilotState)
registerEndpoint("/v1/peering/token", []string{"POST"}, (*HTTPHandlers).PeeringGenerateToken)
registerEndpoint("/v1/peering/initiate", []string{"POST"}, (*HTTPHandlers).PeeringInitiate)
registerEndpoint("/v1/peering/", []string{"GET"}, (*HTTPHandlers).PeeringRead)
registerEndpoint("/v1/peerings", []string{"GET"}, (*HTTPHandlers).PeeringList)
registerEndpoint("/v1/query", []string{"GET", "POST"}, (*HTTPHandlers).PreparedQueryGeneral) registerEndpoint("/v1/query", []string{"GET", "POST"}, (*HTTPHandlers).PreparedQueryGeneral)
// specific prepared query endpoints have more complex rules for allowed methods, so // specific prepared query endpoints have more complex rules for allowed methods, so
// the prefix is registered with no methods. // the prefix is registered with no methods.

118
agent/peering_endpoint.go Normal file
View File

@ -0,0 +1,118 @@
package agent
import (
"fmt"
"net/http"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/proto/pbpeering"
)
// PeeringRead fetches a peering that matches the request parameters.
func (s *HTTPHandlers) PeeringRead(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
name, err := getPathSuffixUnescaped(req.URL.Path, "/v1/peering/")
if err != nil {
return nil, err
}
if name == "" {
return nil, BadRequestError{Reason: "Must specify a name to fetch."}
}
entMeta := s.agent.AgentEnterpriseMeta()
if err := s.parseEntMetaPartition(req, entMeta); err != nil {
return nil, err
}
args := pbpeering.PeeringReadRequest{
Name: name,
Datacenter: s.agent.config.Datacenter,
Partition: entMeta.PartitionOrEmpty(), // should be "" in OSS
}
result, err := s.agent.rpcClientPeering.PeeringRead(req.Context(), &args)
if err != nil {
return nil, err
}
if result.Peering == nil {
return nil, NotFoundError{}
}
// TODO(peering): replace with API types
return result.Peering, nil
}
// PeeringList fetches all peerings in the datacenter in OSS or in a given partition in Consul Enterprise.
func (s *HTTPHandlers) PeeringList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
entMeta := s.agent.AgentEnterpriseMeta()
if err := s.parseEntMetaPartition(req, entMeta); err != nil {
return nil, err
}
args := pbpeering.PeeringListRequest{
Datacenter: s.agent.config.Datacenter,
Partition: entMeta.PartitionOrEmpty(), // should be "" in OSS
}
pbresp, err := s.agent.rpcClientPeering.PeeringList(req.Context(), &args)
if err != nil {
return nil, err
}
return pbresp.Peerings, nil
}
// PeeringGenerateToken handles POSTs to the /v1/peering/token endpoint. The request
// will always be forwarded via RPC to the local leader.
func (s *HTTPHandlers) PeeringGenerateToken(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
args := pbpeering.GenerateTokenRequest{
Datacenter: s.agent.config.Datacenter,
}
if req.Body == nil {
return nil, BadRequestError{Reason: "The peering arguments must be provided in the body"}
}
if err := lib.DecodeJSON(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Body decoding failed: %v", err)}
}
if args.PeerName == "" {
return nil, BadRequestError{Reason: "PeerName is required in the payload when generating a new peering token."}
}
entMeta := s.agent.AgentEnterpriseMeta()
if err := s.parseEntMetaPartition(req, entMeta); err != nil {
return nil, err
}
if args.Partition == "" {
args.Partition = entMeta.PartitionOrEmpty()
}
return s.agent.rpcClientPeering.GenerateToken(req.Context(), &args)
}
// PeeringInitiate handles POSTs to the /v1/peering/initiate endpoint. The request
// will always be forwarded via RPC to the local leader.
func (s *HTTPHandlers) PeeringInitiate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
args := pbpeering.InitiateRequest{
Datacenter: s.agent.config.Datacenter,
}
if req.Body == nil {
return nil, BadRequestError{Reason: "The peering arguments must be provided in the body"}
}
if err := lib.DecodeJSON(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Body decoding failed: %v", err)}
}
if args.PeerName == "" {
return nil, BadRequestError{Reason: "PeerName is required in the payload when initiating a peering."}
}
if args.PeeringToken == "" {
return nil, BadRequestError{Reason: "PeeringToken is required in the payload when initiating a peering."}
}
return s.agent.rpcClientPeering.Initiate(req.Context(), &args)
}

View File

@ -0,0 +1,45 @@
//go:build !consulent
// +build !consulent
package agent
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/testrpc"
)
func TestHTTP_Peering_GenerateToken_OSS_Failure(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
a := NewTestAgent(t, "")
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
t.Run("Doesn't allow partitions in OSS HTTP requests", func(t *testing.T) {
reqBody := &pbpeering.GenerateTokenRequest{
PeerName: "peering-a",
}
reqBodyBytes, err := json.Marshal(reqBody)
require.NoError(t, err)
req, err := http.NewRequest("POST", "/v1/peering/token?partition=foo",
bytes.NewReader(reqBodyBytes))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusBadRequest, resp.Code)
body, _ := io.ReadAll(resp.Body)
require.Contains(t, string(body), "Partitions are a Consul Enterprise feature")
})
}

View File

@ -0,0 +1,312 @@
package agent
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/testrpc"
)
var validCA = `
-----BEGIN CERTIFICATE-----
MIICmDCCAj6gAwIBAgIBBzAKBggqhkjOPQQDAjAWMRQwEgYDVQQDEwtDb25zdWwg
Q0EgNzAeFw0xODA1MjExNjMzMjhaFw0yODA1MTgxNjMzMjhaMBYxFDASBgNVBAMT
C0NvbnN1bCBDQSA3MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAER0qlxjnRcMEr
iSGlH7G7dYU7lzBEmLUSMZkyBbClmyV8+e8WANemjn+PLnCr40If9cmpr7RnC9Qk
GTaLnLiF16OCAXswggF3MA4GA1UdDwEB/wQEAwIBhjAPBgNVHRMBAf8EBTADAQH/
MGgGA1UdDgRhBF8xZjo5MTpjYTo0MTo4ZjphYzo2NzpiZjo1OTpjMjpmYTo0ZTo3
NTo1YzpkODpmMDo1NTpkZTpiZTo3NTpiODozMzozMTpkNToyNDpiMDowNDpiMzpl
ODo5Nzo1Yjo3ZTBqBgNVHSMEYzBhgF8xZjo5MTpjYTo0MTo4ZjphYzo2NzpiZjo1
OTpjMjpmYTo0ZTo3NTo1YzpkODpmMDo1NTpkZTpiZTo3NTpiODozMzozMTpkNToy
NDpiMDowNDpiMzplODo5Nzo1Yjo3ZTA/BgNVHREEODA2hjRzcGlmZmU6Ly8xMjRk
ZjVhMC05ODIwLTc2YzMtOWFhOS02ZjYyMTY0YmExYzIuY29uc3VsMD0GA1UdHgEB
/wQzMDGgLzAtgisxMjRkZjVhMC05ODIwLTc2YzMtOWFhOS02ZjYyMTY0YmExYzIu
Y29uc3VsMAoGCCqGSM49BAMCA0gAMEUCIQDzkkI7R+0U12a+zq2EQhP/n2mHmta+
fs2hBxWIELGwTAIgLdO7RRw+z9nnxCIA6kNl//mIQb+PGItespiHZKAz74Q=
-----END CERTIFICATE-----
`
func TestHTTP_Peering_GenerateToken(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
a := NewTestAgent(t, "")
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
t.Run("No Body", func(t *testing.T) {
req, err := http.NewRequest("POST", "/v1/peering/token", nil)
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusBadRequest, resp.Code)
body, _ := io.ReadAll(resp.Body)
require.Contains(t, string(body), "The peering arguments must be provided in the body")
})
t.Run("Body Invalid", func(t *testing.T) {
req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader([]byte("abc")))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusBadRequest, resp.Code)
body, _ := io.ReadAll(resp.Body)
require.Contains(t, string(body), "Body decoding failed:")
})
t.Run("No Name", func(t *testing.T) {
req, err := http.NewRequest("POST", "/v1/peering/token",
bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusBadRequest, resp.Code)
body, _ := io.ReadAll(resp.Body)
require.Contains(t, string(body), "PeerName is required")
})
// TODO(peering): add more failure cases
t.Run("Success", func(t *testing.T) {
body := &pbpeering.GenerateTokenRequest{
PeerName: "peering-a",
}
bodyBytes, err := json.Marshal(body)
require.NoError(t, err)
req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
var r pbpeering.GenerateTokenResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&r))
tokenJSON, err := base64.StdEncoding.DecodeString(r.PeeringToken)
require.NoError(t, err)
var token structs.PeeringToken
require.NoError(t, json.Unmarshal(tokenJSON, &token))
require.Nil(t, token.CA)
require.Equal(t, []string{fmt.Sprintf("127.0.0.1:%d", a.config.ServerPort)}, token.ServerAddresses)
require.Equal(t, "server.dc1.consul", token.ServerName)
// The PeerID in the token is randomly generated so we don't assert on its value.
require.NotEmpty(t, token.PeerID)
})
}
func TestHTTP_Peering_Initiate(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
a := NewTestAgent(t, "")
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
t.Run("No Body", func(t *testing.T) {
req, err := http.NewRequest("POST", "/v1/peering/initiate", nil)
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusBadRequest, resp.Code)
body, _ := io.ReadAll(resp.Body)
require.Contains(t, string(body), "The peering arguments must be provided in the body")
})
t.Run("Body Invalid", func(t *testing.T) {
req, err := http.NewRequest("POST", "/v1/peering/initiate", bytes.NewReader([]byte("abc")))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusBadRequest, resp.Code)
body, _ := io.ReadAll(resp.Body)
require.Contains(t, string(body), "Body decoding failed:")
})
t.Run("No Name", func(t *testing.T) {
req, err := http.NewRequest("POST", "/v1/peering/initiate",
bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusBadRequest, resp.Code)
body, _ := io.ReadAll(resp.Body)
require.Contains(t, string(body), "PeerName is required")
})
t.Run("No Token", func(t *testing.T) {
req, err := http.NewRequest("POST", "/v1/peering/initiate",
bytes.NewReader([]byte(`{"PeerName": "peer1-usw1"}`)))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusBadRequest, resp.Code)
body, _ := io.ReadAll(resp.Body)
require.Contains(t, string(body), "PeeringToken is required")
})
// TODO(peering): add more failure cases
t.Run("Success", func(t *testing.T) {
token := structs.PeeringToken{
CA: []string{validCA},
ServerName: "server.dc1.consul",
ServerAddresses: []string{fmt.Sprintf("1.2.3.4:%d", 443)},
PeerID: "a0affd3e-f1c8-4bb9-9168-90fd902c441d",
}
tokenJSON, _ := json.Marshal(&token)
tokenB64 := base64.StdEncoding.EncodeToString(tokenJSON)
body := &pbpeering.InitiateRequest{
PeerName: "peering-a",
PeeringToken: tokenB64,
}
bodyBytes, err := json.Marshal(body)
require.NoError(t, err)
req, err := http.NewRequest("POST", "/v1/peering/initiate", bytes.NewReader(bodyBytes))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
// success response does not currently return a value so {} is correct
require.Equal(t, "{}", resp.Body.String())
})
}
func TestHTTP_Peering_Read(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
a := NewTestAgent(t, "")
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Insert peerings directly to state store.
// Note that the state store holds reference to the underlying
// variables; do not modify them after writing.
foo := &pbpeering.PeeringWriteRequest{
Peering: &pbpeering.Peering{
Name: "foo",
State: pbpeering.PeeringState_INITIAL,
PeerCAPems: nil,
PeerServerName: "fooservername",
PeerServerAddresses: []string{"addr1"},
},
}
_, err := a.rpcClientPeering.PeeringWrite(ctx, foo)
require.NoError(t, err)
bar := &pbpeering.PeeringWriteRequest{
Peering: &pbpeering.Peering{
Name: "bar",
State: pbpeering.PeeringState_ACTIVE,
PeerCAPems: nil,
PeerServerName: "barservername",
PeerServerAddresses: []string{"addr1"},
},
}
_, err = a.rpcClientPeering.PeeringWrite(ctx, bar)
require.NoError(t, err)
t.Run("return foo", func(t *testing.T) {
req, err := http.NewRequest("GET", "/v1/peering/foo", nil)
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code)
// TODO(peering): replace with API types
var pbresp pbpeering.Peering
require.NoError(t, json.NewDecoder(resp.Body).Decode(&pbresp))
require.Equal(t, foo.Peering.Name, pbresp.Name)
})
t.Run("not found", func(t *testing.T) {
req, err := http.NewRequest("GET", "/v1/peering/baz", nil)
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusNotFound, resp.Code)
})
}
func TestHTTP_Peering_List(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
a := NewTestAgent(t, "")
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Insert peerings directly to state store.
// Note that the state store holds reference to the underlying
// variables; do not modify them after writing.
foo := &pbpeering.PeeringWriteRequest{
Peering: &pbpeering.Peering{
Name: "foo",
State: pbpeering.PeeringState_INITIAL,
PeerCAPems: nil,
PeerServerName: "fooservername",
PeerServerAddresses: []string{"addr1"},
},
}
_, err := a.rpcClientPeering.PeeringWrite(ctx, foo)
require.NoError(t, err)
bar := &pbpeering.PeeringWriteRequest{
Peering: &pbpeering.Peering{
Name: "bar",
State: pbpeering.PeeringState_ACTIVE,
PeerCAPems: nil,
PeerServerName: "barservername",
PeerServerAddresses: []string{"addr1"},
},
}
_, err = a.rpcClientPeering.PeeringWrite(ctx, bar)
require.NoError(t, err)
t.Run("return all", func(t *testing.T) {
req, err := http.NewRequest("GET", "/v1/peerings", nil)
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code)
// TODO(peering): replace with API types
var pbresp []*pbpeering.Peering
require.NoError(t, json.NewDecoder(resp.Body).Decode(&pbresp))
require.Len(t, pbresp, 2)
})
}

View File

@ -0,0 +1,741 @@
package peering
import (
"context"
"errors"
"fmt"
"io"
"strconv"
"strings"
"time"
"github.com/armon/go-metrics"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"google.golang.org/genproto/googleapis/rpc/code"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/anypb"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/proto/pbservice"
"github.com/hashicorp/consul/proto/pbstatus"
)
var (
errPeeringTokenEmptyCA = errors.New("peering token CA value is empty")
errPeeringTokenInvalidCA = errors.New("peering token CA value is invalid")
errPeeringTokenEmptyServerAddresses = errors.New("peering token server addresses value is empty")
errPeeringTokenEmptyServerName = errors.New("peering token server name value is empty")
errPeeringTokenEmptyPeerID = errors.New("peering token peer ID value is empty")
)
// errPeeringInvalidServerAddress is returned when an initiate request contains
// an invalid server address.
type errPeeringInvalidServerAddress struct {
addr string
}
// Error implements the error interface
func (e *errPeeringInvalidServerAddress) Error() string {
return fmt.Sprintf("%s is not a valid peering server address", e.addr)
}
// Service implements pbpeering.PeeringService to provide RPC operations for
// managing peering relationships.
type Service struct {
Backend Backend
logger hclog.Logger
streams *streamTracker
}
func NewService(logger hclog.Logger, backend Backend) *Service {
return &Service{
Backend: backend,
logger: logger,
streams: newStreamTracker(),
}
}
var _ pbpeering.PeeringServiceServer = (*Service)(nil)
// Backend defines the core integrations the Peering endpoint depends on. A
// functional implementation will integrate with various subcomponents of Consul
// such as the State store for reading and writing data, the CA machinery for
// providing access to CA data and the RPC system for forwarding requests to
// other servers.
type Backend interface {
// Forward should forward the request to the leader when necessary.
Forward(info structs.RPCInfo, f func(*grpc.ClientConn) error) (handled bool, err error)
// GetAgentCACertificates returns the CA certificate to be returned in the peering token data
GetAgentCACertificates() ([]string, error)
// GetServerAddresses returns the addresses used for establishing a peering connection
GetServerAddresses() ([]string, error)
// GetServerName returns the SNI to be returned in the peering token data which
// will be used by peers when establishing peering connections over TLS.
GetServerName() string
// EncodeToken packages a peering token into a slice of bytes.
EncodeToken(tok *structs.PeeringToken) ([]byte, error)
// DecodeToken unpackages a peering token from a slice of bytes.
DecodeToken([]byte) (*structs.PeeringToken, error)
EnterpriseCheckPartitions(partition string) error
Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error)
Store() Store
Apply() Apply
}
// Store provides a read-only interface for querying Peering data.
type Store interface {
PeeringRead(ws memdb.WatchSet, q state.Query) (uint64, *pbpeering.Peering, error)
PeeringList(ws memdb.WatchSet, entMeta acl.EnterpriseMeta) (uint64, []*pbpeering.Peering, error)
ExportedServicesForPeer(ws memdb.WatchSet, peerID string) (uint64, []structs.ServiceName, error)
AbandonCh() <-chan struct{}
}
// Apply provides a write-only interface for persisting Peering data.
type Apply interface {
PeeringWrite(req *pbpeering.PeeringWriteRequest) error
PeeringDelete(req *pbpeering.PeeringDeleteRequest) error
PeeringTerminateByID(req *pbpeering.PeeringTerminateByIDRequest) error
}
// GenerateToken implements the PeeringService RPC method to generate a
// peering token which is the initial step in establishing a peering relationship
// with other Consul clusters.
func (s *Service) GenerateToken(
ctx context.Context,
req *pbpeering.GenerateTokenRequest,
) (*pbpeering.GenerateTokenResponse, error) {
if err := s.Backend.EnterpriseCheckPartitions(req.Partition); err != nil {
return nil, grpcstatus.Error(codes.InvalidArgument, err.Error())
}
// validate prior to forwarding to the leader, this saves a network hop
if err := dns.ValidateLabel(req.PeerName); err != nil {
return nil, fmt.Errorf("%s is not a valid peer name: %w", req.PeerName, err)
}
// TODO(peering): add metrics
// TODO(peering): add tracing
resp := &pbpeering.GenerateTokenResponse{}
handled, err := s.Backend.Forward(req, func(conn *grpc.ClientConn) error {
var err error
resp, err = pbpeering.NewPeeringServiceClient(conn).GenerateToken(ctx, req)
return err
})
if handled || err != nil {
return resp, err
}
ca, err := s.Backend.GetAgentCACertificates()
if err != nil {
return nil, err
}
serverAddrs, err := s.Backend.GetServerAddresses()
if err != nil {
return nil, err
}
writeReq := pbpeering.PeeringWriteRequest{
Peering: &pbpeering.Peering{
Name: req.PeerName,
// TODO(peering): Normalize from ACL token once this endpoint is guarded by ACLs.
Partition: req.PartitionOrDefault(),
},
}
if err := s.Backend.Apply().PeeringWrite(&writeReq); err != nil {
return nil, fmt.Errorf("failed to write peering: %w", err)
}
q := state.Query{
Value: strings.ToLower(req.PeerName),
EnterpriseMeta: *structs.NodeEnterpriseMetaInPartition(req.Partition),
}
_, peering, err := s.Backend.Store().PeeringRead(nil, q)
if err != nil {
return nil, err
}
if peering == nil {
return nil, fmt.Errorf("peering was deleted while token generation request was in flight")
}
tok := structs.PeeringToken{
// Store the UUID so that we can do a global search when handling inbound streams.
PeerID: peering.ID,
CA: ca,
ServerAddresses: serverAddrs,
ServerName: s.Backend.GetServerName(),
}
encoded, err := s.Backend.EncodeToken(&tok)
if err != nil {
return nil, err
}
resp.PeeringToken = string(encoded)
return resp, err
}
// Initiate implements the PeeringService RPC method to finalize peering
// registration. Given a valid token output from a peer's GenerateToken endpoint,
// a peering is registered.
func (s *Service) Initiate(
ctx context.Context,
req *pbpeering.InitiateRequest,
) (*pbpeering.InitiateResponse, error) {
// validate prior to forwarding to the leader, this saves a network hop
if err := dns.ValidateLabel(req.PeerName); err != nil {
return nil, fmt.Errorf("%s is not a valid peer name: %w", req.PeerName, err)
}
tok, err := s.Backend.DecodeToken([]byte(req.PeeringToken))
if err != nil {
return nil, err
}
if err := validatePeeringToken(tok); err != nil {
return nil, err
}
resp := &pbpeering.InitiateResponse{}
handled, err := s.Backend.Forward(req, func(conn *grpc.ClientConn) error {
var err error
resp, err = pbpeering.NewPeeringServiceClient(conn).Initiate(ctx, req)
return err
})
if handled || err != nil {
return resp, err
}
defer metrics.MeasureSince([]string{"peering", "initiate"}, time.Now())
// convert ServiceAddress values to strings
serverAddrs := make([]string, len(tok.ServerAddresses))
for i, addr := range tok.ServerAddresses {
serverAddrs[i] = addr
}
// as soon as a peering is written with a list of ServerAddresses that is
// non-empty, the leader routine will see the peering and attempt to establish
// a connection with the remote peer.
writeReq := &pbpeering.PeeringWriteRequest{
Peering: &pbpeering.Peering{
Name: req.PeerName,
PeerCAPems: tok.CA,
PeerServerAddresses: serverAddrs,
PeerServerName: tok.ServerName,
// uncomment once #1613 lands
// PeerID: tok.PeerID,
},
}
if err = s.Backend.Apply().PeeringWrite(writeReq); err != nil {
return nil, fmt.Errorf("failed to write peering: %w", err)
}
// resp.Status == 0
return resp, nil
}
func (s *Service) PeeringRead(ctx context.Context, req *pbpeering.PeeringReadRequest) (*pbpeering.PeeringReadResponse, error) {
if err := s.Backend.EnterpriseCheckPartitions(req.Partition); err != nil {
return nil, grpcstatus.Error(codes.InvalidArgument, err.Error())
}
var resp *pbpeering.PeeringReadResponse
handled, err := s.Backend.Forward(req, func(conn *grpc.ClientConn) error {
var err error
resp, err = pbpeering.NewPeeringServiceClient(conn).PeeringRead(ctx, req)
return err
})
if handled || err != nil {
return resp, err
}
defer metrics.MeasureSince([]string{"peering", "read"}, time.Now())
// TODO(peering): ACL check request token
// TODO(peering): handle blocking queries
q := state.Query{
Value: strings.ToLower(req.Name),
EnterpriseMeta: *structs.NodeEnterpriseMetaInPartition(req.Partition)}
_, peering, err := s.Backend.Store().PeeringRead(nil, q)
if err != nil {
return nil, err
}
return &pbpeering.PeeringReadResponse{Peering: peering}, nil
}
func (s *Service) PeeringList(ctx context.Context, req *pbpeering.PeeringListRequest) (*pbpeering.PeeringListResponse, error) {
if err := s.Backend.EnterpriseCheckPartitions(req.Partition); err != nil {
return nil, grpcstatus.Error(codes.InvalidArgument, err.Error())
}
var resp *pbpeering.PeeringListResponse
handled, err := s.Backend.Forward(req, func(conn *grpc.ClientConn) error {
var err error
resp, err = pbpeering.NewPeeringServiceClient(conn).PeeringList(ctx, req)
return err
})
if handled || err != nil {
return resp, err
}
defer metrics.MeasureSince([]string{"peering", "list"}, time.Now())
// TODO(peering): ACL check request token
// TODO(peering): handle blocking queries
_, peerings, err := s.Backend.Store().PeeringList(nil, *structs.NodeEnterpriseMetaInPartition(req.Partition))
if err != nil {
return nil, err
}
return &pbpeering.PeeringListResponse{Peerings: peerings}, nil
}
// TODO(peering): As of writing, this method is only used in tests to set up Peerings in the state store.
// Consider removing if we can find another way to populate state store in peering_endpoint_test.go
func (s *Service) PeeringWrite(ctx context.Context, req *pbpeering.PeeringWriteRequest) (*pbpeering.PeeringWriteResponse, error) {
if err := s.Backend.EnterpriseCheckPartitions(req.Peering.Partition); err != nil {
return nil, grpcstatus.Error(codes.InvalidArgument, err.Error())
}
var resp *pbpeering.PeeringWriteResponse
handled, err := s.Backend.Forward(req, func(conn *grpc.ClientConn) error {
var err error
resp, err = pbpeering.NewPeeringServiceClient(conn).PeeringWrite(ctx, req)
return err
})
if handled || err != nil {
return resp, err
}
defer metrics.MeasureSince([]string{"peering", "write"}, time.Now())
// TODO(peering): ACL check request token
// TODO(peering): handle blocking queries
err = s.Backend.Apply().PeeringWrite(req)
if err != nil {
return nil, err
}
return &pbpeering.PeeringWriteResponse{}, nil
}
func (s *Service) PeeringDelete(ctx context.Context, req *pbpeering.PeeringDeleteRequest) (*pbpeering.PeeringDeleteResponse, error) {
if err := s.Backend.EnterpriseCheckPartitions(req.Partition); err != nil {
return nil, grpcstatus.Error(codes.InvalidArgument, err.Error())
}
var resp *pbpeering.PeeringDeleteResponse
handled, err := s.Backend.Forward(req, func(conn *grpc.ClientConn) error {
var err error
resp, err = pbpeering.NewPeeringServiceClient(conn).PeeringDelete(ctx, req)
return err
})
if handled || err != nil {
return resp, err
}
defer metrics.MeasureSince([]string{"peering", "delete"}, time.Now())
// TODO(peering): ACL check request token
// TODO(peering): handle blocking queries
err = s.Backend.Apply().PeeringDelete(req)
if err != nil {
return nil, err
}
return &pbpeering.PeeringDeleteResponse{}, nil
}
type BidirectionalStream interface {
Send(*pbpeering.ReplicationMessage) error
Recv() (*pbpeering.ReplicationMessage, error)
Context() context.Context
}
// StreamResources handles incoming streaming connections.
func (s *Service) StreamResources(stream pbpeering.PeeringService_StreamResourcesServer) error {
// Initial message on a new stream must be a new subscription request.
first, err := stream.Recv()
if err != nil {
s.logger.Error("failed to establish stream", "error", err)
return err
}
// TODO(peering) Make request contain a list of resources, so that roots and services can be
// subscribed to with a single request. See:
// https://github.com/envoyproxy/data-plane-api/blob/main/envoy/service/discovery/v3/discovery.proto#L46
req := first.GetRequest()
if req == nil {
return grpcstatus.Error(codes.InvalidArgument, "first message when initiating a peering must be a subscription request")
}
s.logger.Trace("received initial replication request from peer")
logTraceRecv(s.logger, req)
if req.PeerID == "" {
return grpcstatus.Error(codes.InvalidArgument, "initial subscription request must specify a PeerID")
}
if req.Nonce != "" {
return grpcstatus.Error(codes.InvalidArgument, "initial subscription request must not contain a nonce")
}
if req.ResourceURL != pbpeering.TypeURLService {
return grpcstatus.Error(codes.InvalidArgument, fmt.Sprintf("subscription request to unknown resource URL: %s", req.ResourceURL))
}
// TODO(peering): Validate that a peering exists for this peer
// TODO(peering): If the peering is marked as deleted, send a Terminated message and return
// TODO(peering): Store subscription request so that an event publisher can separately handle pushing messages for it
s.logger.Info("accepted initial replication request from peer", "peer_id", req.PeerID)
// For server peers both of these ID values are the same, because we generated a token with a local ID,
// and the client peer dials using that same ID.
return s.HandleStream(req.PeerID, req.PeerID, stream)
}
// The localID provided is the locally-generated identifier for the peering.
// The remoteID is an identifier that the remote peer recognizes for the peering.
func (s *Service) HandleStream(localID, remoteID string, stream BidirectionalStream) error {
logger := s.logger.Named("stream").With("peer_id", localID)
logger.Trace("handling stream for peer")
status, err := s.streams.connected(localID)
if err != nil {
return fmt.Errorf("failed to register stream: %v", err)
}
// TODO(peering) Also need to clear subscriptions associated with the peer
defer s.streams.disconnected(localID)
mgr := newSubscriptionManager(stream.Context(), logger, s.Backend)
subCh := mgr.subscribe(stream.Context(), localID)
sub := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
PeerID: remoteID,
},
},
}
logTraceSend(logger, sub)
if err := stream.Send(sub); err != nil {
if err == io.EOF {
logger.Info("stream ended by peer")
status.trackReceiveError(err.Error())
return nil
}
// TODO(peering) Test error handling in calls to Send/Recv
status.trackSendError(err.Error())
return fmt.Errorf("failed to send to stream: %v", err)
}
// TODO(peering): Should this be buffered?
recvChan := make(chan *pbpeering.ReplicationMessage)
go func() {
defer close(recvChan)
for {
msg, err := stream.Recv()
if err == io.EOF {
logger.Info("stream ended by peer")
status.trackReceiveError(err.Error())
return
}
if e, ok := grpcstatus.FromError(err); ok {
// Cancelling the stream is not an error, that means we or our peer intended to terminate the peering.
if e.Code() == codes.Canceled {
return
}
}
if err != nil {
logger.Error("failed to receive from stream", "error", err)
status.trackReceiveError(err.Error())
return
}
logTraceRecv(logger, msg)
recvChan <- msg
}
}()
for {
select {
// When the doneCh is closed that means that the peering was deleted locally.
case <-status.doneCh:
logger.Info("ending stream")
term := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Terminated_{
Terminated: &pbpeering.ReplicationMessage_Terminated{},
},
}
logTraceSend(logger, term)
if err := stream.Send(term); err != nil {
status.trackSendError(err.Error())
return fmt.Errorf("failed to send to stream: %v", err)
}
logger.Trace("deleting stream status")
s.streams.deleteStatus(localID)
return nil
case msg, open := <-recvChan:
if !open {
// No longer receiving data on the stream.
return nil
}
if req := msg.GetRequest(); req != nil {
switch {
case req.Nonce == "":
// TODO(peering): This can happen on a client peer since they don't try to receive subscriptions before entering HandleStream.
// Should change that behavior or only allow it that one time.
case req.Error != nil && (req.Error.Code != int32(code.Code_OK) || req.Error.Message != ""):
logger.Warn("client peer was unable to apply resource", "code", req.Error.Code, "error", req.Error.Message)
status.trackNack(fmt.Sprintf("client peer was unable to apply resource: %s", req.Error.Message))
default:
status.trackAck()
}
continue
}
if resp := msg.GetResponse(); resp != nil {
req, err := processResponse(resp)
if err != nil {
logger.Error("failed to persist resource", "resourceURL", resp.ResourceURL, "resourceID", resp.ResourceID)
status.trackReceiveError(err.Error())
} else {
status.trackReceiveSuccess()
}
logTraceSend(logger, req)
if err := stream.Send(req); err != nil {
status.trackSendError(err.Error())
return fmt.Errorf("failed to send to stream: %v", err)
}
continue
}
if term := msg.GetTerminated(); term != nil {
logger.Info("received peering termination message, cleaning up imported resources")
// Once marked as terminated, a separate deferred deletion routine will clean up imported resources.
if err := s.Backend.Apply().PeeringTerminateByID(&pbpeering.PeeringTerminateByIDRequest{ID: localID}); err != nil {
return err
}
return nil
}
case update := <-subCh:
switch {
case strings.HasPrefix(update.CorrelationID, subExportedService):
if err := pushServiceResponse(logger, stream, status, update); err != nil {
return fmt.Errorf("failed to push data for %q: %w", update.CorrelationID, err)
}
default:
logger.Warn("unrecognized update type from subscription manager: " + update.CorrelationID)
continue
}
}
}
}
// pushService response handles sending exported service instance updates to the peer cluster.
// Each cache.UpdateEvent will contain all instances for a service name.
// If there are no instances in the event, we consider that to be a de-registration.
func pushServiceResponse(logger hclog.Logger, stream BidirectionalStream, status *lockableStreamStatus, update cache.UpdateEvent) error {
csn, ok := update.Result.(*pbservice.IndexedCheckServiceNodes)
if !ok {
logger.Error(fmt.Sprintf("invalid type for response: %T, expected *pbservice.IndexedCheckServiceNodes", update.Result))
// Skip this update to avoid locking up peering due to a bad service update.
return nil
}
serviceName := strings.TrimPrefix(update.CorrelationID, subExportedService)
// If no nodes are present then it's due to one of:
// 1. The service is newly registered or exported and yielded a transient empty update.
// 2. All instances of the service were de-registered.
// 3. The service was un-exported.
//
// We don't distinguish when these three things occurred, but it's safe to send a DELETE Op in all cases, so we do that.
// Case #1 is a no-op for the importing peer.
if len(csn.Nodes) == 0 {
resp := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Response_{
Response: &pbpeering.ReplicationMessage_Response{
ResourceURL: pbpeering.TypeURLService,
// TODO(peering): Nonce management
Nonce: "",
ResourceID: serviceName,
Operation: pbpeering.ReplicationMessage_Response_DELETE,
},
},
}
logTraceSend(logger, resp)
if err := stream.Send(resp); err != nil {
status.trackSendError(err.Error())
return fmt.Errorf("failed to send to stream: %v", err)
}
return nil
}
// If there are nodes in the response, we push them as an UPSERT operation.
any, err := ptypes.MarshalAny(csn)
if err != nil {
// Log the error and skip this response to avoid locking up peering due to a bad update event.
logger.Error("failed to marshal service endpoints", "error", err)
return nil
}
resp := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Response_{
Response: &pbpeering.ReplicationMessage_Response{
ResourceURL: pbpeering.TypeURLService,
// TODO(peering): Nonce management
Nonce: "",
ResourceID: serviceName,
Operation: pbpeering.ReplicationMessage_Response_UPSERT,
Resource: any,
},
},
}
logTraceSend(logger, resp)
if err := stream.Send(resp); err != nil {
status.trackSendError(err.Error())
return fmt.Errorf("failed to send to stream: %v", err)
}
return nil
}
func (s *Service) StreamStatus(peer string) (resp StreamStatus, found bool) {
return s.streams.streamStatus(peer)
}
// ConnectedStreams returns a map of connected stream IDs to the corresponding channel for tearing them down.
func (s *Service) ConnectedStreams() map[string]chan struct{} {
return s.streams.connectedStreams()
}
func makeReply(resourceURL, nonce string, errCode code.Code, errMsg string) *pbpeering.ReplicationMessage {
var rpcErr *pbstatus.Status
if errCode != code.Code_OK || errMsg != "" {
rpcErr = &pbstatus.Status{
Code: int32(errCode),
Message: errMsg,
}
}
msg := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: resourceURL,
Nonce: nonce,
Error: rpcErr,
},
},
}
return msg
}
func processResponse(resp *pbpeering.ReplicationMessage_Response) (*pbpeering.ReplicationMessage, error) {
var (
err error
errCode code.Code
errMsg string
)
if resp.ResourceURL != pbpeering.TypeURLService {
errCode = code.Code_INVALID_ARGUMENT
err = fmt.Errorf("received response for unknown resource type %q", resp.ResourceURL)
return makeReply(resp.ResourceURL, resp.Nonce, errCode, err.Error()), err
}
switch resp.Operation {
case pbpeering.ReplicationMessage_Response_UPSERT:
err = handleUpsert(resp.ResourceURL, resp.Resource)
if err != nil {
errCode = code.Code_INTERNAL
errMsg = err.Error()
}
case pbpeering.ReplicationMessage_Response_DELETE:
err = handleDelete(resp.ResourceURL, resp.ResourceID)
if err != nil {
errCode = code.Code_INTERNAL
errMsg = err.Error()
}
default:
errCode = code.Code_INVALID_ARGUMENT
op := pbpeering.ReplicationMessage_Response_Operation_name[int32(resp.Operation)]
if op == "" {
op = strconv.FormatInt(int64(resp.Operation), 10)
}
errMsg = fmt.Sprintf("unsupported operation: %q", op)
err = errors.New(errMsg)
}
return makeReply(resp.ResourceURL, resp.Nonce, errCode, errMsg), err
}
func handleUpsert(resourceURL string, resource *anypb.Any) error {
// TODO(peering): implement
return nil
}
func handleDelete(resourceURL string, resourceID string) error {
// TODO(peering): implement
return nil
}
func logTraceRecv(logger hclog.Logger, pb proto.Message) {
logTraceProto(logger, pb, true)
}
func logTraceSend(logger hclog.Logger, pb proto.Message) {
logTraceProto(logger, pb, false)
}
func logTraceProto(logger hclog.Logger, pb proto.Message, received bool) {
if !logger.IsTrace() {
return
}
dir := "sent"
if received {
dir = "received"
}
m := jsonpb.Marshaler{
Indent: " ",
}
out, err := m.MarshalToString(pb)
if err != nil {
out = "<ERROR: " + err.Error() + ">"
}
logger.Trace("replication message", "direction", dir, "protobuf", out)
}

View File

@ -0,0 +1,39 @@
//go:build !consulent
// +build !consulent
package peering_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/proto/pbpeering"
)
func TestPeeringService_RejectsPartition(t *testing.T) {
s := newTestServer(t, nil)
client := pbpeering.NewPeeringServiceClient(s.ClientConn(t))
t.Run("read", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
req := &pbpeering.PeeringReadRequest{Name: "foo", Partition: "default"}
resp, err := client.PeeringRead(ctx, req)
require.Contains(t, err.Error(), "Partitions are a Consul Enterprise feature")
require.Nil(t, resp)
})
t.Run("list", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
req := &pbpeering.PeeringListRequest{Partition: "default"}
resp, err := client.PeeringList(ctx, req)
require.Contains(t, err.Error(), "Partitions are a Consul Enterprise feature")
require.Nil(t, resp)
})
}

View File

@ -0,0 +1,414 @@
package peering_test
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"path"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
gogrpc "google.golang.org/grpc"
grpc "github.com/hashicorp/consul/agent/grpc/private"
"github.com/hashicorp/consul/agent/grpc/private/resolver"
"github.com/hashicorp/consul/proto/prototest"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/router"
"github.com/hashicorp/consul/agent/rpc/middleware"
"github.com/hashicorp/consul/agent/rpc/peering"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types"
)
func TestPeeringService_GenerateToken(t *testing.T) {
dir := testutil.TempDir(t, "consul")
signer, _, _ := tlsutil.GeneratePrivateKey()
ca, _, _ := tlsutil.GenerateCA(tlsutil.CAOpts{Signer: signer})
cafile := path.Join(dir, "cacert.pem")
require.NoError(t, ioutil.WriteFile(cafile, []byte(ca), 0600))
// TODO(peering): see note on newTestServer, refactor to not use this
s := newTestServer(t, func(c *consul.Config) {
c.SerfLANConfig.MemberlistConfig.AdvertiseAddr = "127.0.0.1"
c.TLSConfig.InternalRPC.CAFile = cafile
c.DataDir = dir
})
client := pbpeering.NewPeeringServiceClient(s.ClientConn(t))
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
req := pbpeering.GenerateTokenRequest{PeerName: "peerB", Datacenter: "dc1"}
resp, err := client.GenerateToken(ctx, &req)
require.NoError(t, err)
tokenJSON, err := base64.StdEncoding.DecodeString(resp.PeeringToken)
require.NoError(t, err)
token := &structs.PeeringToken{}
require.NoError(t, json.Unmarshal(tokenJSON, token))
require.Equal(t, "server.dc1.consul", token.ServerName)
require.Len(t, token.ServerAddresses, 1)
require.Equal(t, "127.0.0.1:2345", token.ServerAddresses[0])
require.Equal(t, []string{ca}, token.CA)
require.NotEmpty(t, token.PeerID)
_, err = uuid.ParseUUID(token.PeerID)
require.NoError(t, err)
_, peers, err := s.Server.FSM().State().PeeringList(nil, *structs.DefaultEnterpriseMetaInDefaultPartition())
require.NoError(t, err)
require.Len(t, peers, 1)
peers[0].ModifyIndex = 0
peers[0].CreateIndex = 0
expect := &pbpeering.Peering{
Name: "peerB",
Partition: acl.DefaultPartitionName,
ID: token.PeerID,
State: pbpeering.PeeringState_INITIAL,
}
require.Equal(t, expect, peers[0])
}
func TestPeeringService_Initiate(t *testing.T) {
validToken := peering.TestPeeringToken("83474a06-cca4-4ff4-99a4-4152929c8160")
validTokenJSON, _ := json.Marshal(&validToken)
validTokenB64 := base64.StdEncoding.EncodeToString(validTokenJSON)
// TODO(peering): see note on newTestServer, refactor to not use this
s := newTestServer(t, nil)
client := pbpeering.NewPeeringServiceClient(s.ClientConn(t))
type testcase struct {
name string
req *pbpeering.InitiateRequest
expectResp *pbpeering.InitiateResponse
expectPeering *pbpeering.Peering
expectErr string
}
run := func(t *testing.T, tc testcase) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
resp, err := client.Initiate(ctx, tc.req)
if tc.expectErr != "" {
require.Contains(t, err.Error(), tc.expectErr)
return
}
require.NoError(t, err)
prototest.AssertDeepEqual(t, tc.expectResp, resp)
// if a peering was expected to be written, try to read it back
if tc.expectPeering != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
resp, err := client.PeeringRead(ctx, &pbpeering.PeeringReadRequest{Name: tc.expectPeering.Name})
require.NoError(t, err)
// check individual values we care about since we don't know exactly
// what the create/modify indexes will be
require.Equal(t, tc.expectPeering.Name, resp.Peering.Name)
require.Equal(t, tc.expectPeering.Partition, resp.Peering.Partition)
require.Equal(t, tc.expectPeering.State, resp.Peering.State)
require.Equal(t, tc.expectPeering.PeerCAPems, resp.Peering.PeerCAPems)
require.Equal(t, tc.expectPeering.PeerServerAddresses, resp.Peering.PeerServerAddresses)
require.Equal(t, tc.expectPeering.PeerServerName, resp.Peering.PeerServerName)
}
}
tcs := []testcase{
{
name: "invalid peer name",
req: &pbpeering.InitiateRequest{PeerName: "--AA--"},
expectErr: "--AA-- is not a valid peer name",
},
{
name: "invalid token (base64)",
req: &pbpeering.InitiateRequest{
PeerName: "peer1-usw1",
PeeringToken: "+++/+++",
},
expectErr: "illegal base64 data",
},
{
name: "invalid token (JSON)",
req: &pbpeering.InitiateRequest{
PeerName: "peer1-usw1",
PeeringToken: "Cg==", // base64 of "-"
},
expectErr: "unexpected end of JSON input",
},
{
name: "invalid token (empty)",
req: &pbpeering.InitiateRequest{
PeerName: "peer1-usw1",
PeeringToken: "e30K", // base64 of "{}"
},
expectErr: "peering token CA value is empty",
},
{
name: "success",
req: &pbpeering.InitiateRequest{
PeerName: "peer1-usw1",
PeeringToken: validTokenB64,
},
expectResp: &pbpeering.InitiateResponse{},
expectPeering: peering.TestPeering(
"peer1-usw1",
pbpeering.PeeringState_INITIAL,
),
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func TestPeeringService_Read(t *testing.T) {
// TODO(peering): see note on newTestServer, refactor to not use this
s := newTestServer(t, nil)
// insert peering directly to state store
p := &pbpeering.Peering{
Name: "foo",
State: pbpeering.PeeringState_INITIAL,
PeerCAPems: nil,
PeerServerName: "test",
PeerServerAddresses: []string{"addr1"},
}
err := s.Server.FSM().State().PeeringWrite(10, p)
require.NoError(t, err)
client := pbpeering.NewPeeringServiceClient(s.ClientConn(t))
type testcase struct {
name string
req *pbpeering.PeeringReadRequest
expect *pbpeering.PeeringReadResponse
expectErr string
}
run := func(t *testing.T, tc testcase) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
resp, err := client.PeeringRead(ctx, tc.req)
if tc.expectErr != "" {
require.Contains(t, err.Error(), tc.expectErr)
return
}
require.NoError(t, err)
prototest.AssertDeepEqual(t, tc.expect, resp)
}
tcs := []testcase{
{
name: "returns foo",
req: &pbpeering.PeeringReadRequest{Name: "foo"},
expect: &pbpeering.PeeringReadResponse{Peering: p},
expectErr: "",
},
{
name: "bar not found",
req: &pbpeering.PeeringReadRequest{Name: "bar"},
expect: &pbpeering.PeeringReadResponse{},
expectErr: "",
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func TestPeeringService_List(t *testing.T) {
// TODO(peering): see note on newTestServer, refactor to not use this
s := newTestServer(t, nil)
// Insert peerings directly to state store.
// Note that the state store holds reference to the underlying
// variables; do not modify them after writing.
foo := &pbpeering.Peering{
Name: "foo",
State: pbpeering.PeeringState_INITIAL,
PeerCAPems: nil,
PeerServerName: "fooservername",
PeerServerAddresses: []string{"addr1"},
}
require.NoError(t, s.Server.FSM().State().PeeringWrite(10, foo))
bar := &pbpeering.Peering{
Name: "bar",
State: pbpeering.PeeringState_ACTIVE,
PeerCAPems: nil,
PeerServerName: "barservername",
PeerServerAddresses: []string{"addr1"},
}
require.NoError(t, s.Server.FSM().State().PeeringWrite(15, bar))
client := pbpeering.NewPeeringServiceClient(s.ClientConn(t))
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
resp, err := client.PeeringList(ctx, &pbpeering.PeeringListRequest{})
require.NoError(t, err)
expect := &pbpeering.PeeringListResponse{
Peerings: []*pbpeering.Peering{bar, foo},
}
prototest.AssertDeepEqual(t, expect, resp)
}
// newTestServer is copied from partition/service_test.go, with the addition of certs/cas.
// TODO(peering): these are endpoint tests and should live in the agent/consul
// package. Instead, these can be written around a mock client (see testing.go)
// and a mock backend (future)
func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
t.Helper()
conf := consul.DefaultConfig()
dir := testutil.TempDir(t, "consul")
conf.Bootstrap = true
conf.Datacenter = "dc1"
conf.DataDir = dir
conf.RPCAddr = &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: 2345}
conf.RaftConfig.ElectionTimeout = 200 * time.Millisecond
conf.RaftConfig.LeaderLeaseTimeout = 100 * time.Millisecond
conf.RaftConfig.HeartbeatTimeout = 200 * time.Millisecond
conf.TLSConfig.Domain = "consul"
nodeID, err := uuid.GenerateUUID()
if err != nil {
t.Fatal(err)
}
conf.NodeID = types.NodeID(nodeID)
if cb != nil {
cb(conf)
}
// Apply config to copied fields because many tests only set the old
// values.
conf.ACLResolverSettings.ACLsEnabled = conf.ACLsEnabled
conf.ACLResolverSettings.NodeName = conf.NodeName
conf.ACLResolverSettings.Datacenter = conf.Datacenter
conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta()
deps := newDefaultDeps(t, conf)
server, err := consul.NewServer(conf, deps, gogrpc.NewServer())
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, server.Shutdown())
})
testrpc.WaitForLeader(t, server.RPC, conf.Datacenter)
backend := consul.NewPeeringBackend(server, deps.GRPCConnPool)
handler := &peering.Service{Backend: backend}
grpcServer := gogrpc.NewServer()
pbpeering.RegisterPeeringServiceServer(grpcServer, handler)
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() { lis.Close() })
g := new(errgroup.Group)
g.Go(func() error {
return grpcServer.Serve(lis)
})
t.Cleanup(func() {
if grpcServer.Stop(); err != nil {
t.Logf("grpc server shutdown: %v", err)
}
if err := g.Wait(); err != nil {
t.Logf("grpc server error: %v", err)
}
})
return testingServer{
Server: server,
Backend: backend,
Addr: lis.Addr(),
}
}
func (s testingServer) ClientConn(t *testing.T) *gogrpc.ClientConn {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
conn, err := gogrpc.DialContext(ctx, s.Addr.String(), gogrpc.WithInsecure())
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })
return conn
}
type testingServer struct {
Server *consul.Server
Addr net.Addr
Backend peering.Backend
}
// TODO(peering): remove duplication between this and agent/consul tests
func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps {
t.Helper()
logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{
Name: c.NodeName,
Level: hclog.Debug,
Output: testutil.NewLogBuffer(t),
})
tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger)
require.NoError(t, err, "failed to create tls configuration")
r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), nil)
builder := resolver.NewServerResolverBuilder(resolver.Config{})
resolver.Register(builder)
connPool := &pool.ConnPool{
Server: false,
SrcAddr: c.RPCSrcAddr,
Logger: logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}),
MaxTime: 2 * time.Minute,
MaxStreams: 4,
TLSConfigurator: tls,
Datacenter: c.Datacenter,
}
return consul.Deps{
Logger: logger,
TLSConfigurator: tls,
Tokens: new(token.Store),
Router: r,
ConnPool: connPool,
GRPCConnPool: grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
TLSWrapper: grpc.TLSWrapper(tls.OutgoingRPCWrapper()),
UseTLSForDC: tls.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: c.Datacenter,
}),
LeaderForwarder: builder,
EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c),
NewRequestRecorderFunc: middleware.NewRequestRecorder,
GetNetRPCInterceptorFunc: middleware.GetNetRPCInterceptor,
}
}

View File

@ -0,0 +1,810 @@
package peering
import (
"context"
"io"
"testing"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/stretchr/testify/require"
"google.golang.org/genproto/googleapis/rpc/code"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/proto/pbservice"
"github.com/hashicorp/consul/proto/pbstatus"
"github.com/hashicorp/consul/proto/prototest"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
)
func TestStreamResources_Server_FirstRequest(t *testing.T) {
type testCase struct {
name string
input *pbpeering.ReplicationMessage
wantErr error
}
run := func(t *testing.T, tc testCase) {
srv := NewService(testutil.Logger(t), nil)
client := newMockClient(context.Background())
errCh := make(chan error, 1)
client.errCh = errCh
go func() {
// Pass errors from server handler into errCh so that they can be seen by the client on Recv().
// This matches gRPC's behavior when an error is returned by a server.
err := srv.StreamResources(client.replicationStream)
if err != nil {
errCh <- err
}
}()
err := client.Send(tc.input)
require.NoError(t, err)
msg, err := client.Recv()
require.Nil(t, msg)
require.Error(t, err)
require.EqualError(t, err, tc.wantErr.Error())
}
tt := []testCase{
{
name: "unexpected response",
input: &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Response_{
Response: &pbpeering.ReplicationMessage_Response{
ResourceURL: pbpeering.TypeURLService,
ResourceID: "api-service",
Nonce: "2",
},
},
},
wantErr: status.Error(codes.InvalidArgument, "first message when initiating a peering must be a subscription request"),
},
{
name: "missing peer id",
input: &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{},
},
},
wantErr: status.Error(codes.InvalidArgument, "initial subscription request must specify a PeerID"),
},
{
name: "unexpected nonce",
input: &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
PeerID: "63b60245-c475-426b-b314-4588d210859d",
Nonce: "1",
},
},
},
wantErr: status.Error(codes.InvalidArgument, "initial subscription request must not contain a nonce"),
},
{
name: "unknown resource",
input: &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
PeerID: "63b60245-c475-426b-b314-4588d210859d",
ResourceURL: "nomad.Job",
},
},
},
wantErr: status.Error(codes.InvalidArgument, "subscription request to unknown resource URL: nomad.Job"),
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func TestStreamResources_Server_Terminate(t *testing.T) {
publisher := stream.NewEventPublisher(10 * time.Second)
store := newStateStore(t, publisher)
srv := NewService(testutil.Logger(t), &testStreamBackend{
store: store,
pub: publisher,
})
it := incrementalTime{
base: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
}
srv.streams.timeNow = it.Now
client := newMockClient(context.Background())
errCh := make(chan error, 1)
client.errCh = errCh
go func() {
// Pass errors from server handler into errCh so that they can be seen by the client on Recv().
// This matches gRPC's behavior when an error is returned by a server.
if err := srv.StreamResources(client.replicationStream); err != nil {
errCh <- err
}
}()
// Receive a subscription from a peer
peerID := "63b60245-c475-426b-b314-4588d210859d"
sub := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
PeerID: peerID,
ResourceURL: pbpeering.TypeURLService,
},
},
}
err := client.Send(sub)
require.NoError(t, err)
runStep(t, "new stream gets tracked", func(t *testing.T) {
retry.Run(t, func(r *retry.R) {
status, ok := srv.StreamStatus(peerID)
require.True(r, ok)
require.True(r, status.Connected)
})
})
// Receive subscription to my-peer-B's resources
receivedSub, err := client.Recv()
require.NoError(t, err)
expect := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
PeerID: peerID,
},
},
}
prototest.AssertDeepEqual(t, expect, receivedSub)
runStep(t, "terminate the stream", func(t *testing.T) {
done := srv.ConnectedStreams()[peerID]
close(done)
retry.Run(t, func(r *retry.R) {
_, ok := srv.StreamStatus(peerID)
require.False(r, ok)
})
})
receivedTerm, err := client.Recv()
require.NoError(t, err)
expect = &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Terminated_{
Terminated: &pbpeering.ReplicationMessage_Terminated{},
},
}
prototest.AssertDeepEqual(t, expect, receivedTerm)
}
func TestStreamResources_Server_StreamTracker(t *testing.T) {
publisher := stream.NewEventPublisher(10 * time.Second)
store := newStateStore(t, publisher)
srv := NewService(testutil.Logger(t), &testStreamBackend{
store: store,
pub: publisher,
})
it := incrementalTime{
base: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
}
srv.streams.timeNow = it.Now
client := newMockClient(context.Background())
errCh := make(chan error, 1)
go func() {
errCh <- srv.StreamResources(client.replicationStream)
}()
peerID := "63b60245-c475-426b-b314-4588d210859d"
sub := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
PeerID: peerID,
ResourceURL: pbpeering.TypeURLService,
},
},
}
err := client.Send(sub)
require.NoError(t, err)
runStep(t, "new stream gets tracked", func(t *testing.T) {
retry.Run(t, func(r *retry.R) {
status, ok := srv.StreamStatus(peerID)
require.True(r, ok)
require.True(r, status.Connected)
})
})
runStep(t, "client receives initial subscription", func(t *testing.T) {
ack, err := client.Recv()
require.NoError(t, err)
expectAck := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
PeerID: peerID,
Nonce: "",
},
},
}
prototest.AssertDeepEqual(t, expectAck, ack)
})
var sequence uint64
var lastSendSuccess time.Time
runStep(t, "ack tracked as success", func(t *testing.T) {
ack := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
PeerID: peerID,
ResourceURL: pbpeering.TypeURLService,
Nonce: "1",
// Acks do not have an Error populated in the request
},
},
}
err := client.Send(ack)
require.NoError(t, err)
sequence++
lastSendSuccess = it.base.Add(time.Duration(sequence) * time.Second).UTC()
expect := StreamStatus{
Connected: true,
LastAck: lastSendSuccess,
}
retry.Run(t, func(r *retry.R) {
status, ok := srv.StreamStatus(peerID)
require.True(r, ok)
require.Equal(r, expect, status)
})
})
var lastNack time.Time
var lastNackMsg string
runStep(t, "nack tracked as error", func(t *testing.T) {
nack := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
PeerID: peerID,
ResourceURL: pbpeering.TypeURLService,
Nonce: "2",
Error: &pbstatus.Status{
Code: int32(code.Code_UNAVAILABLE),
Message: "bad bad not good",
},
},
},
}
err := client.Send(nack)
require.NoError(t, err)
sequence++
lastNackMsg = "client peer was unable to apply resource: bad bad not good"
lastNack = it.base.Add(time.Duration(sequence) * time.Second).UTC()
expect := StreamStatus{
Connected: true,
LastAck: lastSendSuccess,
LastNack: lastNack,
LastNackMessage: lastNackMsg,
}
retry.Run(t, func(r *retry.R) {
status, ok := srv.StreamStatus(peerID)
require.True(r, ok)
require.Equal(r, expect, status)
})
})
var lastRecvSuccess time.Time
runStep(t, "response applied locally", func(t *testing.T) {
resp := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Response_{
Response: &pbpeering.ReplicationMessage_Response{
ResourceURL: pbpeering.TypeURLService,
ResourceID: "api",
Nonce: "21",
Operation: pbpeering.ReplicationMessage_Response_UPSERT,
},
},
}
err := client.Send(resp)
require.NoError(t, err)
sequence++
ack, err := client.Recv()
require.NoError(t, err)
expectAck := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
Nonce: "21",
},
},
}
prototest.AssertDeepEqual(t, expectAck, ack)
lastRecvSuccess = it.base.Add(time.Duration(sequence) * time.Second).UTC()
expect := StreamStatus{
Connected: true,
LastAck: lastSendSuccess,
LastNack: lastNack,
LastNackMessage: lastNackMsg,
LastReceiveSuccess: lastRecvSuccess,
}
retry.Run(t, func(r *retry.R) {
status, ok := srv.StreamStatus(peerID)
require.True(r, ok)
require.Equal(r, expect, status)
})
})
var lastRecvError time.Time
var lastRecvErrorMsg string
runStep(t, "response fails to apply locally", func(t *testing.T) {
resp := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Response_{
Response: &pbpeering.ReplicationMessage_Response{
ResourceURL: pbpeering.TypeURLService,
ResourceID: "web",
Nonce: "24",
// Unknown operation gets NACKed
Operation: pbpeering.ReplicationMessage_Response_Unknown,
},
},
}
err := client.Send(resp)
require.NoError(t, err)
sequence++
ack, err := client.Recv()
require.NoError(t, err)
expectNack := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
Nonce: "24",
Error: &pbstatus.Status{
Code: int32(code.Code_INVALID_ARGUMENT),
Message: `unsupported operation: "Unknown"`,
},
},
},
}
prototest.AssertDeepEqual(t, expectNack, ack)
lastRecvError = it.base.Add(time.Duration(sequence) * time.Second).UTC()
lastRecvErrorMsg = `unsupported operation: "Unknown"`
expect := StreamStatus{
Connected: true,
LastAck: lastSendSuccess,
LastNack: lastNack,
LastNackMessage: lastNackMsg,
LastReceiveSuccess: lastRecvSuccess,
LastReceiveError: lastRecvError,
LastReceiveErrorMessage: lastRecvErrorMsg,
}
retry.Run(t, func(r *retry.R) {
status, ok := srv.StreamStatus(peerID)
require.True(r, ok)
require.Equal(r, expect, status)
})
})
runStep(t, "client disconnect marks stream as disconnected", func(t *testing.T) {
client.Close()
sequence++
lastRecvError := it.base.Add(time.Duration(sequence) * time.Second).UTC()
sequence++
disconnectTime := it.base.Add(time.Duration(sequence) * time.Second).UTC()
expect := StreamStatus{
Connected: false,
LastAck: lastSendSuccess,
LastNack: lastNack,
LastNackMessage: lastNackMsg,
DisconnectTime: disconnectTime,
LastReceiveSuccess: lastRecvSuccess,
LastReceiveErrorMessage: io.EOF.Error(),
LastReceiveError: lastRecvError,
}
retry.Run(t, func(r *retry.R) {
status, ok := srv.StreamStatus(peerID)
require.True(r, ok)
require.Equal(r, expect, status)
})
})
select {
case err := <-errCh:
// Client disconnect is not an error, but should make the handler return.
require.NoError(t, err)
case <-time.After(50 * time.Millisecond):
t.Fatalf("timed out waiting for handler to finish")
}
}
func TestStreamResources_Server_ServiceUpdates(t *testing.T) {
publisher := stream.NewEventPublisher(10 * time.Second)
store := newStateStore(t, publisher)
// Create a peering
var lastIdx uint64 = 1
err := store.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "my-peering",
})
require.NoError(t, err)
_, p, err := store.PeeringRead(nil, state.Query{Value: "my-peering"})
require.NoError(t, err)
require.NotNil(t, p)
srv := NewService(testutil.Logger(t), &testStreamBackend{
store: store,
pub: publisher,
})
client := newMockClient(context.Background())
errCh := make(chan error, 1)
client.errCh = errCh
go func() {
// Pass errors from server handler into errCh so that they can be seen by the client on Recv().
// This matches gRPC's behavior when an error is returned by a server.
if err := srv.StreamResources(client.replicationStream); err != nil {
errCh <- err
}
}()
// Issue a services subscription to server
init := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
PeerID: p.ID,
ResourceURL: pbpeering.TypeURLService,
},
},
}
require.NoError(t, client.Send(init))
// Receive a services subscription from server
receivedSub, err := client.Recv()
require.NoError(t, err)
expect := &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
PeerID: p.ID,
},
},
}
prototest.AssertDeepEqual(t, expect, receivedSub)
// Register a service that is not yet exported
mysql := &structs.CheckServiceNode{
Node: &structs.Node{Node: "foo", Address: "10.0.0.1"},
Service: &structs.NodeService{ID: "mysql-1", Service: "mysql", Port: 5000},
}
lastIdx++
require.NoError(t, store.EnsureNode(lastIdx, mysql.Node))
lastIdx++
require.NoError(t, store.EnsureService(lastIdx, "foo", mysql.Service))
runStep(t, "exporting mysql leads to an UPSERT event", func(t *testing.T) {
entry := &structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "mysql",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peering",
},
},
},
{
// Mongo does not get pushed because it does not have instances registered.
Name: "mongo",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peering",
},
},
},
},
}
lastIdx++
err = store.EnsureConfigEntry(lastIdx, entry)
require.NoError(t, err)
retry.Run(t, func(r *retry.R) {
msg, err := client.RecvWithTimeout(100 * time.Millisecond)
require.NoError(r, err)
require.Equal(r, pbpeering.ReplicationMessage_Response_UPSERT, msg.GetResponse().Operation)
require.Equal(r, mysql.Service.CompoundServiceName().String(), msg.GetResponse().ResourceID)
var nodes pbservice.IndexedCheckServiceNodes
require.NoError(r, ptypes.UnmarshalAny(msg.GetResponse().Resource, &nodes))
require.Len(r, nodes.Nodes, 1)
})
})
mongo := &structs.CheckServiceNode{
Node: &structs.Node{Node: "zip", Address: "10.0.0.3"},
Service: &structs.NodeService{ID: "mongo-1", Service: "mongo", Port: 5000},
}
runStep(t, "registering mongo instance leads to an UPSERT event", func(t *testing.T) {
lastIdx++
require.NoError(t, store.EnsureNode(lastIdx, mongo.Node))
lastIdx++
require.NoError(t, store.EnsureService(lastIdx, "zip", mongo.Service))
retry.Run(t, func(r *retry.R) {
msg, err := client.RecvWithTimeout(100 * time.Millisecond)
require.NoError(r, err)
require.Equal(r, pbpeering.ReplicationMessage_Response_UPSERT, msg.GetResponse().Operation)
require.Equal(r, mongo.Service.CompoundServiceName().String(), msg.GetResponse().ResourceID)
var nodes pbservice.IndexedCheckServiceNodes
require.NoError(r, ptypes.UnmarshalAny(msg.GetResponse().Resource, &nodes))
require.Len(r, nodes.Nodes, 1)
})
})
runStep(t, "un-exporting mysql leads to a DELETE event for mysql", func(t *testing.T) {
entry := &structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "mongo",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peering",
},
},
},
},
}
lastIdx++
err = store.EnsureConfigEntry(lastIdx, entry)
require.NoError(t, err)
retry.Run(t, func(r *retry.R) {
msg, err := client.RecvWithTimeout(100 * time.Millisecond)
require.NoError(r, err)
require.Equal(r, pbpeering.ReplicationMessage_Response_DELETE, msg.GetResponse().Operation)
require.Equal(r, mysql.Service.CompoundServiceName().String(), msg.GetResponse().ResourceID)
require.Nil(r, msg.GetResponse().Resource)
})
})
runStep(t, "deleting the config entry leads to a DELETE event for mongo", func(t *testing.T) {
lastIdx++
err = store.DeleteConfigEntry(lastIdx, structs.ExportedServices, "default", nil)
require.NoError(t, err)
retry.Run(t, func(r *retry.R) {
msg, err := client.RecvWithTimeout(100 * time.Millisecond)
require.NoError(r, err)
require.Equal(r, pbpeering.ReplicationMessage_Response_DELETE, msg.GetResponse().Operation)
require.Equal(r, mongo.Service.CompoundServiceName().String(), msg.GetResponse().ResourceID)
require.Nil(r, msg.GetResponse().Resource)
})
})
}
type testStreamBackend struct {
pub state.EventPublisher
store *state.Store
}
func (b *testStreamBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) {
return b.pub.Subscribe(req)
}
func (b *testStreamBackend) Store() Store {
return b.store
}
func (b *testStreamBackend) Forward(info structs.RPCInfo, f func(conn *grpc.ClientConn) error) (handled bool, err error) {
return true, nil
}
func (b *testStreamBackend) GetAgentCACertificates() ([]string, error) {
return []string{}, nil
}
func (b *testStreamBackend) GetServerAddresses() ([]string, error) {
return []string{}, nil
}
func (b *testStreamBackend) GetServerName() string {
return ""
}
func (b *testStreamBackend) EncodeToken(tok *structs.PeeringToken) ([]byte, error) {
return nil, nil
}
func (b *testStreamBackend) DecodeToken([]byte) (*structs.PeeringToken, error) {
return nil, nil
}
func (b *testStreamBackend) EnterpriseCheckPartitions(partition string) error {
return nil
}
func (b *testStreamBackend) Apply() Apply {
return nil
}
func Test_processResponse(t *testing.T) {
type testCase struct {
name string
in *pbpeering.ReplicationMessage_Response
expect *pbpeering.ReplicationMessage
wantErr bool
}
run := func(t *testing.T, tc testCase) {
reply, err := processResponse(tc.in)
if tc.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, tc.expect, reply)
}
tt := []testCase{
{
name: "valid upsert",
in: &pbpeering.ReplicationMessage_Response{
ResourceURL: pbpeering.TypeURLService,
ResourceID: "api",
Nonce: "1",
Operation: pbpeering.ReplicationMessage_Response_UPSERT,
},
expect: &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
Nonce: "1",
},
},
},
wantErr: false,
},
{
name: "valid delete",
in: &pbpeering.ReplicationMessage_Response{
ResourceURL: pbpeering.TypeURLService,
ResourceID: "api",
Nonce: "1",
Operation: pbpeering.ReplicationMessage_Response_DELETE,
},
expect: &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
Nonce: "1",
},
},
},
wantErr: false,
},
{
name: "invalid resource url",
in: &pbpeering.ReplicationMessage_Response{
ResourceURL: "nomad.Job",
Nonce: "1",
Operation: pbpeering.ReplicationMessage_Response_Unknown,
},
expect: &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: "nomad.Job",
Nonce: "1",
Error: &pbstatus.Status{
Code: int32(code.Code_INVALID_ARGUMENT),
Message: `received response for unknown resource type "nomad.Job"`,
},
},
},
},
wantErr: true,
},
{
name: "unknown operation",
in: &pbpeering.ReplicationMessage_Response{
ResourceURL: pbpeering.TypeURLService,
Nonce: "1",
Operation: pbpeering.ReplicationMessage_Response_Unknown,
},
expect: &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
Nonce: "1",
Error: &pbstatus.Status{
Code: int32(code.Code_INVALID_ARGUMENT),
Message: `unsupported operation: "Unknown"`,
},
},
},
},
wantErr: true,
},
{
name: "out of range operation",
in: &pbpeering.ReplicationMessage_Response{
ResourceURL: pbpeering.TypeURLService,
Nonce: "1",
Operation: pbpeering.ReplicationMessage_Response_Operation(100000),
},
expect: &pbpeering.ReplicationMessage{
Payload: &pbpeering.ReplicationMessage_Request_{
Request: &pbpeering.ReplicationMessage_Request{
ResourceURL: pbpeering.TypeURLService,
Nonce: "1",
Error: &pbstatus.Status{
Code: int32(code.Code_INVALID_ARGUMENT),
Message: `unsupported operation: "100000"`,
},
},
},
},
wantErr: true,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}

View File

@ -0,0 +1,212 @@
package peering
import (
"fmt"
"sync"
"time"
)
// streamTracker contains a map of (PeerID -> StreamStatus).
// As streams are opened and closed we track details about their status.
type streamTracker struct {
mu sync.RWMutex
streams map[string]*lockableStreamStatus
// timeNow is a shim for testing.
timeNow func() time.Time
}
func newStreamTracker() *streamTracker {
return &streamTracker{
streams: make(map[string]*lockableStreamStatus),
timeNow: time.Now,
}
}
// connected registers a stream for a given peer, and marks it as connected.
// It also enforces that there is only one active stream for a peer.
func (t *streamTracker) connected(id string) (*lockableStreamStatus, error) {
t.mu.Lock()
defer t.mu.Unlock()
status, ok := t.streams[id]
if !ok {
status = newLockableStreamStatus(t.timeNow)
t.streams[id] = status
return status, nil
}
if status.connected() {
return nil, fmt.Errorf("there is an active stream for the given PeerID %q", id)
}
status.trackConnected()
return status, nil
}
// disconnected ensures that if a peer id's stream status is tracked, it is marked as disconnected.
func (t *streamTracker) disconnected(id string) {
t.mu.Lock()
defer t.mu.Unlock()
if status, ok := t.streams[id]; ok {
status.trackDisconnected()
}
}
func (t *streamTracker) streamStatus(id string) (resp StreamStatus, found bool) {
t.mu.RLock()
defer t.mu.RUnlock()
s, ok := t.streams[id]
if !ok {
return StreamStatus{}, false
}
return s.status(), true
}
func (t *streamTracker) connectedStreams() map[string]chan struct{} {
t.mu.RLock()
defer t.mu.RUnlock()
resp := make(map[string]chan struct{})
for peer, status := range t.streams {
if status.connected() {
resp[peer] = status.doneCh
}
}
return resp
}
func (t *streamTracker) deleteStatus(id string) {
t.mu.Lock()
defer t.mu.Unlock()
delete(t.streams, id)
}
type lockableStreamStatus struct {
mu sync.RWMutex
// timeNow is a shim for testing.
timeNow func() time.Time
// doneCh allows for shutting down a stream gracefully by sending a termination message
// to the peer before the stream's context is cancelled.
doneCh chan struct{}
StreamStatus
}
// StreamStatus contains information about the replication stream to a peer cluster.
// TODO(peering): There's a lot of fields here...
type StreamStatus struct {
// Connected is true when there is an open stream for the peer.
Connected bool
// If the status is not connected, DisconnectTime tracks when the stream was closed. Else it's zero.
DisconnectTime time.Time
// LastAck tracks the time we received the last ACK for a resource replicated TO the peer.
LastAck time.Time
// LastNack tracks the time we received the last NACK for a resource replicated to the peer.
LastNack time.Time
// LastNackMessage tracks the reported error message associated with the last NACK from a peer.
LastNackMessage string
// LastSendError tracks the time of the last error sending into the stream.
LastSendError time.Time
// LastSendErrorMessage tracks the last error message when sending into the stream.
LastSendErrorMessage string
// LastReceiveSuccess tracks the time we last successfully stored a resource replicated FROM the peer.
LastReceiveSuccess time.Time
// LastReceiveError tracks either:
// - The time we failed to store a resource replicated FROM the peer.
// - The time of the last error when receiving from the stream.
LastReceiveError time.Time
// LastReceiveError tracks either:
// - The error message when we failed to store a resource replicated FROM the peer.
// - The last error message when receiving from the stream.
LastReceiveErrorMessage string
}
func newLockableStreamStatus(now func() time.Time) *lockableStreamStatus {
return &lockableStreamStatus{
StreamStatus: StreamStatus{
Connected: true,
},
timeNow: now,
doneCh: make(chan struct{}),
}
}
func (s *lockableStreamStatus) trackAck() {
s.mu.Lock()
s.LastAck = s.timeNow().UTC()
s.mu.Unlock()
}
func (s *lockableStreamStatus) trackSendError(error string) {
s.mu.Lock()
s.LastSendError = s.timeNow().UTC()
s.LastSendErrorMessage = error
s.mu.Unlock()
}
func (s *lockableStreamStatus) trackReceiveSuccess() {
s.mu.Lock()
s.LastReceiveSuccess = s.timeNow().UTC()
s.mu.Unlock()
}
func (s *lockableStreamStatus) trackReceiveError(error string) {
s.mu.Lock()
s.LastReceiveError = s.timeNow().UTC()
s.LastReceiveErrorMessage = error
s.mu.Unlock()
}
func (s *lockableStreamStatus) trackNack(msg string) {
s.mu.Lock()
s.LastNack = s.timeNow().UTC()
s.LastNackMessage = msg
s.mu.Unlock()
}
func (s *lockableStreamStatus) trackConnected() {
s.mu.Lock()
s.Connected = true
s.DisconnectTime = time.Time{}
s.mu.Unlock()
}
func (s *lockableStreamStatus) trackDisconnected() {
s.mu.Lock()
s.Connected = false
s.DisconnectTime = s.timeNow().UTC()
s.mu.Unlock()
}
func (s *lockableStreamStatus) connected() bool {
var resp bool
s.mu.RLock()
resp = s.Connected
s.mu.RUnlock()
return resp
}
func (s *lockableStreamStatus) status() StreamStatus {
s.mu.RLock()
copy := s.StreamStatus
s.mu.RUnlock()
return copy
}

View File

@ -0,0 +1,162 @@
package peering
import (
"sort"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestStreamTracker_EnsureConnectedDisconnected(t *testing.T) {
tracker := newStreamTracker()
peerID := "63b60245-c475-426b-b314-4588d210859d"
it := incrementalTime{
base: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
}
tracker.timeNow = it.Now
var (
statusPtr *lockableStreamStatus
err error
)
runStep(t, "new stream", func(t *testing.T) {
statusPtr, err = tracker.connected(peerID)
require.NoError(t, err)
expect := StreamStatus{
Connected: true,
}
status, ok := tracker.streamStatus(peerID)
require.True(t, ok)
require.Equal(t, expect, status)
})
runStep(t, "duplicate gets rejected", func(t *testing.T) {
_, err := tracker.connected(peerID)
require.Error(t, err)
require.Contains(t, err.Error(), `there is an active stream for the given PeerID "63b60245-c475-426b-b314-4588d210859d"`)
})
var sequence uint64
var lastSuccess time.Time
runStep(t, "stream updated", func(t *testing.T) {
statusPtr.trackAck()
sequence++
status, ok := tracker.streamStatus(peerID)
require.True(t, ok)
lastSuccess = it.base.Add(time.Duration(sequence) * time.Second).UTC()
expect := StreamStatus{
Connected: true,
LastAck: lastSuccess,
}
require.Equal(t, expect, status)
})
runStep(t, "disconnect", func(t *testing.T) {
tracker.disconnected(peerID)
sequence++
expect := StreamStatus{
Connected: false,
DisconnectTime: it.base.Add(time.Duration(sequence) * time.Second).UTC(),
LastAck: lastSuccess,
}
status, ok := tracker.streamStatus(peerID)
require.True(t, ok)
require.Equal(t, expect, status)
})
runStep(t, "re-connect", func(t *testing.T) {
_, err := tracker.connected(peerID)
require.NoError(t, err)
expect := StreamStatus{
Connected: true,
LastAck: lastSuccess,
// DisconnectTime gets cleared on re-connect.
}
status, ok := tracker.streamStatus(peerID)
require.True(t, ok)
require.Equal(t, expect, status)
})
runStep(t, "delete", func(t *testing.T) {
tracker.deleteStatus(peerID)
status, ok := tracker.streamStatus(peerID)
require.False(t, ok)
require.Zero(t, status)
})
}
func TestStreamTracker_connectedStreams(t *testing.T) {
type testCase struct {
name string
setup func(t *testing.T, s *streamTracker)
expect []string
}
run := func(t *testing.T, tc testCase) {
tracker := newStreamTracker()
if tc.setup != nil {
tc.setup(t, tracker)
}
streams := tracker.connectedStreams()
var keys []string
for key := range streams {
keys = append(keys, key)
}
sort.Strings(keys)
require.Equal(t, tc.expect, keys)
}
tt := []testCase{
{
name: "no streams",
expect: nil,
},
{
name: "all streams active",
setup: func(t *testing.T, s *streamTracker) {
_, err := s.connected("foo")
require.NoError(t, err)
_, err = s.connected("bar")
require.NoError(t, err)
},
expect: []string{"bar", "foo"},
},
{
name: "mixed active and inactive",
setup: func(t *testing.T, s *streamTracker) {
status, err := s.connected("foo")
require.NoError(t, err)
// Mark foo as disconnected to avoid showing it as an active stream
status.trackDisconnected()
_, err = s.connected("bar")
require.NoError(t, err)
},
expect: []string{"bar"},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}

View File

@ -0,0 +1,149 @@
package peering
import (
"context"
"errors"
"fmt"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/submatview"
"github.com/hashicorp/consul/lib/retry"
"github.com/hashicorp/consul/proto/pbservice"
)
type MaterializedViewStore interface {
Get(ctx context.Context, req submatview.Request) (submatview.Result, error)
Notify(ctx context.Context, req submatview.Request, cID string, ch chan<- cache.UpdateEvent) error
}
type SubscriptionBackend interface {
Subscriber
Store() Store
}
// subscriptionManager handlers requests to subscribe to events from an events publisher.
type subscriptionManager struct {
logger hclog.Logger
viewStore MaterializedViewStore
backend SubscriptionBackend
// watchedServices is a map of exported services to a cancel function for their subscription notifier.
watchedServices map[structs.ServiceName]context.CancelFunc
}
// TODO(peering): Maybe centralize so that there is a single manager per datacenter, rather than per peering.
func newSubscriptionManager(ctx context.Context, logger hclog.Logger, backend SubscriptionBackend) *subscriptionManager {
logger = logger.Named("subscriptions")
store := submatview.NewStore(logger.Named("viewstore"))
go store.Run(ctx)
return &subscriptionManager{
logger: logger,
viewStore: store,
backend: backend,
watchedServices: make(map[structs.ServiceName]context.CancelFunc),
}
}
// subscribe returns a channel that will contain updates to exported service instances for a given peer.
func (m *subscriptionManager) subscribe(ctx context.Context, peerID string) <-chan cache.UpdateEvent {
updateCh := make(chan cache.UpdateEvent, 1)
go m.syncSubscriptions(ctx, peerID, updateCh)
return updateCh
}
func (m *subscriptionManager) syncSubscriptions(ctx context.Context, peerID string, updateCh chan<- cache.UpdateEvent) {
waiter := &retry.Waiter{
MinFailures: 1,
Factor: 500 * time.Millisecond,
MaxWait: 60 * time.Second,
Jitter: retry.NewJitter(100),
}
for {
if err := m.syncSubscriptionsAndBlock(ctx, peerID, updateCh); err != nil {
m.logger.Error("failed to sync subscriptions", "error", err)
}
if err := waiter.Wait(ctx); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
m.logger.Error("failed to wait before re-trying sync", "error", err)
}
select {
case <-ctx.Done():
return
default:
}
}
}
// syncSubscriptionsAndBlock ensures that the subscriptions to the subscription backend
// match the list of services exported to the peer.
func (m *subscriptionManager) syncSubscriptionsAndBlock(ctx context.Context, peerID string, updateCh chan<- cache.UpdateEvent) error {
store := m.backend.Store()
ws := memdb.NewWatchSet()
ws.Add(store.AbandonCh())
ws.Add(ctx.Done())
// Get exported services for peer id
_, services, err := store.ExportedServicesForPeer(ws, peerID)
if err != nil {
return fmt.Errorf("failed to watch exported services for peer %q: %w", peerID, err)
}
// seen contains the set of exported service names and is used to reconcile the list of watched services.
seen := make(map[structs.ServiceName]struct{})
// Ensure there is a subscription for each service exported to the peer.
for _, svc := range services {
seen[svc] = struct{}{}
if _, ok := m.watchedServices[svc]; ok {
// Exported service is already being watched, nothing to do.
continue
}
notifyCtx, cancel := context.WithCancel(ctx)
m.watchedServices[svc] = cancel
if err := m.Notify(notifyCtx, svc, updateCh); err != nil {
m.logger.Error("failed to subscribe to service", "service", svc.String())
continue
}
}
// For every subscription without an exported service, call the associated cancel fn.
for svc, cancel := range m.watchedServices {
if _, ok := seen[svc]; !ok {
cancel()
// Send an empty event to the stream handler to trigger sending a DELETE message.
// Cancelling the subscription context above is necessary, but does not yield a useful signal on its own.
updateCh <- cache.UpdateEvent{
CorrelationID: subExportedService + svc.String(),
Result: &pbservice.IndexedCheckServiceNodes{},
}
}
}
// Block for any changes to the state store.
ws.WatchCh(ctx)
return nil
}
const (
subExportedService = "exported-service:"
)
// Notify the given channel when there are updates to the requested service.
func (m *subscriptionManager) Notify(ctx context.Context, svc structs.ServiceName, updateCh chan<- cache.UpdateEvent) error {
sr := newExportedServiceRequest(m.logger, svc, m.backend)
return m.viewStore.Notify(ctx, sr, subExportedService+svc.String(), updateCh)
}

View File

@ -0,0 +1,362 @@
package peering
import (
"context"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/proto/pbservice"
"github.com/hashicorp/consul/sdk/testutil/retry"
)
type testSubscriptionBackend struct {
state.EventPublisher
store *state.Store
}
func (b *testSubscriptionBackend) Store() Store {
return b.store
}
func TestSubscriptionManager_RegisterDeregister(t *testing.T) {
publisher := stream.NewEventPublisher(10 * time.Second)
store := newStateStore(t, publisher)
backend := testSubscriptionBackend{
EventPublisher: publisher,
store: store,
}
ctx := context.Background()
mgr := newSubscriptionManager(ctx, hclog.New(nil), &backend)
// Create a peering
var lastIdx uint64 = 1
err := store.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "my-peering",
})
require.NoError(t, err)
_, p, err := store.PeeringRead(nil, state.Query{Value: "my-peering"})
require.NoError(t, err)
require.NotNil(t, p)
id := p.ID
subCh := mgr.subscribe(ctx, id)
entry := &structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "mysql",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peering",
},
},
},
{
Name: "mongo",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-other-peering",
},
},
},
},
}
lastIdx++
err = store.EnsureConfigEntry(lastIdx, entry)
require.NoError(t, err)
mysql1 := &structs.CheckServiceNode{
Node: &structs.Node{Node: "foo", Address: "10.0.0.1"},
Service: &structs.NodeService{ID: "mysql-1", Service: "mysql", Port: 5000},
Checks: structs.HealthChecks{
&structs.HealthCheck{CheckID: "mysql-check", ServiceID: "mysql-1", Node: "foo"},
},
}
runStep(t, "registering exported service instance yields update", func(t *testing.T) {
lastIdx++
require.NoError(t, store.EnsureNode(lastIdx, mysql1.Node))
lastIdx++
require.NoError(t, store.EnsureService(lastIdx, "foo", mysql1.Service))
lastIdx++
require.NoError(t, store.EnsureCheck(lastIdx, mysql1.Checks[0]))
// Receive in a retry loop so that eventually we converge onto the expected CheckServiceNode.
retry.Run(t, func(r *retry.R) {
select {
case update := <-subCh:
nodes, ok := update.Result.(*pbservice.IndexedCheckServiceNodes)
require.True(r, ok)
require.Equal(r, uint64(5), nodes.Index)
require.Len(r, nodes.Nodes, 1)
require.Equal(r, "foo", nodes.Nodes[0].Node.Node)
require.Equal(r, "mysql-1", nodes.Nodes[0].Service.ID)
require.Len(r, nodes.Nodes[0].Checks, 1)
require.Equal(r, "mysql-check", nodes.Nodes[0].Checks[0].CheckID)
default:
r.Fatalf("invalid update")
}
})
})
mysql2 := &structs.CheckServiceNode{
Node: &structs.Node{Node: "bar", Address: "10.0.0.2"},
Service: &structs.NodeService{ID: "mysql-2", Service: "mysql", Port: 5000},
Checks: structs.HealthChecks{
&structs.HealthCheck{CheckID: "mysql-2-check", ServiceID: "mysql-2", Node: "bar"},
},
}
runStep(t, "additional instances are returned when registered", func(t *testing.T) {
lastIdx++
require.NoError(t, store.EnsureNode(lastIdx, mysql2.Node))
lastIdx++
require.NoError(t, store.EnsureService(lastIdx, "bar", mysql2.Service))
lastIdx++
require.NoError(t, store.EnsureCheck(lastIdx, mysql2.Checks[0]))
retry.Run(t, func(r *retry.R) {
select {
case update := <-subCh:
nodes, ok := update.Result.(*pbservice.IndexedCheckServiceNodes)
require.True(r, ok)
require.Equal(r, uint64(8), nodes.Index)
require.Len(r, nodes.Nodes, 2)
require.Equal(r, "bar", nodes.Nodes[0].Node.Node)
require.Equal(r, "mysql-2", nodes.Nodes[0].Service.ID)
require.Len(r, nodes.Nodes[0].Checks, 1)
require.Equal(r, "mysql-2-check", nodes.Nodes[0].Checks[0].CheckID)
require.Equal(r, "foo", nodes.Nodes[1].Node.Node)
require.Equal(r, "mysql-1", nodes.Nodes[1].Service.ID)
require.Len(r, nodes.Nodes[1].Checks, 1)
require.Equal(r, "mysql-check", nodes.Nodes[1].Checks[0].CheckID)
default:
r.Fatalf("invalid update")
}
})
})
runStep(t, "no updates are received for services not exported to my-peering", func(t *testing.T) {
mongo := &structs.CheckServiceNode{
Node: &structs.Node{Node: "zip", Address: "10.0.0.3"},
Service: &structs.NodeService{ID: "mongo", Service: "mongo", Port: 5000},
Checks: structs.HealthChecks{
&structs.HealthCheck{CheckID: "mongo-check", ServiceID: "mongo", Node: "zip"},
},
}
lastIdx++
require.NoError(t, store.EnsureNode(lastIdx, mongo.Node))
lastIdx++
require.NoError(t, store.EnsureService(lastIdx, "zip", mongo.Service))
lastIdx++
require.NoError(t, store.EnsureCheck(lastIdx, mongo.Checks[0]))
// Receive from subCh times out. The retry in the last step already consumed all the mysql events.
select {
case update := <-subCh:
nodes, ok := update.Result.(*pbservice.IndexedCheckServiceNodes)
if ok && len(nodes.Nodes) > 0 && nodes.Nodes[0].Node.Node == "zip" {
t.Fatalf("received update for mongo node zip")
}
case <-time.After(100 * time.Millisecond):
// Expect this to fire
}
})
runStep(t, "deregister an instance and it gets removed from the output", func(t *testing.T) {
lastIdx++
require.NoError(t, store.DeleteService(lastIdx, "foo", mysql1.Service.ID, nil, ""))
select {
case update := <-subCh:
nodes, ok := update.Result.(*pbservice.IndexedCheckServiceNodes)
require.True(t, ok)
require.Equal(t, uint64(12), nodes.Index)
require.Len(t, nodes.Nodes, 1)
require.Equal(t, "bar", nodes.Nodes[0].Node.Node)
require.Equal(t, "mysql-2", nodes.Nodes[0].Service.ID)
require.Len(t, nodes.Nodes[0].Checks, 1)
require.Equal(t, "mysql-2-check", nodes.Nodes[0].Checks[0].CheckID)
case <-time.After(100 * time.Millisecond):
t.Fatalf("timed out waiting for update")
}
})
runStep(t, "deregister the last instance and the output is empty", func(t *testing.T) {
lastIdx++
require.NoError(t, store.DeleteService(lastIdx, "bar", mysql2.Service.ID, nil, ""))
select {
case update := <-subCh:
nodes, ok := update.Result.(*pbservice.IndexedCheckServiceNodes)
require.True(t, ok)
require.Equal(t, uint64(13), nodes.Index)
require.Len(t, nodes.Nodes, 0)
case <-time.After(100 * time.Millisecond):
t.Fatalf("timed out waiting for update")
}
})
}
func TestSubscriptionManager_InitialSnapshot(t *testing.T) {
publisher := stream.NewEventPublisher(10 * time.Second)
store := newStateStore(t, publisher)
backend := testSubscriptionBackend{
EventPublisher: publisher,
store: store,
}
ctx := context.Background()
mgr := newSubscriptionManager(ctx, hclog.New(nil), &backend)
// Create a peering
var lastIdx uint64 = 1
err := store.PeeringWrite(lastIdx, &pbpeering.Peering{
Name: "my-peering",
})
require.NoError(t, err)
_, p, err := store.PeeringRead(nil, state.Query{Value: "my-peering"})
require.NoError(t, err)
require.NotNil(t, p)
id := p.ID
subCh := mgr.subscribe(ctx, id)
// Register two services that are not yet exported
mysql := &structs.CheckServiceNode{
Node: &structs.Node{Node: "foo", Address: "10.0.0.1"},
Service: &structs.NodeService{ID: "mysql-1", Service: "mysql", Port: 5000},
}
lastIdx++
require.NoError(t, store.EnsureNode(lastIdx, mysql.Node))
lastIdx++
require.NoError(t, store.EnsureService(lastIdx, "foo", mysql.Service))
mongo := &structs.CheckServiceNode{
Node: &structs.Node{Node: "zip", Address: "10.0.0.3"},
Service: &structs.NodeService{ID: "mongo-1", Service: "mongo", Port: 5000},
}
lastIdx++
require.NoError(t, store.EnsureNode(lastIdx, mongo.Node))
lastIdx++
require.NoError(t, store.EnsureService(lastIdx, "zip", mongo.Service))
// No updates should be received, because neither service is exported.
select {
case update := <-subCh:
nodes, ok := update.Result.(*pbservice.IndexedCheckServiceNodes)
if ok && len(nodes.Nodes) > 0 {
t.Fatalf("received unexpected update")
}
case <-time.After(100 * time.Millisecond):
// Expect this to fire
}
runStep(t, "exporting the two services yields an update for both", func(t *testing.T) {
entry := &structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "mysql",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peering",
},
},
},
{
Name: "mongo",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peering",
},
},
},
},
}
lastIdx++
err = store.EnsureConfigEntry(lastIdx, entry)
require.NoError(t, err)
var (
sawMySQL bool
sawMongo bool
)
retry.Run(t, func(r *retry.R) {
select {
case update := <-subCh:
nodes, ok := update.Result.(*pbservice.IndexedCheckServiceNodes)
require.True(r, ok)
require.Len(r, nodes.Nodes, 1)
switch nodes.Nodes[0].Service.Service {
case "mongo":
sawMongo = true
case "mysql":
sawMySQL = true
}
if !sawMySQL || !sawMongo {
r.Fatalf("missing an update")
}
default:
r.Fatalf("invalid update")
}
})
})
}
func newStateStore(t *testing.T, publisher *stream.EventPublisher) *state.Store {
gc, err := state.NewTombstoneGC(time.Second, time.Millisecond)
require.NoError(t, err)
store := state.NewStateStoreWithEventPublisher(gc, publisher)
require.NoError(t, publisher.RegisterHandler(state.EventTopicServiceHealth, store.ServiceHealthSnapshot))
require.NoError(t, publisher.RegisterHandler(state.EventTopicServiceHealthConnect, store.ServiceHealthSnapshot))
go publisher.Run(context.Background())
return store
}

View File

@ -0,0 +1,141 @@
package peering
import (
"fmt"
"sort"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/submatview"
"github.com/hashicorp/consul/proto/pbservice"
"github.com/hashicorp/consul/proto/pbsubscribe"
)
type Subscriber interface {
Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error)
}
type exportedServiceRequest struct {
logger hclog.Logger
req structs.ServiceSpecificRequest
sub Subscriber
}
func newExportedServiceRequest(logger hclog.Logger, svc structs.ServiceName, sub Subscriber) *exportedServiceRequest {
req := structs.ServiceSpecificRequest{
// TODO(peering): Need to subscribe to both Connect and not
Connect: false,
ServiceName: svc.Name,
EnterpriseMeta: svc.EnterpriseMeta,
}
return &exportedServiceRequest{
logger: logger,
req: req,
sub: sub,
}
}
// CacheInfo implements submatview.Request
func (e *exportedServiceRequest) CacheInfo() cache.RequestInfo {
return e.req.CacheInfo()
}
// NewMaterializer implements submatview.Request
func (e *exportedServiceRequest) NewMaterializer() (submatview.Materializer, error) {
reqFn := func(index uint64) *pbsubscribe.SubscribeRequest {
r := &pbsubscribe.SubscribeRequest{
Topic: pbsubscribe.Topic_ServiceHealth,
Key: e.req.ServiceName,
Token: e.req.Token,
Datacenter: e.req.Datacenter,
Index: index,
Namespace: e.req.EnterpriseMeta.NamespaceOrEmpty(),
Partition: e.req.EnterpriseMeta.PartitionOrEmpty(),
}
if e.req.Connect {
r.Topic = pbsubscribe.Topic_ServiceHealthConnect
}
return r
}
deps := submatview.Deps{
View: newExportedServicesView(),
Logger: e.logger,
Request: reqFn,
}
return submatview.NewLocalMaterializer(e.sub, deps), nil
}
// Type implements submatview.Request
func (e *exportedServiceRequest) Type() string {
return "leader.peering.stream.exportedServiceRequest"
}
// exportedServicesView implements submatview.View for storing the view state
// of an exported service's health result. We store it as a map to make updates and
// deletions a little easier but we could just store a result type
// (IndexedCheckServiceNodes) and update it in place for each event - that
// involves re-sorting each time etc. though.
//
// Unlike rpcclient.healthView, there is no need for a filter because for exported services
// we export all instances unconditionally.
type exportedServicesView struct {
state map[string]*pbservice.CheckServiceNode
}
func newExportedServicesView() *exportedServicesView {
return &exportedServicesView{
state: make(map[string]*pbservice.CheckServiceNode),
}
}
// Reset implements submatview.View
func (s *exportedServicesView) Reset() {
s.state = make(map[string]*pbservice.CheckServiceNode)
}
// Update implements submatview.View
func (s *exportedServicesView) Update(events []*pbsubscribe.Event) error {
for _, event := range events {
serviceHealth := event.GetServiceHealth()
if serviceHealth == nil {
return fmt.Errorf("unexpected event type for service health view: %T",
event.GetPayload())
}
id := serviceHealth.CheckServiceNode.UniqueID()
switch serviceHealth.Op {
case pbsubscribe.CatalogOp_Register:
s.state[id] = serviceHealth.CheckServiceNode
case pbsubscribe.CatalogOp_Deregister:
delete(s.state, id)
}
}
return nil
}
// Result returns the CheckServiceNodes stored by this view.
// Result implements submatview.View
func (s *exportedServicesView) Result(index uint64) interface{} {
result := pbservice.IndexedCheckServiceNodes{
Nodes: make([]*pbservice.CheckServiceNode, 0, len(s.state)),
Index: index,
}
for _, node := range s.state {
result.Nodes = append(result.Nodes, node)
}
sortCheckServiceNodes(&result)
return &result
}
// sortCheckServiceNodes stable sorts the results to match memdb semantics.
func sortCheckServiceNodes(n *pbservice.IndexedCheckServiceNodes) {
sort.SliceStable(n.Nodes, func(i, j int) bool {
return n.Nodes[i].UniqueID() < n.Nodes[j].UniqueID()
})
}

View File

@ -0,0 +1,338 @@
package peering
import (
"context"
"math/rand"
"sort"
"sync"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/submatview"
"github.com/hashicorp/consul/proto/pbservice"
"github.com/hashicorp/consul/proto/pbsubscribe"
)
// TestExportedServiceSubscription tests the exported services view and the backing submatview.LocalMaterializer.
func TestExportedServiceSubscription(t *testing.T) {
s := &stateMap{
states: make(map[string]*serviceState),
}
sh := snapshotHandler{stateMap: s}
pub := stream.NewEventPublisher(10 * time.Millisecond)
pub.RegisterHandler(pbsubscribe.Topic_ServiceHealth, sh.Snapshot)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go pub.Run(ctx)
apiSN := structs.NewServiceName("api", nil)
webSN := structs.NewServiceName("web", nil)
// List of updates to the state store:
// - api: {register api-1, register api-2, register api-3}
// - web: {register web-1, deregister web-1, register web-2}1
events := []map[string]stream.Event{
{
apiSN.String(): stream.Event{
Topic: pbsubscribe.Topic_ServiceHealth,
Payload: state.EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Register,
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
ID: "api-1",
Service: "api",
},
},
},
},
webSN.String(): stream.Event{
Topic: pbsubscribe.Topic_ServiceHealth,
Payload: state.EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Register,
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
ID: "web-1",
Service: "web",
},
},
},
},
},
{
apiSN.String(): stream.Event{
Topic: pbsubscribe.Topic_ServiceHealth,
Payload: state.EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Register,
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
ID: "api-2",
Service: "api",
},
},
},
},
webSN.String(): stream.Event{
Topic: pbsubscribe.Topic_ServiceHealth,
Payload: state.EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Deregister,
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
ID: "web-1",
Service: "web",
},
},
},
},
},
{
apiSN.String(): stream.Event{
Topic: pbsubscribe.Topic_ServiceHealth,
Payload: state.EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Register,
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
ID: "api-3",
Service: "api",
},
},
},
},
webSN.String(): stream.Event{
Topic: pbsubscribe.Topic_ServiceHealth,
Payload: state.EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Register,
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
ID: "web-2",
Service: "web",
},
},
},
},
},
}
// store represents Consul's memdb state store.
// A stream of event updates
store := store{stateMap: s, pub: pub}
// This errgroup is used to issue simulate async updates to the state store,
// and also consume that fixed number of updates.
group, gctx := errgroup.WithContext(ctx)
group.Go(func() error {
store.simulateUpdates(gctx, events)
return nil
})
// viewStore is the store shared by the two service consumer's materializers.
// It is intentionally not run in the errgroup because it will block until the context is canceled.
viewStore := submatview.NewStore(hclog.New(nil))
go viewStore.Run(ctx)
// Each consumer represents a subscriber to exported service updates, and will consume
// stream events for the service name it is interested in.
consumers := make(map[string]*consumer)
for _, svc := range []structs.ServiceName{apiSN, webSN} {
c := &consumer{
viewStore: viewStore,
publisher: pub,
seenByIndex: make(map[uint64][]string),
}
service := svc
group.Go(func() error {
return c.consume(gctx, service.Name, len(events))
})
consumers[service.String()] = c
}
// Wait until all the events have been simulated and consumed.
done := make(chan struct{})
go func() {
defer close(done)
_ = group.Wait()
}()
select {
case <-done:
// finished
case <-time.After(500 * time.Millisecond):
// timed out, the Wait context will be cancelled by
t.Fatalf("timed out waiting for producers and consumers")
}
for svc, c := range consumers {
require.NotEmpty(t, c.seenByIndex)
// Note that store.states[svc].idsByIndex does not assert against a slice of expectations because
// the index that the different events will arrive in the simulation is not deterministic.
require.Equal(t, store.states[svc].idsByIndex, c.seenByIndex)
}
}
// stateMap is a map keyed by service to the state of the store at different indexes
type stateMap struct {
mu sync.Mutex
states map[string]*serviceState
}
type store struct {
*stateMap
pub *stream.EventPublisher
}
// simulateUpdates will publish events and also store the state at each index for later assertions.
func (s *store) simulateUpdates(ctx context.Context, events []map[string]stream.Event) {
idx := uint64(0)
for _, m := range events {
if ctx.Err() != nil {
return
}
for svc, event := range m {
idx++
event.Index = idx
s.pub.Publish([]stream.Event{event})
s.stateMap.mu.Lock()
svcState, ok := s.states[svc]
if !ok {
svcState = &serviceState{
current: make(map[string]*structs.CheckServiceNode),
idsByIndex: make(map[uint64][]string),
}
s.states[svc] = svcState
}
s.stateMap.mu.Unlock()
svcState.mu.Lock()
svcState.idx = idx
// Updating the svcState.current map allows us to capture snapshots from a stream of add/delete events.
payload := event.Payload.(state.EventPayloadCheckServiceNode)
switch payload.Op {
case pbsubscribe.CatalogOp_Register:
svcState.current[payload.Value.Service.ID] = payload.Value
default:
// If not a registration it must be a deregistration:
delete(svcState.current, payload.Value.Service.ID)
}
svcState.idsByIndex[idx] = serviceIDsFromMap(svcState.current)
svcState.mu.Unlock()
delay := time.Duration(rand.Intn(25)) * time.Millisecond
time.Sleep(5*time.Millisecond + delay)
}
}
}
func serviceIDsFromMap(m map[string]*structs.CheckServiceNode) []string {
var result []string
for id := range m {
result = append(result, id)
}
sort.Strings(result)
return result
}
type snapshotHandler struct {
*stateMap
}
type serviceState struct {
mu sync.Mutex
idx uint64
// The current snapshot of data, given the observed events.
current map[string]*structs.CheckServiceNode
// The list of service IDs seen at each index that an update was received for the given service name.
idsByIndex map[uint64][]string
}
// Snapshot dumps the currently registered service instances.
//
// Snapshot implements stream.SnapshotFunc.
func (s *snapshotHandler) Snapshot(req stream.SubscribeRequest, buf stream.SnapshotAppender) (index uint64, err error) {
s.stateMap.mu.Lock()
svcState, ok := s.states[req.Subject.String()]
if !ok {
svcState = &serviceState{
current: make(map[string]*structs.CheckServiceNode),
idsByIndex: make(map[uint64][]string),
}
s.states[req.Subject.String()] = svcState
}
s.stateMap.mu.Unlock()
svcState.mu.Lock()
defer svcState.mu.Unlock()
for _, node := range svcState.current {
event := stream.Event{
Topic: pbsubscribe.Topic_ServiceHealth,
Index: svcState.idx,
Payload: state.EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Register,
Value: node,
},
}
buf.Append([]stream.Event{event})
}
return svcState.idx, nil
}
type consumer struct {
viewStore *submatview.Store
publisher *stream.EventPublisher
seenByIndex map[uint64][]string
}
func (c *consumer) consume(ctx context.Context, service string, countExpected int) error {
group, gctx := errgroup.WithContext(ctx)
updateCh := make(chan cache.UpdateEvent, 10)
group.Go(func() error {
sr := newExportedServiceRequest(hclog.New(nil), structs.NewServiceName(service, nil), c.publisher)
return c.viewStore.Notify(gctx, sr, "", updateCh)
})
group.Go(func() error {
var n int
for {
if n >= countExpected {
return nil
}
select {
case u := <-updateCh:
// Each update contains the current snapshot of registered services.
c.seenByIndex[u.Meta.Index] = serviceIDsFromUpdates(u)
n++
case <-gctx.Done():
return nil
}
}
})
return group.Wait()
}
func serviceIDsFromUpdates(u cache.UpdateEvent) []string {
var result []string
for _, node := range u.Result.(*pbservice.IndexedCheckServiceNodes).Nodes {
result = append(result, node.Service.ID)
}
sort.Strings(result)
return result
}

View File

@ -0,0 +1,199 @@
package peering
import (
"context"
"io"
"sync"
"testing"
"time"
"google.golang.org/grpc/metadata"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
)
// same certificate that appears in our connect tests
var validCA = `
-----BEGIN CERTIFICATE-----
MIICmDCCAj6gAwIBAgIBBzAKBggqhkjOPQQDAjAWMRQwEgYDVQQDEwtDb25zdWwg
Q0EgNzAeFw0xODA1MjExNjMzMjhaFw0yODA1MTgxNjMzMjhaMBYxFDASBgNVBAMT
C0NvbnN1bCBDQSA3MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAER0qlxjnRcMEr
iSGlH7G7dYU7lzBEmLUSMZkyBbClmyV8+e8WANemjn+PLnCr40If9cmpr7RnC9Qk
GTaLnLiF16OCAXswggF3MA4GA1UdDwEB/wQEAwIBhjAPBgNVHRMBAf8EBTADAQH/
MGgGA1UdDgRhBF8xZjo5MTpjYTo0MTo4ZjphYzo2NzpiZjo1OTpjMjpmYTo0ZTo3
NTo1YzpkODpmMDo1NTpkZTpiZTo3NTpiODozMzozMTpkNToyNDpiMDowNDpiMzpl
ODo5Nzo1Yjo3ZTBqBgNVHSMEYzBhgF8xZjo5MTpjYTo0MTo4ZjphYzo2NzpiZjo1
OTpjMjpmYTo0ZTo3NTo1YzpkODpmMDo1NTpkZTpiZTo3NTpiODozMzozMTpkNToy
NDpiMDowNDpiMzplODo5Nzo1Yjo3ZTA/BgNVHREEODA2hjRzcGlmZmU6Ly8xMjRk
ZjVhMC05ODIwLTc2YzMtOWFhOS02ZjYyMTY0YmExYzIuY29uc3VsMD0GA1UdHgEB
/wQzMDGgLzAtgisxMjRkZjVhMC05ODIwLTc2YzMtOWFhOS02ZjYyMTY0YmExYzIu
Y29uc3VsMAoGCCqGSM49BAMCA0gAMEUCIQDzkkI7R+0U12a+zq2EQhP/n2mHmta+
fs2hBxWIELGwTAIgLdO7RRw+z9nnxCIA6kNl//mIQb+PGItespiHZKAz74Q=
-----END CERTIFICATE-----
`
var invalidCA = `
-----BEGIN CERTIFICATE-----
not valid
-----END CERTIFICATE-----
`
var validAddress = "1.2.3.4:80"
var validServerName = "server.consul"
var validPeerID = "peer1"
// TODO(peering): the test methods below are exposed to prevent duplication,
// these should be removed at same time tests in peering_test get refactored.
// XXX: we can't put the existing tests in service_test.go into the peering
// package because it causes an import cycle by importing the top-level consul
// package (which correctly imports the agent/rpc/peering package)
// TestPeering is a test utility for generating a pbpeering.Peering with valid
// data along with the peerName, state and index.
func TestPeering(peerName string, state pbpeering.PeeringState) *pbpeering.Peering {
return &pbpeering.Peering{
Name: peerName,
PeerCAPems: []string{validCA},
PeerServerAddresses: []string{validAddress},
PeerServerName: validServerName,
State: state,
// uncomment once #1613 lands
// PeerID: validPeerID
}
}
// TestPeeringToken is a test utility for generating a valid peering token
// with the given peerID for use in test cases
func TestPeeringToken(peerID string) structs.PeeringToken {
return structs.PeeringToken{
CA: []string{validCA},
ServerAddresses: []string{validAddress},
ServerName: validServerName,
PeerID: peerID,
}
}
type mockClient struct {
mu sync.Mutex
errCh chan error
replicationStream *mockStream
}
func (c *mockClient) Send(r *pbpeering.ReplicationMessage) error {
c.replicationStream.recvCh <- r
return nil
}
func (c *mockClient) Recv() (*pbpeering.ReplicationMessage, error) {
select {
case err := <-c.errCh:
return nil, err
case r := <-c.replicationStream.sendCh:
return r, nil
case <-time.After(10 * time.Millisecond):
return nil, io.EOF
}
}
func (c *mockClient) RecvWithTimeout(dur time.Duration) (*pbpeering.ReplicationMessage, error) {
select {
case err := <-c.errCh:
return nil, err
case r := <-c.replicationStream.sendCh:
return r, nil
case <-time.After(dur):
return nil, io.EOF
}
}
func (c *mockClient) Close() {
close(c.replicationStream.recvCh)
}
func newMockClient(ctx context.Context) *mockClient {
return &mockClient{
replicationStream: newTestReplicationStream(ctx),
}
}
// mockStream mocks peering.PeeringService_StreamResourcesServer
type mockStream struct {
sendCh chan *pbpeering.ReplicationMessage
recvCh chan *pbpeering.ReplicationMessage
ctx context.Context
mu sync.Mutex
}
var _ pbpeering.PeeringService_StreamResourcesServer = (*mockStream)(nil)
func newTestReplicationStream(ctx context.Context) *mockStream {
return &mockStream{
sendCh: make(chan *pbpeering.ReplicationMessage, 1),
recvCh: make(chan *pbpeering.ReplicationMessage, 1),
ctx: ctx,
}
}
// Send implements pbpeering.PeeringService_StreamResourcesServer
func (s *mockStream) Send(r *pbpeering.ReplicationMessage) error {
s.sendCh <- r
return nil
}
// Recv implements pbpeering.PeeringService_StreamResourcesServer
func (s *mockStream) Recv() (*pbpeering.ReplicationMessage, error) {
r := <-s.recvCh
if r == nil {
return nil, io.EOF
}
return r, nil
}
// Context implements grpc.ServerStream and grpc.ClientStream
func (s *mockStream) Context() context.Context {
return s.ctx
}
// SendMsg implements grpc.ServerStream and grpc.ClientStream
func (s *mockStream) SendMsg(m interface{}) error {
return nil
}
// RecvMsg implements grpc.ServerStream and grpc.ClientStream
func (s *mockStream) RecvMsg(m interface{}) error {
return nil
}
// SetHeader implements grpc.ServerStream
func (s *mockStream) SetHeader(metadata.MD) error {
return nil
}
// SendHeader implements grpc.ServerStream
func (s *mockStream) SendHeader(metadata.MD) error {
return nil
}
// SetTrailer implements grpc.ServerStream
func (s *mockStream) SetTrailer(metadata.MD) {}
type incrementalTime struct {
base time.Time
next uint64
}
func (t *incrementalTime) Now() time.Time {
t.next++
return t.base.Add(time.Duration(t.next) * time.Second)
}
func runStep(t *testing.T, name string, fn func(t *testing.T)) {
t.Helper()
if !t.Run(name, fn) {
t.FailNow()
}
}

View File

@ -0,0 +1,16 @@
//go:build !consulent
// +build !consulent
package peering_test
import (
"testing"
"github.com/hashicorp/consul/agent/consul"
"github.com/hashicorp/go-hclog"
)
func newDefaultDepsEnterprise(t *testing.T, logger hclog.Logger, c *consul.Config) consul.EnterpriseDeps {
t.Helper()
return consul.EnterpriseDeps{}
}

View File

@ -0,0 +1,62 @@
package peering
import (
"fmt"
"net"
"strconv"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
// TODO: replace this with net/netip when we upgrade to go1.18
"inet.af/netaddr"
)
// validatePeeringToken ensures that the token has valid values.
func validatePeeringToken(tok *structs.PeeringToken) error {
if len(tok.CA) == 0 {
return errPeeringTokenEmptyCA
}
// the CA values here should be valid x509 certs
for _, certStr := range tok.CA {
// TODO(peering): should we put these in a cert pool on the token?
// maybe there's a better place to do the parsing?
if _, err := connect.ParseCert(certStr); err != nil {
return fmt.Errorf("peering token invalid CA: %w", err)
}
}
if len(tok.ServerAddresses) == 0 {
return errPeeringTokenEmptyServerAddresses
}
for _, addr := range tok.ServerAddresses {
host, portRaw, err := net.SplitHostPort(addr)
if err != nil {
return &errPeeringInvalidServerAddress{addr}
}
port, err := strconv.Atoi(portRaw)
if err != nil {
return &errPeeringInvalidServerAddress{addr}
}
if port < 1 || port > 65535 {
return &errPeeringInvalidServerAddress{addr}
}
if _, err := netaddr.ParseIP(host); err != nil {
return &errPeeringInvalidServerAddress{addr}
}
}
// TODO(peering): validate name matches SNI?
// TODO(peering): validate name well formed?
if tok.ServerName == "" {
return errPeeringTokenEmptyServerName
}
if tok.PeerID == "" {
return errPeeringTokenEmptyPeerID
}
return nil
}

View File

@ -0,0 +1,107 @@
package peering
import (
"errors"
"testing"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/require"
)
func TestValidatePeeringToken(t *testing.T) {
type testCase struct {
name string
token *structs.PeeringToken
wantErr error
}
tt := []testCase{
{
name: "empty",
token: &structs.PeeringToken{},
wantErr: errPeeringTokenEmptyCA,
},
{
name: "empty CA",
token: &structs.PeeringToken{
CA: []string{},
},
wantErr: errPeeringTokenEmptyCA,
},
{
name: "invalid CA",
token: &structs.PeeringToken{
CA: []string{"notavalidcert"},
},
wantErr: errors.New("peering token invalid CA: no PEM-encoded data found"),
},
{
name: "invalid CA cert",
token: &structs.PeeringToken{
CA: []string{invalidCA},
},
wantErr: errors.New("peering token invalid CA: x509: malformed certificate"),
},
{
name: "invalid address port",
token: &structs.PeeringToken{
CA: []string{validCA},
ServerAddresses: []string{"1.2.3.4"},
},
wantErr: &errPeeringInvalidServerAddress{
"1.2.3.4",
},
},
{
name: "invalid address IP",
token: &structs.PeeringToken{
CA: []string{validCA},
ServerAddresses: []string{"foo.bar.baz"},
},
wantErr: &errPeeringInvalidServerAddress{
"foo.bar.baz",
},
},
{
name: "invalid server name",
token: &structs.PeeringToken{
CA: []string{validCA},
ServerAddresses: []string{"1.2.3.4:80"},
},
wantErr: errPeeringTokenEmptyServerName,
},
{
name: "invalid peer ID",
token: &structs.PeeringToken{
CA: []string{validCA},
ServerAddresses: []string{validAddress},
ServerName: validServerName,
},
wantErr: errPeeringTokenEmptyPeerID,
},
{
name: "valid token",
token: &structs.PeeringToken{
CA: []string{validCA},
ServerAddresses: []string{validAddress},
ServerName: validServerName,
PeerID: validPeerID,
},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
err := validatePeeringToken(tc.token)
if tc.wantErr != nil {
if err == nil {
t.Error("expected error but got nil")
return
}
require.Contains(t, err.Error(), tc.wantErr.Error())
return
}
require.NoError(t, err)
})
}
}

View File

@ -133,15 +133,16 @@ func (r serviceRequest) Type() string {
return "agent.rpcclient.health.serviceRequest" return "agent.rpcclient.health.serviceRequest"
} }
func (r serviceRequest) NewMaterializer() (*submatview.Materializer, error) { func (r serviceRequest) NewMaterializer() (submatview.Materializer, error) {
view, err := newHealthView(r.ServiceSpecificRequest) view, err := newHealthView(r.ServiceSpecificRequest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return submatview.NewMaterializer(submatview.Deps{ deps := submatview.Deps{
View: view, View: view,
Client: pbsubscribe.NewStateChangeSubscriptionClient(r.deps.Conn),
Logger: r.deps.Logger, Logger: r.deps.Logger,
Request: newMaterializerRequest(r.ServiceSpecificRequest), Request: newMaterializerRequest(r.ServiceSpecificRequest),
}), nil }
return submatview.NewRPCMaterializer(pbsubscribe.NewStateChangeSubscriptionClient(r.deps.Conn), deps), nil
} }

View File

@ -537,17 +537,17 @@ type serviceRequestStub struct {
streamClient submatview.StreamClient streamClient submatview.StreamClient
} }
func (r serviceRequestStub) NewMaterializer() (*submatview.Materializer, error) { func (r serviceRequestStub) NewMaterializer() (submatview.Materializer, error) {
view, err := newHealthView(r.ServiceSpecificRequest) view, err := newHealthView(r.ServiceSpecificRequest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return submatview.NewMaterializer(submatview.Deps{ deps := submatview.Deps{
View: view, View: view,
Client: r.streamClient,
Logger: hclog.New(nil), Logger: hclog.New(nil),
Request: newMaterializerRequest(r.ServiceSpecificRequest), Request: newMaterializerRequest(r.ServiceSpecificRequest),
}), nil }
return submatview.NewRPCMaterializer(r.streamClient, deps), nil
} }
func newEventServiceHealthRegister(index uint64, nodeNum int, svc string) *pbsubscribe.Event { func newEventServiceHealthRegister(index uint64, nodeNum int, svc string) *pbsubscribe.Event {

View File

@ -0,0 +1,62 @@
//go:build !consulent
// +build !consulent
package structs
import (
"testing"
)
func TestExportedServicesConfigEntry_OSS(t *testing.T) {
cases := map[string]configEntryTestcase{
"normalize: noop in oss": {
entry: &ExportedServicesConfigEntry{
Name: "default",
Services: []ExportedService{
{
Name: "web",
Consumers: []ServiceConsumer{
{
PeerName: "bar",
},
},
},
},
},
expected: &ExportedServicesConfigEntry{
Name: "default",
Services: []ExportedService{
{
Name: "web",
Namespace: "",
Consumers: []ServiceConsumer{
{
PeerName: "bar",
},
},
},
},
},
},
"validate: empty name": {
entry: &ExportedServicesConfigEntry{
Name: "",
},
validateErr: `exported-services Name must be "default"`,
},
"validate: wildcard name": {
entry: &ExportedServicesConfigEntry{
Name: WildcardSpecifier,
},
validateErr: `exported-services Name must be "default"`,
},
"validate: other name": {
entry: &ExportedServicesConfigEntry{
Name: "foo",
},
validateErr: `exported-services Name must be "default"`,
},
}
testConfigEntryNormalizeAndValidate(t, cases)
}

View File

@ -35,9 +35,14 @@ type ExportedService struct {
} }
// ServiceConsumer represents a downstream consumer of the service to be exported. // ServiceConsumer represents a downstream consumer of the service to be exported.
// At most one of Partition or PeerName must be specified.
type ServiceConsumer struct { type ServiceConsumer struct {
// Partition is the admin partition to export the service to. // Partition is the admin partition to export the service to.
// Deprecated: PeerName should be used for both remote peers and local partitions.
Partition string Partition string
// PeerName is the name of the peer to export the service to.
PeerName string
} }
func (e *ExportedServicesConfigEntry) ToMap() map[string]map[string][]string { func (e *ExportedServicesConfigEntry) ToMap() map[string]map[string][]string {
@ -99,37 +104,40 @@ func (e *ExportedServicesConfigEntry) Normalize() error {
e.EnterpriseMeta.Normalize() e.EnterpriseMeta.Normalize()
for i := range e.Services { for i := range e.Services {
e.Services[i].Namespace = acl.NamespaceOrDefault(e.Services[i].Namespace) e.Services[i].Namespace = acl.NormalizeNamespace(e.Services[i].Namespace)
} }
return nil return nil
} }
func (e *ExportedServicesConfigEntry) Validate() error { func (e *ExportedServicesConfigEntry) Validate() error {
if e.Name == "" { if err := validateExportedServicesName(e.Name); err != nil {
return fmt.Errorf("Name is required")
}
if e.Name == WildcardSpecifier {
return fmt.Errorf("exported-services Name must be the name of a partition, and not a wildcard")
}
if err := requireEnterprise(e.GetKind()); err != nil {
return err return err
} }
if err := validateConfigEntryMeta(e.Meta); err != nil { if err := validateConfigEntryMeta(e.Meta); err != nil {
return err return err
} }
for _, svc := range e.Services { for i, svc := range e.Services {
if svc.Name == "" { if svc.Name == "" {
return fmt.Errorf("service name cannot be empty") return fmt.Errorf("Services[%d]: service name cannot be empty", i)
}
if svc.Namespace == WildcardSpecifier && svc.Name != WildcardSpecifier {
return fmt.Errorf("Services[%d]: service name must be wildcard if namespace is wildcard", i)
} }
if len(svc.Consumers) == 0 { if len(svc.Consumers) == 0 {
return fmt.Errorf("service %q must have at least one consumer", svc.Name) return fmt.Errorf("Services[%d]: must have at least one consumer", i)
} }
for _, consumer := range svc.Consumers { for j, consumer := range svc.Consumers {
if consumer.PeerName != "" && consumer.Partition != "" {
return fmt.Errorf("Services[%d].Consumers[%d]: must define at most one of PeerName or Partition", i, j)
}
if consumer.Partition == WildcardSpecifier { if consumer.Partition == WildcardSpecifier {
return fmt.Errorf("exporting to all partitions (wildcard) is not yet supported") return fmt.Errorf("Services[%d].Consumers[%d]: exporting to all partitions (wildcard) is not supported", i, j)
}
if consumer.PeerName == WildcardSpecifier {
return fmt.Errorf("Services[%d].Consumers[%d]: exporting to all peers (wildcard) is not supported", i, j)
} }
} }
} }

View File

@ -0,0 +1,94 @@
package structs
import (
"testing"
)
func TestExportedServicesConfigEntry(t *testing.T) {
cases := map[string]configEntryTestcase{
"validate: empty service name": {
entry: &ExportedServicesConfigEntry{
Name: "default",
Services: []ExportedService{
{
Name: "",
},
},
},
validateErr: `service name cannot be empty`,
},
"validate: empty consumer list": {
entry: &ExportedServicesConfigEntry{
Name: "default",
Services: []ExportedService{
{
Name: "web",
},
},
},
validateErr: `must have at least one consumer`,
},
"validate: no wildcard in consumer partition": {
entry: &ExportedServicesConfigEntry{
Name: "default",
Services: []ExportedService{
{
Name: "api",
Consumers: []ServiceConsumer{
{
Partition: "foo",
},
},
},
{
Name: "web",
Consumers: []ServiceConsumer{
{
Partition: "*",
},
},
},
},
},
validateErr: `Services[1].Consumers[0]: exporting to all partitions (wildcard) is not supported`,
},
"validate: no wildcard in consumer peername": {
entry: &ExportedServicesConfigEntry{
Name: "default",
Services: []ExportedService{
{
Name: "web",
Consumers: []ServiceConsumer{
{
PeerName: "foo",
},
{
PeerName: "*",
},
},
},
},
},
validateErr: `Services[0].Consumers[1]: exporting to all peers (wildcard) is not supported`,
},
"validate: cannot specify consumer with partition and peername": {
entry: &ExportedServicesConfigEntry{
Name: "default",
Services: []ExportedService{
{
Name: "web",
Consumers: []ServiceConsumer{
{
Partition: "foo",
PeerName: "bar",
},
},
},
},
},
validateErr: `Services[0].Consumers[0]: must define at most one of PeerName or Partition`,
},
}
testConfigEntryNormalizeAndValidate(t, cases)
}

Some files were not shown because too many files have changed in this diff Show More