Merge branch 'main' of github.com:hashicorp/consul into sa-restructure-documentation

This commit is contained in:
trujillo-adam 2022-07-27 11:47:56 -07:00
commit 534f011663
402 changed files with 15694 additions and 7104 deletions

3
.changelog/13722.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:feature
streaming: Added topic that can be used to consume updates about the list of services in a datacenter
```

3
.changelog/13787.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
cli: when `acl token read` is used with the `-self` and `-expanded` flags, return an error instead of panicking
```

6
.changelog/13807.txt Normal file
View File

@ -0,0 +1,6 @@
```release-note: improvement
connect: Add Envoy 1.23.0 to support matrix
```
```release-note: breaking-change
connect: Removes support for Envoy 1.19
```

3
.changelog/13847.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
connect: Fixed a goroutine/memory leak that would occur when using the ingress gateway.
```

View File

@ -24,9 +24,10 @@ references:
VAULT_BINARY_VERSION: 1.9.4 VAULT_BINARY_VERSION: 1.9.4
GO_VERSION: 1.18.1 GO_VERSION: 1.18.1
envoy-versions: &supported_envoy_versions envoy-versions: &supported_envoy_versions
- &default_envoy_version "1.19.5" - &default_envoy_version "1.20.6"
- "1.20.4" - "1.21.4"
- "1.21.3" - "1.22.2"
- "1.23.0"
images: images:
# When updating the Go version, remember to also update the versions in the # When updating the Go version, remember to also update the versions in the
# workflows section for go-test-lib jobs. # workflows section for go-test-lib jobs.
@ -875,8 +876,13 @@ jobs:
environment: environment:
ENVOY_VERSION: << parameters.envoy-version >> ENVOY_VERSION: << parameters.envoy-version >>
XDS_TARGET: << parameters.xds-target >> XDS_TARGET: << parameters.xds-target >>
AWS_LAMBDA_REGION: us-west-2
steps: &ENVOY_INTEGRATION_TEST_STEPS steps: &ENVOY_INTEGRATION_TEST_STEPS
- checkout - checkout
- assume-role:
access-key: AWS_ACCESS_KEY_ID_LAMBDA
secret-key: AWS_SECRET_ACCESS_KEY_LAMBDA
role-arn: ROLE_ARN_LAMBDA
# Get go binary from workspace # Get go binary from workspace
- attach_workspace: - attach_workspace:
at: . at: .

View File

@ -254,8 +254,8 @@ jobs:
docker.io/hashicorppreview/${{ env.repo }}:${{ env.dev_tag }}-${{ github.sha }} docker.io/hashicorppreview/${{ env.repo }}:${{ env.dev_tag }}-${{ github.sha }}
smoke_test: .github/scripts/verify_docker.sh v${{ env.version }} smoke_test: .github/scripts/verify_docker.sh v${{ env.version }}
build-docker-redhat: build-docker-ubi-redhat:
name: Docker Build UBI Image for RedHat name: Docker Build UBI Image for RedHat Registry
needs: needs:
- get-product-version - get-product-version
- build - build
@ -274,6 +274,39 @@ jobs:
redhat_tag: scan.connect.redhat.com/ospid-60f9fdbec3a80eac643abedf/${{env.repo}}:${{env.version}}-ubi redhat_tag: scan.connect.redhat.com/ospid-60f9fdbec3a80eac643abedf/${{env.repo}}:${{env.version}}-ubi
smoke_test: .github/scripts/verify_docker.sh v${{ env.version }} smoke_test: .github/scripts/verify_docker.sh v${{ env.version }}
build-docker-ubi-dockerhub:
name: Docker Build UBI Image for DockerHub
needs:
- get-product-version
- build
runs-on: ubuntu-latest
env:
repo: ${{github.event.repository.name}}
version: ${{needs.get-product-version.outputs.product-version}}
steps:
- uses: actions/checkout@v2
# Strip everything but MAJOR.MINOR from the version string and add a `-dev` suffix
# This naming convention will be used ONLY for per-commit dev images
- name: Set docker dev tag
run: |
version="${{ env.version }}"
echo "dev_tag=${version%.*}-dev" >> $GITHUB_ENV
- uses: hashicorp/actions-docker-build@v1
with:
version: ${{env.version}}
target: ubi
arch: amd64
tags: |
docker.io/hashicorp/${{env.repo}}:${{env.version}}-ubi
public.ecr.aws/hashicorp/${{env.repo}}:${{env.version}}-ubi
dev_tags: |
docker.io/hashicorppreview/${{ env.repo }}:${{ env.dev_tag }}-ubi
docker.io/hashicorppreview/${{ env.repo }}:${{ env.dev_tag }}-ubi-${{ github.sha }}
smoke_test: .github/scripts/verify_docker.sh v${{ env.version }}
verify-linux: verify-linux:
needs: needs:
- get-product-version - get-product-version

1
.gitignore vendored
View File

@ -14,6 +14,7 @@ changelog.tmp
exit-code exit-code
Thumbs.db Thumbs.db
.idea .idea
.vscode
# MacOS # MacOS
.DS_Store .DS_Store

View File

@ -178,6 +178,15 @@ event "promote-dev-docker" {
} }
} }
event "fossa-scan" {
depends = ["promote-dev-docker"]
action "fossa-scan" {
organization = "hashicorp"
repository = "crt-workflows-common"
workflow = "fossa-scan"
}
}
## These are promotion and post-publish events ## These are promotion and post-publish events
## they should be added to the end of the file after the verify event stanza. ## they should be added to the end of the file after the verify event stanza.

View File

@ -27,6 +27,7 @@ func legacyPolicy(policy *Policy) *Policy {
Keyring: policy.Keyring, Keyring: policy.Keyring,
Operator: policy.Operator, Operator: policy.Operator,
Mesh: policy.Mesh, Mesh: policy.Mesh,
Peering: policy.Peering,
}, },
} }
} }
@ -117,6 +118,14 @@ func checkAllowMeshWrite(t *testing.T, authz Authorizer, prefix string, entCtx *
require.Equal(t, Allow, authz.MeshWrite(entCtx)) require.Equal(t, Allow, authz.MeshWrite(entCtx))
} }
func checkAllowPeeringRead(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) {
require.Equal(t, Allow, authz.PeeringRead(entCtx))
}
func checkAllowPeeringWrite(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) {
require.Equal(t, Allow, authz.PeeringWrite(entCtx))
}
func checkAllowOperatorRead(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) { func checkAllowOperatorRead(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) {
require.Equal(t, Allow, authz.OperatorRead(entCtx)) require.Equal(t, Allow, authz.OperatorRead(entCtx))
} }
@ -241,6 +250,14 @@ func checkDenyMeshWrite(t *testing.T, authz Authorizer, prefix string, entCtx *A
require.Equal(t, Deny, authz.MeshWrite(entCtx)) require.Equal(t, Deny, authz.MeshWrite(entCtx))
} }
func checkDenyPeeringRead(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) {
require.Equal(t, Deny, authz.PeeringRead(entCtx))
}
func checkDenyPeeringWrite(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) {
require.Equal(t, Deny, authz.PeeringWrite(entCtx))
}
func checkDenyOperatorRead(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) { func checkDenyOperatorRead(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) {
require.Equal(t, Deny, authz.OperatorRead(entCtx)) require.Equal(t, Deny, authz.OperatorRead(entCtx))
} }
@ -365,6 +382,14 @@ func checkDefaultMeshWrite(t *testing.T, authz Authorizer, prefix string, entCtx
require.Equal(t, Default, authz.MeshWrite(entCtx)) require.Equal(t, Default, authz.MeshWrite(entCtx))
} }
func checkDefaultPeeringRead(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) {
require.Equal(t, Default, authz.PeeringRead(entCtx))
}
func checkDefaultPeeringWrite(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) {
require.Equal(t, Default, authz.PeeringWrite(entCtx))
}
func checkDefaultOperatorRead(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) { func checkDefaultOperatorRead(t *testing.T, authz Authorizer, prefix string, entCtx *AuthorizerContext) {
require.Equal(t, Default, authz.OperatorRead(entCtx)) require.Equal(t, Default, authz.OperatorRead(entCtx))
} }
@ -446,6 +471,8 @@ func TestACL(t *testing.T) {
{name: "DenyNodeWrite", check: checkDenyNodeWrite}, {name: "DenyNodeWrite", check: checkDenyNodeWrite},
{name: "DenyMeshRead", check: checkDenyMeshRead}, {name: "DenyMeshRead", check: checkDenyMeshRead},
{name: "DenyMeshWrite", check: checkDenyMeshWrite}, {name: "DenyMeshWrite", check: checkDenyMeshWrite},
{name: "DenyPeeringRead", check: checkDenyPeeringRead},
{name: "DenyPeeringWrite", check: checkDenyPeeringWrite},
{name: "DenyOperatorRead", check: checkDenyOperatorRead}, {name: "DenyOperatorRead", check: checkDenyOperatorRead},
{name: "DenyOperatorWrite", check: checkDenyOperatorWrite}, {name: "DenyOperatorWrite", check: checkDenyOperatorWrite},
{name: "DenyPreparedQueryRead", check: checkDenyPreparedQueryRead}, {name: "DenyPreparedQueryRead", check: checkDenyPreparedQueryRead},
@ -480,6 +507,8 @@ func TestACL(t *testing.T) {
{name: "AllowNodeWrite", check: checkAllowNodeWrite}, {name: "AllowNodeWrite", check: checkAllowNodeWrite},
{name: "AllowMeshRead", check: checkAllowMeshRead}, {name: "AllowMeshRead", check: checkAllowMeshRead},
{name: "AllowMeshWrite", check: checkAllowMeshWrite}, {name: "AllowMeshWrite", check: checkAllowMeshWrite},
{name: "AllowPeeringRead", check: checkAllowPeeringRead},
{name: "AllowPeeringWrite", check: checkAllowPeeringWrite},
{name: "AllowOperatorRead", check: checkAllowOperatorRead}, {name: "AllowOperatorRead", check: checkAllowOperatorRead},
{name: "AllowOperatorWrite", check: checkAllowOperatorWrite}, {name: "AllowOperatorWrite", check: checkAllowOperatorWrite},
{name: "AllowPreparedQueryRead", check: checkAllowPreparedQueryRead}, {name: "AllowPreparedQueryRead", check: checkAllowPreparedQueryRead},
@ -514,6 +543,8 @@ func TestACL(t *testing.T) {
{name: "AllowNodeWrite", check: checkAllowNodeWrite}, {name: "AllowNodeWrite", check: checkAllowNodeWrite},
{name: "AllowMeshRead", check: checkAllowMeshRead}, {name: "AllowMeshRead", check: checkAllowMeshRead},
{name: "AllowMeshWrite", check: checkAllowMeshWrite}, {name: "AllowMeshWrite", check: checkAllowMeshWrite},
{name: "AllowPeeringRead", check: checkAllowPeeringRead},
{name: "AllowPeeringWrite", check: checkAllowPeeringWrite},
{name: "AllowOperatorRead", check: checkAllowOperatorRead}, {name: "AllowOperatorRead", check: checkAllowOperatorRead},
{name: "AllowOperatorWrite", check: checkAllowOperatorWrite}, {name: "AllowOperatorWrite", check: checkAllowOperatorWrite},
{name: "AllowPreparedQueryRead", check: checkAllowPreparedQueryRead}, {name: "AllowPreparedQueryRead", check: checkAllowPreparedQueryRead},
@ -1217,6 +1248,319 @@ func TestACL(t *testing.T) {
{name: "WriteAllowed", check: checkAllowMeshWrite}, {name: "WriteAllowed", check: checkAllowMeshWrite},
}, },
}, },
{
name: "PeeringDefaultAllowPolicyDeny",
defaultPolicy: AllowAll(),
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Peering: PolicyDeny,
},
},
},
checks: []aclCheck{
{name: "ReadDenied", check: checkDenyPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
name: "PeeringDefaultAllowPolicyRead",
defaultPolicy: AllowAll(),
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Peering: PolicyRead,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
name: "PeeringDefaultAllowPolicyWrite",
defaultPolicy: AllowAll(),
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Peering: PolicyWrite,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteAllowed", check: checkAllowPeeringWrite},
},
},
{
name: "PeeringDefaultAllowPolicyNone",
defaultPolicy: AllowAll(),
policyStack: []*Policy{
{},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteAllowed", check: checkAllowPeeringWrite},
},
},
{
name: "PeeringDefaultDenyPolicyDeny",
defaultPolicy: DenyAll(),
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Peering: PolicyDeny,
},
},
},
checks: []aclCheck{
{name: "ReadDenied", check: checkDenyPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
name: "PeeringDefaultDenyPolicyRead",
defaultPolicy: DenyAll(),
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Peering: PolicyRead,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
name: "PeeringDefaultDenyPolicyWrite",
defaultPolicy: DenyAll(),
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Peering: PolicyWrite,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteAllowed", check: checkAllowPeeringWrite},
},
},
{
name: "PeeringDefaultDenyPolicyNone",
defaultPolicy: DenyAll(),
policyStack: []*Policy{
{},
},
checks: []aclCheck{
{name: "ReadDenied", check: checkDenyPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
// o:deny, p:deny = deny
name: "PeeringOperatorDenyPolicyDeny",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyDeny,
Peering: PolicyDeny,
},
},
},
checks: []aclCheck{
{name: "ReadDenied", check: checkDenyPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
// o:read, p:deny = deny
name: "PeeringOperatorReadPolicyDeny",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyRead,
Peering: PolicyDeny,
},
},
},
checks: []aclCheck{
{name: "ReadDenied", check: checkDenyPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
// o:write, p:deny = deny
name: "PeeringOperatorWritePolicyDeny",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyWrite,
Peering: PolicyDeny,
},
},
},
checks: []aclCheck{
{name: "ReadDenied", check: checkDenyPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
// o:deny, p:read = read
name: "PeeringOperatorDenyPolicyRead",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyDeny,
Peering: PolicyRead,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
// o:read, p:read = read
name: "PeeringOperatorReadPolicyRead",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyRead,
Peering: PolicyRead,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
// o:write, p:read = read
name: "PeeringOperatorWritePolicyRead",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyWrite,
Peering: PolicyRead,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
// o:deny, p:write = write
name: "PeeringOperatorDenyPolicyWrite",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyDeny,
Peering: PolicyWrite,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteAllowed", check: checkAllowPeeringWrite},
},
},
{
// o:read, p:write = write
name: "PeeringOperatorReadPolicyWrite",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyRead,
Peering: PolicyWrite,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteAllowed", check: checkAllowPeeringWrite},
},
},
{
// o:write, p:write = write
name: "PeeringOperatorWritePolicyWrite",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyWrite,
Peering: PolicyWrite,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteAllowed", check: checkAllowPeeringWrite},
},
},
{
// o:deny, p:<none> = deny
name: "PeeringOperatorDenyPolicyNone",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyDeny,
},
},
},
checks: []aclCheck{
{name: "ReadDenied", check: checkDenyPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
// o:read, p:<none> = read
name: "PeeringOperatorReadPolicyNone",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyRead,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteDenied", check: checkDenyPeeringWrite},
},
},
{
// o:write, p:<none> = write
name: "PeeringOperatorWritePolicyNone",
defaultPolicy: nil, // test both
policyStack: []*Policy{
{
PolicyRules: PolicyRules{
Operator: PolicyWrite,
},
},
},
checks: []aclCheck{
{name: "ReadAllowed", check: checkAllowPeeringRead},
{name: "WriteAllowed", check: checkAllowPeeringWrite},
},
},
{ {
name: "OperatorDefaultAllowPolicyDeny", name: "OperatorDefaultAllowPolicyDeny",
defaultPolicy: AllowAll(), defaultPolicy: AllowAll(),

View File

@ -114,6 +114,14 @@ type Authorizer interface {
// functions can be used. // functions can be used.
MeshWrite(*AuthorizerContext) EnforcementDecision MeshWrite(*AuthorizerContext) EnforcementDecision
// PeeringRead determines if the read-only Consul peering functions
// can be used.
PeeringRead(*AuthorizerContext) EnforcementDecision
// PeeringWrite determines if the stage-changing Consul peering
// functions can be used.
PeeringWrite(*AuthorizerContext) EnforcementDecision
// NodeRead checks for permission to read (discover) a given node. // NodeRead checks for permission to read (discover) a given node.
NodeRead(string, *AuthorizerContext) EnforcementDecision NodeRead(string, *AuthorizerContext) EnforcementDecision
@ -327,6 +335,24 @@ func (a AllowAuthorizer) MeshWriteAllowed(ctx *AuthorizerContext) error {
return nil return nil
} }
// PeeringReadAllowed determines if the read-only Consul peering functions
// can be used.
func (a AllowAuthorizer) PeeringReadAllowed(ctx *AuthorizerContext) error {
if a.Authorizer.PeeringRead(ctx) != Allow {
return PermissionDeniedByACLUnnamed(a, ctx, ResourcePeering, AccessRead)
}
return nil
}
// PeeringWriteAllowed determines if the state-changing Consul peering
// functions can be used.
func (a AllowAuthorizer) PeeringWriteAllowed(ctx *AuthorizerContext) error {
if a.Authorizer.PeeringWrite(ctx) != Allow {
return PermissionDeniedByACLUnnamed(a, ctx, ResourcePeering, AccessWrite)
}
return nil
}
// NodeReadAllowed checks for permission to read (discover) a given node. // NodeReadAllowed checks for permission to read (discover) a given node.
func (a AllowAuthorizer) NodeReadAllowed(name string, ctx *AuthorizerContext) error { func (a AllowAuthorizer) NodeReadAllowed(name string, ctx *AuthorizerContext) error {
if a.Authorizer.NodeRead(name, ctx) != Allow { if a.Authorizer.NodeRead(name, ctx) != Allow {
@ -542,12 +568,11 @@ func Enforce(authz Authorizer, rsc Resource, segment string, access string, ctx
return authz.SessionWrite(segment, ctx), nil return authz.SessionWrite(segment, ctx), nil
} }
case ResourcePeering: case ResourcePeering:
// TODO (peering) switch this over to using PeeringRead & PeeringWrite methods once implemented
switch lowerAccess { switch lowerAccess {
case "read": case "read":
return authz.OperatorRead(ctx), nil return authz.PeeringRead(ctx), nil
case "write": case "write":
return authz.OperatorWrite(ctx), nil return authz.PeeringWrite(ctx), nil
} }
default: default:
if processed, decision, err := enforceEnterprise(authz, rsc, segment, lowerAccess, ctx); processed { if processed, decision, err := enforceEnterprise(authz, rsc, segment, lowerAccess, ctx); processed {
@ -561,6 +586,7 @@ func Enforce(authz Authorizer, rsc Resource, segment string, access string, ctx
// NewAuthorizerFromRules is a convenience function to invoke NewPolicyFromSource followed by NewPolicyAuthorizer with // NewAuthorizerFromRules is a convenience function to invoke NewPolicyFromSource followed by NewPolicyAuthorizer with
// the parse policy. // the parse policy.
// TODO(ACL-Legacy-Compat): remove syntax arg after removing SyntaxLegacy
func NewAuthorizerFromRules(rules string, syntax SyntaxVersion, conf *Config, meta *EnterprisePolicyMeta) (Authorizer, error) { func NewAuthorizerFromRules(rules string, syntax SyntaxVersion, conf *Config, meta *EnterprisePolicyMeta) (Authorizer, error) {
policy, err := NewPolicyFromSource(rules, syntax, conf, meta) policy, err := NewPolicyFromSource(rules, syntax, conf, meta)
if err != nil { if err != nil {

View File

@ -139,6 +139,20 @@ func (m *mockAuthorizer) MeshWrite(ctx *AuthorizerContext) EnforcementDecision {
return ret.Get(0).(EnforcementDecision) return ret.Get(0).(EnforcementDecision)
} }
// PeeringRead determines if the read-only Consul peering functions
// can be used.
func (m *mockAuthorizer) PeeringRead(ctx *AuthorizerContext) EnforcementDecision {
ret := m.Called(ctx)
return ret.Get(0).(EnforcementDecision)
}
// PeeringWrite determines if the state-changing Consul peering
// functions can be used.
func (m *mockAuthorizer) PeeringWrite(ctx *AuthorizerContext) EnforcementDecision {
ret := m.Called(ctx)
return ret.Get(0).(EnforcementDecision)
}
// OperatorRead determines if the read-only Consul operator functions // OperatorRead determines if the read-only Consul operator functions
// can be used. ret := m.Called(segment, ctx) // can be used. ret := m.Called(segment, ctx)
func (m *mockAuthorizer) OperatorRead(ctx *AuthorizerContext) EnforcementDecision { func (m *mockAuthorizer) OperatorRead(ctx *AuthorizerContext) EnforcementDecision {
@ -463,29 +477,25 @@ func TestACL_Enforce(t *testing.T) {
err: "Invalid access level", err: "Invalid access level",
}, },
{ {
// TODO (peering) Update to use PeeringRead method: "PeeringRead",
method: "OperatorRead",
resource: ResourcePeering, resource: ResourcePeering,
access: "read", access: "read",
ret: Allow, ret: Allow,
}, },
{ {
// TODO (peering) Update to use PeeringRead method: "PeeringRead",
method: "OperatorRead",
resource: ResourcePeering, resource: ResourcePeering,
access: "read", access: "read",
ret: Deny, ret: Deny,
}, },
{ {
// TODO (peering) Update to use PeeringWrite method: "PeeringWrite",
method: "OperatorWrite",
resource: ResourcePeering, resource: ResourcePeering,
access: "write", access: "write",
ret: Allow, ret: Allow,
}, },
{ {
// TODO (peering) Update to use PeeringWrite method: "PeeringWrite",
method: "OperatorWrite",
resource: ResourcePeering, resource: ResourcePeering,
access: "write", access: "write",
ret: Deny, ret: Deny,

View File

@ -161,6 +161,22 @@ func (c *ChainedAuthorizer) MeshWrite(entCtx *AuthorizerContext) EnforcementDeci
}) })
} }
// PeeringRead determines if the read-only Consul peering functions
// can be used.
func (c *ChainedAuthorizer) PeeringRead(entCtx *AuthorizerContext) EnforcementDecision {
return c.executeChain(func(authz Authorizer) EnforcementDecision {
return authz.PeeringRead(entCtx)
})
}
// PeeringWrite determines if the state-changing Consul peering
// functions can be used.
func (c *ChainedAuthorizer) PeeringWrite(entCtx *AuthorizerContext) EnforcementDecision {
return c.executeChain(func(authz Authorizer) EnforcementDecision {
return authz.PeeringWrite(entCtx)
})
}
// NodeRead checks for permission to read (discover) a given node. // NodeRead checks for permission to read (discover) a given node.
func (c *ChainedAuthorizer) NodeRead(node string, entCtx *AuthorizerContext) EnforcementDecision { func (c *ChainedAuthorizer) NodeRead(node string, entCtx *AuthorizerContext) EnforcementDecision {
return c.executeChain(func(authz Authorizer) EnforcementDecision { return c.executeChain(func(authz Authorizer) EnforcementDecision {

View File

@ -68,6 +68,12 @@ func (authz testAuthorizer) MeshRead(*AuthorizerContext) EnforcementDecision {
func (authz testAuthorizer) MeshWrite(*AuthorizerContext) EnforcementDecision { func (authz testAuthorizer) MeshWrite(*AuthorizerContext) EnforcementDecision {
return EnforcementDecision(authz) return EnforcementDecision(authz)
} }
func (authz testAuthorizer) PeeringRead(*AuthorizerContext) EnforcementDecision {
return EnforcementDecision(authz)
}
func (authz testAuthorizer) PeeringWrite(*AuthorizerContext) EnforcementDecision {
return EnforcementDecision(authz)
}
func (authz testAuthorizer) OperatorRead(*AuthorizerContext) EnforcementDecision { func (authz testAuthorizer) OperatorRead(*AuthorizerContext) EnforcementDecision {
return EnforcementDecision(authz) return EnforcementDecision(authz)
} }
@ -128,6 +134,8 @@ func TestChainedAuthorizer(t *testing.T) {
checkDenyNodeWrite(t, authz, "foo", nil) checkDenyNodeWrite(t, authz, "foo", nil)
checkDenyMeshRead(t, authz, "foo", nil) checkDenyMeshRead(t, authz, "foo", nil)
checkDenyMeshWrite(t, authz, "foo", nil) checkDenyMeshWrite(t, authz, "foo", nil)
checkDenyPeeringRead(t, authz, "foo", nil)
checkDenyPeeringWrite(t, authz, "foo", nil)
checkDenyOperatorRead(t, authz, "foo", nil) checkDenyOperatorRead(t, authz, "foo", nil)
checkDenyOperatorWrite(t, authz, "foo", nil) checkDenyOperatorWrite(t, authz, "foo", nil)
checkDenyPreparedQueryRead(t, authz, "foo", nil) checkDenyPreparedQueryRead(t, authz, "foo", nil)
@ -160,6 +168,8 @@ func TestChainedAuthorizer(t *testing.T) {
checkDenyNodeWrite(t, authz, "foo", nil) checkDenyNodeWrite(t, authz, "foo", nil)
checkDenyMeshRead(t, authz, "foo", nil) checkDenyMeshRead(t, authz, "foo", nil)
checkDenyMeshWrite(t, authz, "foo", nil) checkDenyMeshWrite(t, authz, "foo", nil)
checkDenyPeeringRead(t, authz, "foo", nil)
checkDenyPeeringWrite(t, authz, "foo", nil)
checkDenyOperatorRead(t, authz, "foo", nil) checkDenyOperatorRead(t, authz, "foo", nil)
checkDenyOperatorWrite(t, authz, "foo", nil) checkDenyOperatorWrite(t, authz, "foo", nil)
checkDenyPreparedQueryRead(t, authz, "foo", nil) checkDenyPreparedQueryRead(t, authz, "foo", nil)
@ -192,6 +202,8 @@ func TestChainedAuthorizer(t *testing.T) {
checkAllowNodeWrite(t, authz, "foo", nil) checkAllowNodeWrite(t, authz, "foo", nil)
checkAllowMeshRead(t, authz, "foo", nil) checkAllowMeshRead(t, authz, "foo", nil)
checkAllowMeshWrite(t, authz, "foo", nil) checkAllowMeshWrite(t, authz, "foo", nil)
checkAllowPeeringRead(t, authz, "foo", nil)
checkAllowPeeringWrite(t, authz, "foo", nil)
checkAllowOperatorRead(t, authz, "foo", nil) checkAllowOperatorRead(t, authz, "foo", nil)
checkAllowOperatorWrite(t, authz, "foo", nil) checkAllowOperatorWrite(t, authz, "foo", nil)
checkAllowPreparedQueryRead(t, authz, "foo", nil) checkAllowPreparedQueryRead(t, authz, "foo", nil)
@ -224,6 +236,8 @@ func TestChainedAuthorizer(t *testing.T) {
checkDenyNodeWrite(t, authz, "foo", nil) checkDenyNodeWrite(t, authz, "foo", nil)
checkDenyMeshRead(t, authz, "foo", nil) checkDenyMeshRead(t, authz, "foo", nil)
checkDenyMeshWrite(t, authz, "foo", nil) checkDenyMeshWrite(t, authz, "foo", nil)
checkDenyPeeringRead(t, authz, "foo", nil)
checkDenyPeeringWrite(t, authz, "foo", nil)
checkDenyOperatorRead(t, authz, "foo", nil) checkDenyOperatorRead(t, authz, "foo", nil)
checkDenyOperatorWrite(t, authz, "foo", nil) checkDenyOperatorWrite(t, authz, "foo", nil)
checkDenyPreparedQueryRead(t, authz, "foo", nil) checkDenyPreparedQueryRead(t, authz, "foo", nil)
@ -254,6 +268,8 @@ func TestChainedAuthorizer(t *testing.T) {
checkAllowNodeWrite(t, authz, "foo", nil) checkAllowNodeWrite(t, authz, "foo", nil)
checkAllowMeshRead(t, authz, "foo", nil) checkAllowMeshRead(t, authz, "foo", nil)
checkAllowMeshWrite(t, authz, "foo", nil) checkAllowMeshWrite(t, authz, "foo", nil)
checkAllowPeeringRead(t, authz, "foo", nil)
checkAllowPeeringWrite(t, authz, "foo", nil)
checkAllowOperatorRead(t, authz, "foo", nil) checkAllowOperatorRead(t, authz, "foo", nil)
checkAllowOperatorWrite(t, authz, "foo", nil) checkAllowOperatorWrite(t, authz, "foo", nil)
checkAllowPreparedQueryRead(t, authz, "foo", nil) checkAllowPreparedQueryRead(t, authz, "foo", nil)

View File

@ -85,6 +85,7 @@ type PolicyRules struct {
Keyring string `hcl:"keyring"` Keyring string `hcl:"keyring"`
Operator string `hcl:"operator"` Operator string `hcl:"operator"`
Mesh string `hcl:"mesh"` Mesh string `hcl:"mesh"`
Peering string `hcl:"peering"`
} }
// Policy is used to represent the policy specified by an ACL configuration. // Policy is used to represent the policy specified by an ACL configuration.
@ -289,6 +290,10 @@ func (pr *PolicyRules) Validate(conf *Config) error {
return fmt.Errorf("Invalid mesh policy: %#v", pr.Mesh) return fmt.Errorf("Invalid mesh policy: %#v", pr.Mesh)
} }
// Validate the peering policy - this one is allowed to be empty
if pr.Peering != "" && !isPolicyValid(pr.Peering, false) {
return fmt.Errorf("Invalid peering policy: %#v", pr.Peering)
}
return nil return nil
} }
@ -309,6 +314,7 @@ func parseCurrent(rules string, conf *Config, meta *EnterprisePolicyMeta) (*Poli
return p, nil return p, nil
} }
// TODO(ACL-Legacy-Compat): remove in phase 2
func parseLegacy(rules string, conf *Config) (*Policy, error) { func parseLegacy(rules string, conf *Config) (*Policy, error) {
p := &Policy{} p := &Policy{}
@ -436,6 +442,7 @@ func NewPolicyFromSource(rules string, syntax SyntaxVersion, conf *Config, meta
var policy *Policy var policy *Policy
var err error var err error
switch syntax { switch syntax {
// TODO(ACL-Legacy-Compat): remove and remove as argument from function
case SyntaxLegacy: case SyntaxLegacy:
policy, err = parseLegacy(rules, conf) policy, err = parseLegacy(rules, conf)
case SyntaxCurrent: case SyntaxCurrent:

View File

@ -43,6 +43,9 @@ type policyAuthorizer struct {
// meshRule contains the mesh policies. // meshRule contains the mesh policies.
meshRule *policyAuthorizerRule meshRule *policyAuthorizerRule
// peeringRule contains the peering policies.
peeringRule *policyAuthorizerRule
// embedded enterprise policy authorizer // embedded enterprise policy authorizer
enterprisePolicyAuthorizer enterprisePolicyAuthorizer
} }
@ -322,6 +325,15 @@ func (p *policyAuthorizer) loadRules(policy *PolicyRules) error {
p.meshRule = &policyAuthorizerRule{access: access} p.meshRule = &policyAuthorizerRule{access: access}
} }
// Load the peering policy
if policy.Peering != "" {
access, err := AccessLevelFromString(policy.Peering)
if err != nil {
return err
}
p.peeringRule = &policyAuthorizerRule{access: access}
}
return nil return nil
} }
@ -692,6 +704,25 @@ func (p *policyAuthorizer) MeshWrite(ctx *AuthorizerContext) EnforcementDecision
return p.OperatorWrite(ctx) return p.OperatorWrite(ctx)
} }
// PeeringRead determines if the read-only peering functions are allowed.
func (p *policyAuthorizer) PeeringRead(ctx *AuthorizerContext) EnforcementDecision {
if p.peeringRule != nil {
return enforce(p.peeringRule.access, AccessRead)
}
// default to OperatorRead access
return p.OperatorRead(ctx)
}
// PeeringWrite determines if the state-changing peering functions are
// allowed.
func (p *policyAuthorizer) PeeringWrite(ctx *AuthorizerContext) EnforcementDecision {
if p.peeringRule != nil {
return enforce(p.peeringRule.access, AccessWrite)
}
// default to OperatorWrite access
return p.OperatorWrite(ctx)
}
// OperatorRead determines if the read-only operator functions are allowed. // OperatorRead determines if the read-only operator functions are allowed.
func (p *policyAuthorizer) OperatorRead(*AuthorizerContext) EnforcementDecision { func (p *policyAuthorizer) OperatorRead(*AuthorizerContext) EnforcementDecision {
if p.operatorRule != nil { if p.operatorRule != nil {

View File

@ -50,6 +50,8 @@ func TestPolicyAuthorizer(t *testing.T) {
{name: "DefaultNodeWrite", prefix: "foo", check: checkDefaultNodeWrite}, {name: "DefaultNodeWrite", prefix: "foo", check: checkDefaultNodeWrite},
{name: "DefaultMeshRead", prefix: "foo", check: checkDefaultMeshRead}, {name: "DefaultMeshRead", prefix: "foo", check: checkDefaultMeshRead},
{name: "DefaultMeshWrite", prefix: "foo", check: checkDefaultMeshWrite}, {name: "DefaultMeshWrite", prefix: "foo", check: checkDefaultMeshWrite},
{name: "DefaultPeeringRead", prefix: "foo", check: checkDefaultPeeringRead},
{name: "DefaultPeeringWrite", prefix: "foo", check: checkDefaultPeeringWrite},
{name: "DefaultOperatorRead", prefix: "foo", check: checkDefaultOperatorRead}, {name: "DefaultOperatorRead", prefix: "foo", check: checkDefaultOperatorRead},
{name: "DefaultOperatorWrite", prefix: "foo", check: checkDefaultOperatorWrite}, {name: "DefaultOperatorWrite", prefix: "foo", check: checkDefaultOperatorWrite},
{name: "DefaultPreparedQueryRead", prefix: "foo", check: checkDefaultPreparedQueryRead}, {name: "DefaultPreparedQueryRead", prefix: "foo", check: checkDefaultPreparedQueryRead},

View File

@ -10,6 +10,7 @@ type policyRulesMergeContext struct {
keyRules map[string]*KeyRule keyRules map[string]*KeyRule
keyPrefixRules map[string]*KeyRule keyPrefixRules map[string]*KeyRule
meshRule string meshRule string
peeringRule string
nodeRules map[string]*NodeRule nodeRules map[string]*NodeRule
nodePrefixRules map[string]*NodeRule nodePrefixRules map[string]*NodeRule
operatorRule string operatorRule string
@ -33,6 +34,7 @@ func (p *policyRulesMergeContext) init() {
p.keyRules = make(map[string]*KeyRule) p.keyRules = make(map[string]*KeyRule)
p.keyPrefixRules = make(map[string]*KeyRule) p.keyPrefixRules = make(map[string]*KeyRule)
p.meshRule = "" p.meshRule = ""
p.peeringRule = ""
p.nodeRules = make(map[string]*NodeRule) p.nodeRules = make(map[string]*NodeRule)
p.nodePrefixRules = make(map[string]*NodeRule) p.nodePrefixRules = make(map[string]*NodeRule)
p.operatorRule = "" p.operatorRule = ""
@ -119,10 +121,6 @@ func (p *policyRulesMergeContext) merge(policy *PolicyRules) {
} }
} }
if takesPrecedenceOver(policy.Mesh, p.meshRule) {
p.meshRule = policy.Mesh
}
for _, np := range policy.Nodes { for _, np := range policy.Nodes {
update := true update := true
if permission, found := p.nodeRules[np.Name]; found { if permission, found := p.nodeRules[np.Name]; found {
@ -145,6 +143,14 @@ func (p *policyRulesMergeContext) merge(policy *PolicyRules) {
} }
} }
if takesPrecedenceOver(policy.Mesh, p.meshRule) {
p.meshRule = policy.Mesh
}
if takesPrecedenceOver(policy.Peering, p.peeringRule) {
p.peeringRule = policy.Peering
}
if takesPrecedenceOver(policy.Operator, p.operatorRule) { if takesPrecedenceOver(policy.Operator, p.operatorRule) {
p.operatorRule = policy.Operator p.operatorRule = policy.Operator
} }
@ -235,6 +241,7 @@ func (p *policyRulesMergeContext) fill(merged *PolicyRules) {
merged.Keyring = p.keyringRule merged.Keyring = p.keyringRule
merged.Operator = p.operatorRule merged.Operator = p.operatorRule
merged.Mesh = p.meshRule merged.Mesh = p.meshRule
merged.Peering = p.peeringRule
// All the for loop appends are ugly but Go doesn't have a way to get // All the for loop appends are ugly but Go doesn't have a way to get
// a slice of all values within a map so this is necessary // a slice of all values within a map so this is necessary

View File

@ -65,6 +65,7 @@ func TestPolicySourceParse(t *testing.T) {
} }
operator = "deny" operator = "deny"
mesh = "deny" mesh = "deny"
peering = "deny"
service_prefix "" { service_prefix "" {
policy = "write" policy = "write"
} }
@ -147,6 +148,7 @@ func TestPolicySourceParse(t *testing.T) {
}, },
"operator": "deny", "operator": "deny",
"mesh": "deny", "mesh": "deny",
"peering": "deny",
"service_prefix": { "service_prefix": {
"": { "": {
"policy": "write" "policy": "write"
@ -253,6 +255,7 @@ func TestPolicySourceParse(t *testing.T) {
}, },
Operator: PolicyDeny, Operator: PolicyDeny,
Mesh: PolicyDeny, Mesh: PolicyDeny,
Peering: PolicyDeny,
PreparedQueryPrefixes: []*PreparedQueryRule{ PreparedQueryPrefixes: []*PreparedQueryRule{
{ {
Prefix: "", Prefix: "",
@ -743,6 +746,13 @@ func TestPolicySourceParse(t *testing.T) {
RulesJSON: `{ "mesh": "nope" }`, RulesJSON: `{ "mesh": "nope" }`,
Err: "Invalid mesh policy", Err: "Invalid mesh policy",
}, },
{
Name: "Bad Policy - Peering",
Syntax: SyntaxCurrent,
Rules: `peering = "nope"`,
RulesJSON: `{ "peering": "nope" }`,
Err: "Invalid peering policy",
},
{ {
Name: "Keyring Empty", Name: "Keyring Empty",
Syntax: SyntaxCurrent, Syntax: SyntaxCurrent,
@ -764,6 +774,13 @@ func TestPolicySourceParse(t *testing.T) {
RulesJSON: `{ "mesh": "" }`, RulesJSON: `{ "mesh": "" }`,
Expected: &Policy{PolicyRules: PolicyRules{Mesh: ""}}, Expected: &Policy{PolicyRules: PolicyRules{Mesh: ""}},
}, },
{
Name: "Peering Empty",
Syntax: SyntaxCurrent,
Rules: `peering = ""`,
RulesJSON: `{ "peering": "" }`,
Expected: &Policy{PolicyRules: PolicyRules{Peering: ""}},
},
} }
for _, tc := range cases { for _, tc := range cases {
@ -1453,66 +1470,90 @@ func TestMergePolicies(t *testing.T) {
{ {
name: "Write Precedence", name: "Write Precedence",
input: []*Policy{ input: []*Policy{
{PolicyRules: PolicyRules{ {
PolicyRules: PolicyRules{
ACL: PolicyRead, ACL: PolicyRead,
Keyring: PolicyRead, Keyring: PolicyRead,
Operator: PolicyRead, Operator: PolicyRead,
Mesh: PolicyRead, Mesh: PolicyRead,
}}, Peering: PolicyRead,
{PolicyRules: PolicyRules{
ACL: PolicyWrite,
Keyring: PolicyWrite,
Operator: PolicyWrite,
Mesh: PolicyWrite,
}},
}, },
expected: &Policy{PolicyRules: PolicyRules{ },
{
PolicyRules: PolicyRules{
ACL: PolicyWrite, ACL: PolicyWrite,
Keyring: PolicyWrite, Keyring: PolicyWrite,
Operator: PolicyWrite, Operator: PolicyWrite,
Mesh: PolicyWrite, Mesh: PolicyWrite,
}}, Peering: PolicyWrite,
},
},
},
expected: &Policy{
PolicyRules: PolicyRules{
ACL: PolicyWrite,
Keyring: PolicyWrite,
Operator: PolicyWrite,
Mesh: PolicyWrite,
Peering: PolicyWrite,
},
},
}, },
{ {
name: "Deny Precedence", name: "Deny Precedence",
input: []*Policy{ input: []*Policy{
{PolicyRules: PolicyRules{ {
PolicyRules: PolicyRules{
ACL: PolicyWrite, ACL: PolicyWrite,
Keyring: PolicyWrite, Keyring: PolicyWrite,
Operator: PolicyWrite, Operator: PolicyWrite,
Mesh: PolicyWrite, Mesh: PolicyWrite,
}}, Peering: PolicyWrite,
{PolicyRules: PolicyRules{
ACL: PolicyDeny,
Keyring: PolicyDeny,
Operator: PolicyDeny,
Mesh: PolicyDeny,
}},
}, },
expected: &Policy{PolicyRules: PolicyRules{ },
{
PolicyRules: PolicyRules{
ACL: PolicyDeny, ACL: PolicyDeny,
Keyring: PolicyDeny, Keyring: PolicyDeny,
Operator: PolicyDeny, Operator: PolicyDeny,
Mesh: PolicyDeny, Mesh: PolicyDeny,
}}, Peering: PolicyDeny,
},
},
},
expected: &Policy{
PolicyRules: PolicyRules{
ACL: PolicyDeny,
Keyring: PolicyDeny,
Operator: PolicyDeny,
Mesh: PolicyDeny,
Peering: PolicyDeny,
},
},
}, },
{ {
name: "Read Precedence", name: "Read Precedence",
input: []*Policy{ input: []*Policy{
{PolicyRules: PolicyRules{ {
PolicyRules: PolicyRules{
ACL: PolicyRead, ACL: PolicyRead,
Keyring: PolicyRead, Keyring: PolicyRead,
Operator: PolicyRead, Operator: PolicyRead,
Mesh: PolicyRead, Mesh: PolicyRead,
}}, Peering: PolicyRead,
},
},
{}, {},
}, },
expected: &Policy{PolicyRules: PolicyRules{ expected: &Policy{
PolicyRules: PolicyRules{
ACL: PolicyRead, ACL: PolicyRead,
Keyring: PolicyRead, Keyring: PolicyRead,
Operator: PolicyRead, Operator: PolicyRead,
Mesh: PolicyRead, Mesh: PolicyRead,
}}, Peering: PolicyRead,
},
},
}, },
} }
@ -1524,6 +1565,7 @@ func TestMergePolicies(t *testing.T) {
require.Equal(t, exp.Keyring, act.Keyring) require.Equal(t, exp.Keyring, act.Keyring)
require.Equal(t, exp.Operator, act.Operator) require.Equal(t, exp.Operator, act.Operator)
require.Equal(t, exp.Mesh, act.Mesh) require.Equal(t, exp.Mesh, act.Mesh)
require.Equal(t, exp.Peering, act.Peering)
require.ElementsMatch(t, exp.Agents, act.Agents) require.ElementsMatch(t, exp.Agents, act.Agents)
require.ElementsMatch(t, exp.AgentPrefixes, act.AgentPrefixes) require.ElementsMatch(t, exp.AgentPrefixes, act.AgentPrefixes)
require.ElementsMatch(t, exp.Events, act.Events) require.ElementsMatch(t, exp.Events, act.Events)
@ -1597,6 +1639,9 @@ operator = "write"
# comment # comment
mesh = "write" mesh = "write"
# comment
peering = "write"
` `
expected := ` expected := `
@ -1652,6 +1697,9 @@ operator = "write"
# comment # comment
mesh = "write" mesh = "write"
# comment
peering = "write"
` `
output, err := TranslateLegacyRules([]byte(input)) output, err := TranslateLegacyRules([]byte(input))

View File

@ -170,6 +170,20 @@ func (s *staticAuthorizer) MeshWrite(*AuthorizerContext) EnforcementDecision {
return Deny return Deny
} }
func (s *staticAuthorizer) PeeringRead(*AuthorizerContext) EnforcementDecision {
if s.defaultAllow {
return Allow
}
return Deny
}
func (s *staticAuthorizer) PeeringWrite(*AuthorizerContext) EnforcementDecision {
if s.defaultAllow {
return Allow
}
return Deny
}
func (s *staticAuthorizer) OperatorRead(*AuthorizerContext) EnforcementDecision { func (s *staticAuthorizer) OperatorRead(*AuthorizerContext) EnforcementDecision {
if s.defaultAllow { if s.defaultAllow {
return Allow return Allow

View File

@ -2044,6 +2044,14 @@ func TestACL_Authorize(t *testing.T) {
Resource: "mesh", Resource: "mesh",
Access: "write", Access: "write",
}, },
{
Resource: "peering",
Access: "read",
},
{
Resource: "peering",
Access: "write",
},
{ {
Resource: "query", Resource: "query",
Segment: "foo", Segment: "foo",
@ -2186,6 +2194,14 @@ func TestACL_Authorize(t *testing.T) {
Resource: "mesh", Resource: "mesh",
Access: "write", Access: "write",
}, },
{
Resource: "peering",
Access: "read",
},
{
Resource: "peering",
Access: "write",
},
{ {
Resource: "query", Resource: "query",
Segment: "foo", Segment: "foo",
@ -2238,6 +2254,8 @@ func TestACL_Authorize(t *testing.T) {
true, // operator:write true, // operator:write
true, // mesh:read true, // mesh:read
true, // mesh:write true, // mesh:write
true, // peering:read
true, // peering:write
false, // query:read false, // query:read
false, // query:write false, // query:write
true, // service:read true, // service:read

View File

@ -274,10 +274,10 @@ func TestACL_vetServiceRegister(t *testing.T) {
// Try to register over a service without write privs to the existing // Try to register over a service without write privs to the existing
// service. // service.
a.State.AddService(&structs.NodeService{ a.State.AddServiceWithChecks(&structs.NodeService{
ID: "my-service", ID: "my-service",
Service: "other", Service: "other",
}, "") }, nil, "")
err = a.vetServiceRegister(serviceRWSecret, &structs.NodeService{ err = a.vetServiceRegister(serviceRWSecret, &structs.NodeService{
ID: "my-service", ID: "my-service",
Service: "service", Service: "service",
@ -304,10 +304,10 @@ func TestACL_vetServiceUpdateWithAuthorizer(t *testing.T) {
require.Contains(t, err.Error(), "Unknown service") require.Contains(t, err.Error(), "Unknown service")
// Update with write privs. // Update with write privs.
a.State.AddService(&structs.NodeService{ a.State.AddServiceWithChecks(&structs.NodeService{
ID: "my-service", ID: "my-service",
Service: "service", Service: "service",
}, "") }, nil, "")
err = vetServiceUpdate(serviceRWSecret, structs.NewServiceID("my-service", nil)) err = vetServiceUpdate(serviceRWSecret, structs.NewServiceID("my-service", nil))
require.NoError(t, err) require.NoError(t, err)
@ -361,10 +361,10 @@ func TestACL_vetCheckRegisterWithAuthorizer(t *testing.T) {
// Try to register over a service check without write privs to the // Try to register over a service check without write privs to the
// existing service. // existing service.
a.State.AddService(&structs.NodeService{ a.State.AddServiceWithChecks(&structs.NodeService{
ID: "my-service", ID: "my-service",
Service: "service", Service: "service",
}, "") }, nil, "")
a.State.AddCheck(&structs.HealthCheck{ a.State.AddCheck(&structs.HealthCheck{
CheckID: types.CheckID("my-check"), CheckID: types.CheckID("my-check"),
ServiceID: "my-service", ServiceID: "my-service",
@ -410,10 +410,10 @@ func TestACL_vetCheckUpdateWithAuthorizer(t *testing.T) {
require.Contains(t, err.Error(), "Unknown check") require.Contains(t, err.Error(), "Unknown check")
// Update service check with write privs. // Update service check with write privs.
a.State.AddService(&structs.NodeService{ a.State.AddServiceWithChecks(&structs.NodeService{
ID: "my-service", ID: "my-service",
Service: "service", Service: "service",
}, "") }, nil, "")
a.State.AddCheck(&structs.HealthCheck{ a.State.AddCheck(&structs.HealthCheck{
CheckID: types.CheckID("my-service-check"), CheckID: types.CheckID("my-service-check"),
ServiceID: "my-service", ServiceID: "my-service",

View File

@ -761,12 +761,7 @@ func (a *Agent) Failed() <-chan struct{} {
} }
func (a *Agent) buildExternalGRPCServer() { func (a *Agent) buildExternalGRPCServer() {
// TLS is only enabled on the gRPC server if there's an HTTPS port configured. a.externalGRPCServer = external.NewServer(a.logger.Named("grpc.external"), a.tlsConfigurator)
var tls *tlsutil.Configurator
if a.config.HTTPSPort > 0 {
tls = a.tlsConfigurator
}
a.externalGRPCServer = external.NewServer(a.logger.Named("grpc.external"), tls)
} }
func (a *Agent) listenAndServeGRPC() error { func (a *Agent) listenAndServeGRPC() error {
@ -1346,6 +1341,8 @@ func newConsulConfig(runtimeCfg *config.RuntimeConfig, logger hclog.Logger) (*co
// function does not drift. // function does not drift.
cfg.SerfLANConfig = consul.CloneSerfLANConfig(cfg.SerfLANConfig) cfg.SerfLANConfig = consul.CloneSerfLANConfig(cfg.SerfLANConfig)
cfg.PeeringEnabled = runtimeCfg.PeeringEnabled
enterpriseConsulConfig(cfg, runtimeCfg) enterpriseConsulConfig(cfg, runtimeCfg)
return cfg, nil return cfg, nil
} }
@ -4075,6 +4072,7 @@ func (a *Agent) registerCache() {
a.cache.RegisterType(cachetype.IntentionMatchName, &cachetype.IntentionMatch{RPC: a}) a.cache.RegisterType(cachetype.IntentionMatchName, &cachetype.IntentionMatch{RPC: a})
a.cache.RegisterType(cachetype.IntentionUpstreamsName, &cachetype.IntentionUpstreams{RPC: a}) a.cache.RegisterType(cachetype.IntentionUpstreamsName, &cachetype.IntentionUpstreams{RPC: a})
a.cache.RegisterType(cachetype.IntentionUpstreamsDestinationName, &cachetype.IntentionUpstreamsDestination{RPC: a})
a.cache.RegisterType(cachetype.CatalogServicesName, &cachetype.CatalogServices{RPC: a}) a.cache.RegisterType(cachetype.CatalogServicesName, &cachetype.CatalogServices{RPC: a})
@ -4097,6 +4095,7 @@ func (a *Agent) registerCache() {
a.cache.RegisterType(cachetype.CompiledDiscoveryChainName, &cachetype.CompiledDiscoveryChain{RPC: a}) a.cache.RegisterType(cachetype.CompiledDiscoveryChainName, &cachetype.CompiledDiscoveryChain{RPC: a})
a.cache.RegisterType(cachetype.GatewayServicesName, &cachetype.GatewayServices{RPC: a}) a.cache.RegisterType(cachetype.GatewayServicesName, &cachetype.GatewayServices{RPC: a})
a.cache.RegisterType(cachetype.ServiceGatewaysName, &cachetype.ServiceGateways{RPC: a})
a.cache.RegisterType(cachetype.ConfigEntryListName, &cachetype.ConfigEntryList{RPC: a}) a.cache.RegisterType(cachetype.ConfigEntryListName, &cachetype.ConfigEntryList{RPC: a})
@ -4220,10 +4219,12 @@ func (a *Agent) proxyDataSources() proxycfg.DataSources {
Datacenters: proxycfgglue.CacheDatacenters(a.cache), Datacenters: proxycfgglue.CacheDatacenters(a.cache),
FederationStateListMeshGateways: proxycfgglue.CacheFederationStateListMeshGateways(a.cache), FederationStateListMeshGateways: proxycfgglue.CacheFederationStateListMeshGateways(a.cache),
GatewayServices: proxycfgglue.CacheGatewayServices(a.cache), GatewayServices: proxycfgglue.CacheGatewayServices(a.cache),
Health: proxycfgglue.Health(a.rpcClientHealth), ServiceGateways: proxycfgglue.CacheServiceGateways(a.cache),
Health: proxycfgglue.ClientHealth(a.rpcClientHealth),
HTTPChecks: proxycfgglue.CacheHTTPChecks(a.cache), HTTPChecks: proxycfgglue.CacheHTTPChecks(a.cache),
Intentions: proxycfgglue.CacheIntentions(a.cache), Intentions: proxycfgglue.CacheIntentions(a.cache),
IntentionUpstreams: proxycfgglue.CacheIntentionUpstreams(a.cache), IntentionUpstreams: proxycfgglue.CacheIntentionUpstreams(a.cache),
IntentionUpstreamsDestination: proxycfgglue.CacheIntentionUpstreamsDestination(a.cache),
InternalServiceDump: proxycfgglue.CacheInternalServiceDump(a.cache), InternalServiceDump: proxycfgglue.CacheInternalServiceDump(a.cache),
LeafCertificate: proxycfgglue.CacheLeafCertificate(a.cache), LeafCertificate: proxycfgglue.CacheLeafCertificate(a.cache),
PeeredUpstreams: proxycfgglue.CachePeeredUpstreams(a.cache), PeeredUpstreams: proxycfgglue.CachePeeredUpstreams(a.cache),
@ -4237,6 +4238,7 @@ func (a *Agent) proxyDataSources() proxycfg.DataSources {
if server, ok := a.delegate.(*consul.Server); ok { if server, ok := a.delegate.(*consul.Server); ok {
deps := proxycfgglue.ServerDataSourceDeps{ deps := proxycfgglue.ServerDataSourceDeps{
Datacenter: a.config.Datacenter,
EventPublisher: a.baseDeps.EventPublisher, EventPublisher: a.baseDeps.EventPublisher,
ViewStore: a.baseDeps.ViewStore, ViewStore: a.baseDeps.ViewStore,
Logger: a.logger.Named("proxycfg.server-data-sources"), Logger: a.logger.Named("proxycfg.server-data-sources"),
@ -4245,8 +4247,17 @@ func (a *Agent) proxyDataSources() proxycfg.DataSources {
} }
sources.ConfigEntry = proxycfgglue.ServerConfigEntry(deps) sources.ConfigEntry = proxycfgglue.ServerConfigEntry(deps)
sources.ConfigEntryList = proxycfgglue.ServerConfigEntryList(deps) sources.ConfigEntryList = proxycfgglue.ServerConfigEntryList(deps)
sources.CompiledDiscoveryChain = proxycfgglue.ServerCompiledDiscoveryChain(deps, proxycfgglue.CacheCompiledDiscoveryChain(a.cache))
sources.ExportedPeeredServices = proxycfgglue.ServerExportedPeeredServices(deps)
sources.FederationStateListMeshGateways = proxycfgglue.ServerFederationStateListMeshGateways(deps)
sources.GatewayServices = proxycfgglue.ServerGatewayServices(deps)
sources.Health = proxycfgglue.ServerHealth(deps, proxycfgglue.ClientHealth(a.rpcClientHealth))
sources.Intentions = proxycfgglue.ServerIntentions(deps) sources.Intentions = proxycfgglue.ServerIntentions(deps)
sources.IntentionUpstreams = proxycfgglue.ServerIntentionUpstreams(deps) sources.IntentionUpstreams = proxycfgglue.ServerIntentionUpstreams(deps)
sources.PeeredUpstreams = proxycfgglue.ServerPeeredUpstreams(deps)
sources.ServiceList = proxycfgglue.ServerServiceList(deps, proxycfgglue.CacheServiceList(a.cache))
sources.TrustBundle = proxycfgglue.ServerTrustBundle(deps)
sources.TrustBundleList = proxycfgglue.ServerTrustBundleList(deps)
} }
a.fillEnterpriseProxyDataSources(&sources) a.fillEnterpriseProxyDataSources(&sources)

View File

@ -93,7 +93,7 @@ func TestAgent_Services(t *testing.T) {
}, },
Port: 5000, Port: 5000,
} }
require.NoError(t, a.State.AddService(srv1, "")) require.NoError(t, a.State.AddServiceWithChecks(srv1, nil, ""))
req, _ := http.NewRequest("GET", "/v1/agent/services", nil) req, _ := http.NewRequest("GET", "/v1/agent/services", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
@ -128,7 +128,7 @@ func TestAgent_ServicesFiltered(t *testing.T) {
}, },
Port: 5000, Port: 5000,
} }
require.NoError(t, a.State.AddService(srv1, "")) require.NoError(t, a.State.AddServiceWithChecks(srv1, nil, ""))
// Add another service // Add another service
srv2 := &structs.NodeService{ srv2 := &structs.NodeService{
@ -140,7 +140,7 @@ func TestAgent_ServicesFiltered(t *testing.T) {
}, },
Port: 1234, Port: 1234,
} }
require.NoError(t, a.State.AddService(srv2, "")) require.NoError(t, a.State.AddServiceWithChecks(srv2, nil, ""))
req, _ := http.NewRequest("GET", "/v1/agent/services?filter="+url.QueryEscape("foo in Meta"), nil) req, _ := http.NewRequest("GET", "/v1/agent/services?filter="+url.QueryEscape("foo in Meta"), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
@ -188,7 +188,7 @@ func TestAgent_Services_ExternalConnectProxy(t *testing.T) {
Upstreams: structs.TestUpstreams(t), Upstreams: structs.TestUpstreams(t),
}, },
} }
a.State.AddService(srv1, "") a.State.AddServiceWithChecks(srv1, nil, "")
req, _ := http.NewRequest("GET", "/v1/agent/services", nil) req, _ := http.NewRequest("GET", "/v1/agent/services", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
@ -232,7 +232,7 @@ func TestAgent_Services_Sidecar(t *testing.T) {
}, },
}, },
} }
a.State.AddService(srv1, "") a.State.AddServiceWithChecks(srv1, nil, "")
req, _ := http.NewRequest("GET", "/v1/agent/services", nil) req, _ := http.NewRequest("GET", "/v1/agent/services", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
@ -281,7 +281,7 @@ func TestAgent_Services_MeshGateway(t *testing.T) {
}, },
}, },
} }
a.State.AddService(srv1, "") a.State.AddServiceWithChecks(srv1, nil, "")
req, _ := http.NewRequest("GET", "/v1/agent/services", nil) req, _ := http.NewRequest("GET", "/v1/agent/services", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
@ -325,7 +325,7 @@ func TestAgent_Services_TerminatingGateway(t *testing.T) {
}, },
}, },
} }
require.NoError(t, a.State.AddService(srv1, "")) require.NoError(t, a.State.AddServiceWithChecks(srv1, nil, ""))
req, _ := http.NewRequest("GET", "/v1/agent/services", nil) req, _ := http.NewRequest("GET", "/v1/agent/services", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
@ -370,7 +370,7 @@ func TestAgent_Services_ACLFilter(t *testing.T) {
}, },
} }
for _, s := range services { for _, s := range services {
a.State.AddService(s, "") a.State.AddServiceWithChecks(s, nil, "")
} }
t.Run("no token", func(t *testing.T) { t.Run("no token", func(t *testing.T) {
@ -7994,7 +7994,7 @@ func TestAgent_Services_ExposeConfig(t *testing.T) {
}, },
}, },
} }
a.State.AddService(srv1, "") a.State.AddServiceWithChecks(srv1, nil, "")
req, _ := http.NewRequest("GET", "/v1/agent/services", nil) req, _ := http.NewRequest("GET", "/v1/agent/services", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()

View File

@ -0,0 +1,52 @@
package cachetype
import (
"fmt"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
)
// Recommended name for registration.
const ServiceGatewaysName = "service-gateways"
// GatewayUpstreams supports fetching upstreams for a given gateway name.
type ServiceGateways struct {
RegisterOptionsBlockingRefresh
RPC RPC
}
func (g *ServiceGateways) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
var result cache.FetchResult
// The request should be a ServiceSpecificRequest.
reqReal, ok := req.(*structs.ServiceSpecificRequest)
if !ok {
return result, fmt.Errorf(
"Internal cache failure: request wrong type: %T", req)
}
// Lightweight copy this object so that manipulating QueryOptions doesn't race.
dup := *reqReal
reqReal = &dup
// Set the minimum query index to our current index so we block
reqReal.QueryOptions.MinQueryIndex = opts.MinIndex
reqReal.QueryOptions.MaxQueryTime = opts.Timeout
// Always allow stale - there's no point in hitting leader if the request is
// going to be served from cache and end up arbitrarily stale anyway. This
// allows cached service-discover to automatically read scale across all
// servers too.
reqReal.AllowStale = true
// Fetch
var reply structs.IndexedCheckServiceNodes
if err := g.RPC.RPC("Internal.ServiceGateways", reqReal, &reply); err != nil {
return result, err
}
result.Value = &reply
result.Index = reply.QueryMeta.Index
return result, nil
}

View File

@ -0,0 +1,57 @@
package cachetype
import (
"testing"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestServiceGateways(t *testing.T) {
rpc := TestRPC(t)
typ := &ServiceGateways{RPC: rpc}
// Expect the proper RPC call. This also sets the expected value
// since that is return-by-pointer in the arguments.
var resp *structs.IndexedCheckServiceNodes
rpc.On("RPC", "Internal.ServiceGateways", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.ServiceSpecificRequest)
require.Equal(t, uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(t, 1*time.Second, req.QueryOptions.MaxQueryTime)
require.True(t, req.AllowStale)
require.Equal(t, "foo", req.ServiceName)
nodes := []structs.CheckServiceNode{
{
Service: &structs.NodeService{
Tags: req.ServiceTags,
},
},
}
reply := args.Get(2).(*structs.IndexedCheckServiceNodes)
reply.Nodes = nodes
reply.QueryMeta.Index = 48
resp = reply
})
// Fetch
resultA, err := typ.Fetch(cache.FetchOptions{
MinIndex: 24,
Timeout: 1 * time.Second,
}, &structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: "foo",
})
require.NoError(t, err)
require.Equal(t, cache.FetchResult{
Value: resp,
Index: 48,
}, resultA)
rpc.AssertExpectations(t)
}

View File

@ -3,16 +3,53 @@ package cachetype
import ( import (
"context" "context"
"fmt" "fmt"
"strconv"
"time"
"github.com/mitchellh/hashstructure"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/cache"
external "github.com/hashicorp/consul/agent/grpc-external"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering" "github.com/hashicorp/consul/proto/pbpeering"
) )
// Recommended name for registration. // Recommended name for registration.
const TrustBundleReadName = "peer-trust-bundle" const TrustBundleReadName = "peer-trust-bundle"
type TrustBundleReadRequest struct {
Request *pbpeering.TrustBundleReadRequest
structs.QueryOptions
}
func (r *TrustBundleReadRequest) CacheInfo() cache.RequestInfo {
info := cache.RequestInfo{
Token: r.Token,
Datacenter: "",
MinIndex: 0,
Timeout: 0,
MustRevalidate: false,
// OPTIMIZE(peering): Cache.notifyPollingQuery polls at this interval. We need to revisit how that polling works.
// Using an exponential backoff when the result hasn't changed may be preferable.
MaxAge: 1 * time.Second,
}
v, err := hashstructure.Hash([]interface{}{
r.Request.Partition,
r.Request.Name,
}, nil)
if err == nil {
// If there is an error, we don't set the key. A blank key forces
// no cache for this request so the request is forwarded directly
// to the server.
info.Key = strconv.FormatUint(v, 10)
}
return info
}
// TrustBundle supports fetching discovering service instances via prepared // TrustBundle supports fetching discovering service instances via prepared
// queries. // queries.
type TrustBundle struct { type TrustBundle struct {
@ -33,14 +70,20 @@ func (t *TrustBundle) Fetch(_ cache.FetchOptions, req cache.Request) (cache.Fetc
// The request should be a TrustBundleReadRequest. // The request should be a TrustBundleReadRequest.
// We do not need to make a copy of this request type like in other cache types // We do not need to make a copy of this request type like in other cache types
// because the RequestInfo is synthetic. // because the RequestInfo is synthetic.
reqReal, ok := req.(*pbpeering.TrustBundleReadRequest) reqReal, ok := req.(*TrustBundleReadRequest)
if !ok { if !ok {
return result, fmt.Errorf( return result, fmt.Errorf(
"Internal cache failure: request wrong type: %T", req) "Internal cache failure: request wrong type: %T", req)
} }
// Always allow stale - there's no point in hitting leader if the request is
// going to be served from cache and end up arbitrarily stale anyway. This
// allows cached service-discover to automatically read scale across all
// servers too.
reqReal.QueryOptions.SetAllowStale(true)
// Fetch // Fetch
reply, err := t.Client.TrustBundleRead(context.Background(), reqReal) reply, err := t.Client.TrustBundleRead(external.ContextWithToken(context.Background(), reqReal.Token), reqReal.Request)
if err != nil { if err != nil {
return result, err return result, err
} }

View File

@ -33,8 +33,10 @@ func TestTrustBundle(t *testing.T) {
Return(resp, nil) Return(resp, nil)
// Fetch and assert against the result. // Fetch and assert against the result.
result, err := typ.Fetch(cache.FetchOptions{}, &pbpeering.TrustBundleReadRequest{ result, err := typ.Fetch(cache.FetchOptions{}, &TrustBundleReadRequest{
Request: &pbpeering.TrustBundleReadRequest{
Name: "foo", Name: "foo",
},
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, cache.FetchResult{ require.Equal(t, cache.FetchResult{
@ -82,7 +84,9 @@ func TestTrustBundle_MultipleUpdates(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel) t.Cleanup(cancel)
err := c.Notify(ctx, TrustBundleReadName, &pbpeering.TrustBundleReadRequest{Name: "foo"}, "updates", ch) err := c.Notify(ctx, TrustBundleReadName, &TrustBundleReadRequest{
Request: &pbpeering.TrustBundleReadRequest{Name: "foo"},
}, "updates", ch)
require.NoError(t, err) require.NoError(t, err)
i := uint64(1) i := uint64(1)

View File

@ -3,16 +3,55 @@ package cachetype
import ( import (
"context" "context"
"fmt" "fmt"
"strconv"
"time"
"github.com/mitchellh/hashstructure"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/cache"
external "github.com/hashicorp/consul/agent/grpc-external"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering" "github.com/hashicorp/consul/proto/pbpeering"
) )
// Recommended name for registration. // Recommended name for registration.
const TrustBundleListName = "trust-bundles" const TrustBundleListName = "trust-bundles"
type TrustBundleListRequest struct {
Request *pbpeering.TrustBundleListByServiceRequest
structs.QueryOptions
}
func (r *TrustBundleListRequest) CacheInfo() cache.RequestInfo {
info := cache.RequestInfo{
Token: r.Token,
Datacenter: "",
MinIndex: 0,
Timeout: 0,
MustRevalidate: false,
// OPTIMIZE(peering): Cache.notifyPollingQuery polls at this interval. We need to revisit how that polling works.
// Using an exponential backoff when the result hasn't changed may be preferable.
MaxAge: 1 * time.Second,
}
v, err := hashstructure.Hash([]interface{}{
r.Request.Partition,
r.Request.Namespace,
r.Request.ServiceName,
r.Request.Kind,
}, nil)
if err == nil {
// If there is an error, we don't set the key. A blank key forces
// no cache for this request so the request is forwarded directly
// to the server.
info.Key = strconv.FormatUint(v, 10)
}
return info
}
// TrustBundles supports fetching discovering service instances via prepared // TrustBundles supports fetching discovering service instances via prepared
// queries. // queries.
type TrustBundles struct { type TrustBundles struct {
@ -30,17 +69,23 @@ type TrustBundleLister interface {
func (t *TrustBundles) Fetch(_ cache.FetchOptions, req cache.Request) (cache.FetchResult, error) { func (t *TrustBundles) Fetch(_ cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
var result cache.FetchResult var result cache.FetchResult
// The request should be a TrustBundleListByServiceRequest. // The request should be a TrustBundleListRequest.
// We do not need to make a copy of this request type like in other cache types // We do not need to make a copy of this request type like in other cache types
// because the RequestInfo is synthetic. // because the RequestInfo is synthetic.
reqReal, ok := req.(*pbpeering.TrustBundleListByServiceRequest) reqReal, ok := req.(*TrustBundleListRequest)
if !ok { if !ok {
return result, fmt.Errorf( return result, fmt.Errorf(
"Internal cache failure: request wrong type: %T", req) "Internal cache failure: request wrong type: %T", req)
} }
// Always allow stale - there's no point in hitting leader if the request is
// going to be served from cache and end up arbitrarily stale anyway. This
// allows cached service-discover to automatically read scale across all
// servers too.
reqReal.QueryOptions.SetAllowStale(true)
// Fetch // Fetch
reply, err := t.Client.TrustBundleListByService(context.Background(), reqReal) reply, err := t.Client.TrustBundleListByService(external.ContextWithToken(context.Background(), reqReal.Token), reqReal.Request)
if err != nil { if err != nil {
return result, err return result, err
} }

View File

@ -36,8 +36,10 @@ func TestTrustBundles(t *testing.T) {
Return(resp, nil) Return(resp, nil)
// Fetch and assert against the result. // Fetch and assert against the result.
result, err := typ.Fetch(cache.FetchOptions{}, &pbpeering.TrustBundleListByServiceRequest{ result, err := typ.Fetch(cache.FetchOptions{}, &TrustBundleListRequest{
Request: &pbpeering.TrustBundleListByServiceRequest{
ServiceName: "foo", ServiceName: "foo",
},
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, cache.FetchResult{ require.Equal(t, cache.FetchResult{
@ -85,7 +87,9 @@ func TestTrustBundles_MultipleUpdates(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel) t.Cleanup(cancel)
err := c.Notify(ctx, TrustBundleListName, &pbpeering.TrustBundleListByServiceRequest{ServiceName: "foo"}, "updates", ch) err := c.Notify(ctx, TrustBundleListName, &TrustBundleListRequest{
Request: &pbpeering.TrustBundleListByServiceRequest{ServiceName: "foo"},
}, "updates", ch)
require.NoError(t, err) require.NoError(t, err)
i := uint64(1) i := uint64(1)

View File

@ -1014,6 +1014,7 @@ func (b *builder) build() (rt RuntimeConfig, err error) {
NodeMeta: c.NodeMeta, NodeMeta: c.NodeMeta,
NodeName: b.nodeName(c.NodeName), NodeName: b.nodeName(c.NodeName),
ReadReplica: boolVal(c.ReadReplica), ReadReplica: boolVal(c.ReadReplica),
PeeringEnabled: boolVal(c.Peering.Enabled),
PidFile: stringVal(c.PidFile), PidFile: stringVal(c.PidFile),
PrimaryDatacenter: primaryDatacenter, PrimaryDatacenter: primaryDatacenter,
PrimaryGateways: b.expandAllOptionalAddrs("primary_gateways", c.PrimaryGateways), PrimaryGateways: b.expandAllOptionalAddrs("primary_gateways", c.PrimaryGateways),

View File

@ -197,6 +197,7 @@ type Config struct {
NodeID *string `mapstructure:"node_id"` NodeID *string `mapstructure:"node_id"`
NodeMeta map[string]string `mapstructure:"node_meta"` NodeMeta map[string]string `mapstructure:"node_meta"`
NodeName *string `mapstructure:"node_name"` NodeName *string `mapstructure:"node_name"`
Peering Peering `mapstructure:"peering"`
Performance Performance `mapstructure:"performance"` Performance Performance `mapstructure:"performance"`
PidFile *string `mapstructure:"pid_file"` PidFile *string `mapstructure:"pid_file"`
Ports Ports `mapstructure:"ports"` Ports Ports `mapstructure:"ports"`
@ -887,3 +888,7 @@ type TLS struct {
// config merging logic. // config merging logic.
GRPCModifiedByDeprecatedConfig *struct{} `mapstructure:"-"` GRPCModifiedByDeprecatedConfig *struct{} `mapstructure:"-"`
} }
type Peering struct {
Enabled *bool `mapstructure:"enabled"`
}

View File

@ -104,6 +104,9 @@ func DefaultSource() Source {
kv_max_value_size = ` + strconv.FormatInt(raft.SuggestedMaxDataSize, 10) + ` kv_max_value_size = ` + strconv.FormatInt(raft.SuggestedMaxDataSize, 10) + `
txn_max_req_len = ` + strconv.FormatInt(raft.SuggestedMaxDataSize, 10) + ` txn_max_req_len = ` + strconv.FormatInt(raft.SuggestedMaxDataSize, 10) + `
} }
peering = {
enabled = true
}
performance = { performance = {
leave_drain_time = "5s" leave_drain_time = "5s"
raft_multiplier = ` + strconv.Itoa(int(consul.DefaultRaftMultiplier)) + ` raft_multiplier = ` + strconv.Itoa(int(consul.DefaultRaftMultiplier)) + `

View File

@ -810,6 +810,14 @@ type RuntimeConfig struct {
// flag: -non-voting-server // flag: -non-voting-server
ReadReplica bool ReadReplica bool
// PeeringEnabled enables cluster peering. This setting only applies for servers.
// When disabled, all peering RPC endpoints will return errors,
// peering requests from other clusters will receive errors, and any peerings already stored in this server's
// state will be ignored.
//
// hcl: peering { enabled = (true|false) }
PeeringEnabled bool
// PidFile is the file to store our PID in. // PidFile is the file to store our PID in.
// //
// hcl: pid_file = string // hcl: pid_file = string

View File

@ -5548,6 +5548,16 @@ func TestLoad_IntegrationWithFlags(t *testing.T) {
"tls.grpc was provided but TLS will NOT be enabled on the gRPC listener without an HTTPS listener configured (e.g. via ports.https)", "tls.grpc was provided but TLS will NOT be enabled on the gRPC listener without an HTTPS listener configured (e.g. via ports.https)",
}, },
}) })
run(t, testCase{
desc: "peering.enabled defaults to true",
args: []string{
`-data-dir=` + dataDir,
},
expected: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.PeeringEnabled = true
},
})
} }
func (tc testCase) run(format string, dataDir string) func(t *testing.T) { func (tc testCase) run(format string, dataDir string) func(t *testing.T) {
@ -5955,6 +5965,7 @@ func TestLoad_FullConfig(t *testing.T) {
NodeMeta: map[string]string{"5mgGQMBk": "mJLtVMSG", "A7ynFMJB": "0Nx6RGab"}, NodeMeta: map[string]string{"5mgGQMBk": "mJLtVMSG", "A7ynFMJB": "0Nx6RGab"},
NodeName: "otlLxGaI", NodeName: "otlLxGaI",
ReadReplica: true, ReadReplica: true,
PeeringEnabled: true,
PidFile: "43xN80Km", PidFile: "43xN80Km",
PrimaryGateways: []string{"aej8eeZo", "roh2KahS"}, PrimaryGateways: []string{"aej8eeZo", "roh2KahS"},
PrimaryGatewaysInterval: 18866 * time.Second, PrimaryGatewaysInterval: 18866 * time.Second,

View File

@ -235,6 +235,7 @@
"NodeID": "", "NodeID": "",
"NodeMeta": {}, "NodeMeta": {},
"NodeName": "", "NodeName": "",
"PeeringEnabled": false,
"PidFile": "", "PidFile": "",
"PrimaryDatacenter": "", "PrimaryDatacenter": "",
"PrimaryGateways": [ "PrimaryGateways": [

View File

@ -305,6 +305,9 @@ node_meta {
node_name = "otlLxGaI" node_name = "otlLxGaI"
non_voting_server = true non_voting_server = true
partition = "" partition = ""
peering {
enabled = true
}
performance { performance {
leave_drain_time = "8265s" leave_drain_time = "8265s"
raft_multiplier = 5 raft_multiplier = 5

View File

@ -305,6 +305,9 @@
"node_name": "otlLxGaI", "node_name": "otlLxGaI",
"non_voting_server": true, "non_voting_server": true,
"partition": "", "partition": "",
"peering": {
"enabled": true
},
"performance": { "performance": {
"leave_drain_time": "8265s", "leave_drain_time": "8265s",
"raft_multiplier": 5, "raft_multiplier": 5,

View File

@ -4,9 +4,9 @@ import (
"context" "context"
"fmt" "fmt"
iamauth "github.com/hashicorp/consul-awsauth"
"github.com/hashicorp/consul/agent/consul/authmethod" "github.com/hashicorp/consul/agent/consul/authmethod"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/iamauth"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
) )

View File

@ -8,10 +8,10 @@ import (
"testing" "testing"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
iamauth "github.com/hashicorp/consul-awsauth"
"github.com/hashicorp/consul-awsauth/iamauthtest"
"github.com/hashicorp/consul/agent/consul/authmethod" "github.com/hashicorp/consul/agent/consul/authmethod"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/iamauth"
"github.com/hashicorp/consul/internal/iamauth/iamauthtest"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )

View File

@ -176,7 +176,7 @@ func servicePreApply(service *structs.NodeService, authz resolver.Result, authzC
// Verify ServiceName provided if ID. // Verify ServiceName provided if ID.
if service.ID != "" && service.Service == "" { if service.ID != "" && service.Service == "" {
return fmt.Errorf("Must provide service name with ID") return fmt.Errorf("Must provide service name (Service.Service) when service ID is provided")
} }
// Check the service address here and in the agent endpoint // Check the service address here and in the agent endpoint

View File

@ -396,6 +396,9 @@ type Config struct {
RaftBoltDBConfig RaftBoltDBConfig RaftBoltDBConfig RaftBoltDBConfig
// PeeringEnabled enables cluster peering.
PeeringEnabled bool
// Embedded Consul Enterprise specific configuration // Embedded Consul Enterprise specific configuration
*EnterpriseConfig *EnterpriseConfig
} }
@ -512,6 +515,8 @@ func DefaultConfig() *Config {
DefaultQueryTime: 300 * time.Second, DefaultQueryTime: 300 * time.Second,
MaxQueryTime: 600 * time.Second, MaxQueryTime: 600 * time.Second,
PeeringEnabled: true,
EnterpriseConfig: DefaultEnterpriseConfig(), EnterpriseConfig: DefaultEnterpriseConfig(),
} }

View File

@ -1141,7 +1141,7 @@ func TestConfigEntry_ResolveServiceConfig_TransparentProxy(t *testing.T) {
Name: "foo", Name: "foo",
Mode: structs.ProxyModeTransparent, Mode: structs.ProxyModeTransparent,
Destination: &structs.DestinationConfig{ Destination: &structs.DestinationConfig{
Address: "hello.world.com", Addresses: []string{"hello.world.com"},
Port: 443, Port: 443,
}, },
}, },
@ -1153,7 +1153,7 @@ func TestConfigEntry_ResolveServiceConfig_TransparentProxy(t *testing.T) {
expect: structs.ServiceConfigResponse{ expect: structs.ServiceConfigResponse{
Mode: structs.ProxyModeTransparent, Mode: structs.ProxyModeTransparent,
Destination: structs.DestinationConfig{ Destination: structs.DestinationConfig{
Address: "hello.world.com", Addresses: []string{"hello.world.com"},
Port: 443, Port: 443,
}, },
}, },

View File

@ -324,4 +324,11 @@ func (c *FSM) registerStreamSnapshotHandlers() {
if err != nil { if err != nil {
panic(fmt.Errorf("fatal error encountered registering streaming snapshot handlers: %w", err)) panic(fmt.Errorf("fatal error encountered registering streaming snapshot handlers: %w", err))
} }
err = c.deps.Publisher.RegisterHandler(state.EventTopicServiceList, func(req stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) {
return c.State().ServiceListSnapshot(req, buf)
}, true)
if err != nil {
panic(fmt.Errorf("fatal error encountered registering streaming snapshot handlers: %w", err))
}
} }

View File

@ -1213,10 +1213,12 @@ func registerTestRoutingConfigTopologyEntries(t *testing.T, codec rpc.ClientCode
func registerLocalAndRemoteServicesVIPEnabled(t *testing.T, state *state.Store) { func registerLocalAndRemoteServicesVIPEnabled(t *testing.T, state *state.Store) {
t.Helper() t.Helper()
retry.Run(t, func(r *retry.R) {
_, entry, err := state.SystemMetadataGet(nil, structs.SystemMetadataVirtualIPsEnabled) _, entry, err := state.SystemMetadataGet(nil, structs.SystemMetadataVirtualIPsEnabled)
require.NoError(t, err) require.NoError(r, err)
require.NotNil(t, entry) require.NotNil(r, entry)
require.Equal(t, "true", entry.Value) require.Equal(r, "true", entry.Value)
})
// Register a local connect-native service // Register a local connect-native service
require.NoError(t, state.EnsureRegistration(10, &structs.RegisterRequest{ require.NoError(t, state.EnsureRegistration(10, &structs.RegisterRequest{
@ -1462,7 +1464,7 @@ func registerIntentionUpstreamEntries(t *testing.T, codec rpc.ClientCodec, token
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "api.example.com", Name: "api.example.com",
Destination: &structs.DestinationConfig{ Destination: &structs.DestinationConfig{
Address: "api.example.com", Addresses: []string{"api.example.com"},
Port: 443, Port: 443,
}, },
}, },
@ -1474,7 +1476,7 @@ func registerIntentionUpstreamEntries(t *testing.T, codec rpc.ClientCodec, token
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "kafka.store.com", Name: "kafka.store.com",
Destination: &structs.DestinationConfig{ Destination: &structs.DestinationConfig{
Address: "172.168.2.1", Addresses: []string{"172.168.2.1"},
Port: 9003, Port: 9003,
}, },
}, },

View File

@ -453,6 +453,56 @@ func (m *Internal) GatewayServiceDump(args *structs.ServiceSpecificRequest, repl
return err return err
} }
// ServiceGateways returns all the nodes for services associated with a gateway along with their gateway config
func (m *Internal) ServiceGateways(args *structs.ServiceSpecificRequest, reply *structs.IndexedCheckServiceNodes) error {
if done, err := m.srv.ForwardRPC("Internal.ServiceGateways", args, reply); done {
return err
}
// Verify the arguments
if args.ServiceName == "" {
return fmt.Errorf("Must provide gateway name")
}
var authzContext acl.AuthorizerContext
authz, err := m.srv.ResolveTokenAndDefaultMeta(args.Token, &args.EnterpriseMeta, &authzContext)
if err != nil {
return err
}
if err := m.srv.validateEnterpriseRequest(&args.EnterpriseMeta, false); err != nil {
return err
}
// We need read access to the service we're trying to find gateways for, so check that first.
if err := authz.ToAllowAuthorizer().ServiceReadAllowed(args.ServiceName, &authzContext); err != nil {
return err
}
err = m.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
var maxIdx uint64
idx, gateways, err := state.ServiceGateways(ws, args.ServiceName, args.ServiceKind, args.EnterpriseMeta)
if err != nil {
return err
}
if idx > maxIdx {
maxIdx = idx
}
reply.Index, reply.Nodes = maxIdx, gateways
if err := m.srv.filterACL(args.Token, reply); err != nil {
return err
}
return nil
})
return err
}
// GatewayIntentions Match returns the set of intentions that match the given source/destination. // GatewayIntentions Match returns the set of intentions that match the given source/destination.
func (m *Internal) GatewayIntentions(args *structs.IntentionQueryRequest, reply *structs.IndexedIntentions) error { func (m *Internal) GatewayIntentions(args *structs.IntentionQueryRequest, reply *structs.IndexedIntentions) error {
// Forward if necessary // Forward if necessary

View File

@ -2782,6 +2782,10 @@ func TestInternal_PeeredUpstreams(t *testing.T) {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
orig := virtualIPVersionCheckInterval
virtualIPVersionCheckInterval = 50 * time.Millisecond
t.Cleanup(func() { virtualIPVersionCheckInterval = orig })
t.Parallel() t.Parallel()
_, s1 := testServerWithConfig(t) _, s1 := testServerWithConfig(t)
@ -2811,3 +2815,479 @@ func TestInternal_PeeredUpstreams(t *testing.T) {
} }
require.Equal(t, expect, out.Services) require.Equal(t, expect, out.Services)
} }
func TestInternal_ServiceGatewayService_Terminating(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
db := structs.NodeService{
ID: "db2",
Service: "db",
}
redis := structs.NodeService{
ID: "redis",
Service: "redis",
}
// Register gateway and two service instances that will be associated with it
{
arg := structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "10.1.2.2",
Service: &structs.NodeService{
ID: "terminating-gateway-01",
Service: "terminating-gateway",
Kind: structs.ServiceKindTerminatingGateway,
Port: 443,
Address: "198.18.1.3",
},
Check: &structs.HealthCheck{
Name: "terminating connect",
Status: api.HealthPassing,
ServiceID: "terminating-gateway-01",
},
}
var out struct{}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out))
arg = structs.RegisterRequest{
Datacenter: "dc1",
Node: "bar",
Address: "127.0.0.2",
Service: &structs.NodeService{
ID: "db",
Service: "db",
},
Check: &structs.HealthCheck{
Name: "db-warning",
Status: api.HealthWarning,
ServiceID: "db",
},
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out))
arg = structs.RegisterRequest{
Datacenter: "dc1",
Node: "baz",
Address: "127.0.0.3",
Service: &db,
Check: &structs.HealthCheck{
Name: "db2-passing",
Status: api.HealthPassing,
ServiceID: "db2",
},
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out))
}
// Register terminating-gateway config entry, linking it to db and redis (dne)
{
args := &structs.TerminatingGatewayConfigEntry{
Name: "terminating-gateway",
Kind: structs.TerminatingGateway,
Services: []structs.LinkedService{
{
Name: "db",
},
{
Name: "redis",
CAFile: "/etc/certs/ca.pem",
CertFile: "/etc/certs/cert.pem",
KeyFile: "/etc/certs/key.pem",
},
},
}
req := structs.ConfigEntryRequest{
Op: structs.ConfigEntryUpsert,
Datacenter: "dc1",
Entry: args,
}
var configOutput bool
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &req, &configOutput))
require.True(t, configOutput)
}
var out structs.IndexedCheckServiceNodes
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: "db",
ServiceKind: structs.ServiceKindTerminatingGateway,
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Internal.ServiceGateways", &req, &out))
for _, n := range out.Nodes {
n.Node.RaftIndex = structs.RaftIndex{}
n.Service.RaftIndex = structs.RaftIndex{}
for _, m := range n.Checks {
m.RaftIndex = structs.RaftIndex{}
}
}
expect := structs.CheckServiceNodes{
structs.CheckServiceNode{
Node: &structs.Node{
Node: "foo",
RaftIndex: structs.RaftIndex{},
Address: "10.1.2.2",
Datacenter: "dc1",
Partition: acl.DefaultPartitionName,
},
Service: &structs.NodeService{
Kind: structs.ServiceKindTerminatingGateway,
ID: "terminating-gateway-01",
Service: "terminating-gateway",
TaggedAddresses: map[string]structs.ServiceAddress{
"consul-virtual:" + db.CompoundServiceName().String(): {Address: "240.0.0.1"},
"consul-virtual:" + redis.CompoundServiceName().String(): {Address: "240.0.0.2"},
},
Weights: &structs.Weights{Passing: 1, Warning: 1},
Port: 443,
Tags: []string{},
Meta: map[string]string{},
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
RaftIndex: structs.RaftIndex{},
Address: "198.18.1.3",
},
Checks: structs.HealthChecks{
&structs.HealthCheck{
Name: "terminating connect",
Node: "foo",
CheckID: "terminating connect",
Status: api.HealthPassing,
ServiceID: "terminating-gateway-01",
ServiceName: "terminating-gateway",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
},
},
},
}
assert.Equal(t, expect, out.Nodes)
}
func TestInternal_ServiceGatewayService_Terminating_ACL(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
c.ACLInitialManagementToken = "root"
c.ACLResolverSettings.ACLDefaultPolicy = "deny"
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForTestAgent(t, s1.RPC, "dc1", testrpc.WithToken("root"))
// Create the ACL.
token, err := upsertTestTokenWithPolicyRules(codec, "root", "dc1", `
service "db" { policy = "read" }
service "terminating-gateway" { policy = "read" }
node_prefix "" { policy = "read" }`)
require.NoError(t, err)
// Register gateway and two service instances that will be associated with it
{
arg := structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
ID: "terminating-gateway",
Service: "terminating-gateway",
Kind: structs.ServiceKindTerminatingGateway,
Port: 443,
},
Check: &structs.HealthCheck{
Name: "terminating connect",
Status: api.HealthPassing,
ServiceID: "terminating-gateway",
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
var out struct{}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out))
{
arg := structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
ID: "terminating-gateway2",
Service: "terminating-gateway2",
Kind: structs.ServiceKindTerminatingGateway,
Port: 444,
},
Check: &structs.HealthCheck{
Name: "terminating connect",
Status: api.HealthPassing,
ServiceID: "terminating-gateway2",
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
var out struct{}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out))
}
arg = structs.RegisterRequest{
Datacenter: "dc1",
Node: "bar",
Address: "127.0.0.2",
Service: &structs.NodeService{
ID: "db",
Service: "db",
},
Check: &structs.HealthCheck{
Name: "db-warning",
Status: api.HealthWarning,
ServiceID: "db",
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out))
arg = structs.RegisterRequest{
Datacenter: "dc1",
Node: "baz",
Address: "127.0.0.3",
Service: &structs.NodeService{
ID: "api",
Service: "api",
},
Check: &structs.HealthCheck{
Name: "api-passing",
Status: api.HealthPassing,
ServiceID: "api",
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out))
}
// Register terminating-gateway config entry, linking it to db and api
{
args := &structs.TerminatingGatewayConfigEntry{
Name: "terminating-gateway",
Kind: structs.TerminatingGateway,
Services: []structs.LinkedService{
{Name: "db"},
{Name: "api"},
},
}
req := structs.ConfigEntryRequest{
Op: structs.ConfigEntryUpsert,
Datacenter: "dc1",
Entry: args,
WriteRequest: structs.WriteRequest{Token: "root"},
}
var out bool
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &req, &out))
require.True(t, out)
}
// Register terminating-gateway config entry, linking it to db and api
{
args := &structs.TerminatingGatewayConfigEntry{
Name: "terminating-gateway2",
Kind: structs.TerminatingGateway,
Services: []structs.LinkedService{
{Name: "db"},
{Name: "api"},
},
}
req := structs.ConfigEntryRequest{
Op: structs.ConfigEntryUpsert,
Datacenter: "dc1",
Entry: args,
WriteRequest: structs.WriteRequest{Token: "root"},
}
var out bool
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &req, &out))
require.True(t, out)
}
var out structs.IndexedCheckServiceNodes
// Not passing a token with service:read on Gateway leads to PermissionDenied
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: "db",
ServiceKind: structs.ServiceKindTerminatingGateway,
}
err = msgpackrpc.CallWithCodec(codec, "Internal.ServiceGateways", &req, &out)
require.Error(t, err, acl.ErrPermissionDenied)
// Passing a token without service:read on api leads to it getting filtered out
req = structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: "db",
ServiceKind: structs.ServiceKindTerminatingGateway,
QueryOptions: structs.QueryOptions{Token: token.SecretID},
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Internal.ServiceGateways", &req, &out))
nodes := out.Nodes
require.Len(t, nodes, 1)
require.Equal(t, "foo", nodes[0].Node.Node)
require.Equal(t, structs.ServiceKindTerminatingGateway, nodes[0].Service.Kind)
require.Equal(t, "terminating-gateway", nodes[0].Service.Service)
require.Equal(t, "terminating-gateway", nodes[0].Service.ID)
require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
}
func TestInternal_ServiceGatewayService_Terminating_Destination(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
google := structs.NodeService{
ID: "google",
Service: "google",
}
// Register service-default with conflicting destination address
{
arg := structs.ConfigEntryRequest{
Op: structs.ConfigEntryUpsert,
Datacenter: "dc1",
Entry: &structs.ServiceConfigEntry{
Name: "google",
Destination: &structs.DestinationConfig{Addresses: []string{"www.google.com"}, Port: 443},
EnterpriseMeta: *acl.DefaultEnterpriseMeta(),
},
}
var configOutput bool
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &arg, &configOutput))
require.True(t, configOutput)
}
// Register terminating-gateway config entry, linking it to google.com
{
arg := structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
Service: &structs.NodeService{
ID: "terminating-gateway",
Service: "terminating-gateway",
Kind: structs.ServiceKindTerminatingGateway,
Port: 443,
},
Check: &structs.HealthCheck{
Name: "terminating connect",
Status: api.HealthPassing,
ServiceID: "terminating-gateway",
},
}
var out struct{}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &out))
}
{
args := &structs.TerminatingGatewayConfigEntry{
Name: "terminating-gateway",
Kind: structs.TerminatingGateway,
Services: []structs.LinkedService{
{
Name: "google",
},
},
}
req := structs.ConfigEntryRequest{
Op: structs.ConfigEntryUpsert,
Datacenter: "dc1",
Entry: args,
}
var configOutput bool
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &req, &configOutput))
require.True(t, configOutput)
}
var out structs.IndexedCheckServiceNodes
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: "google",
ServiceKind: structs.ServiceKindTerminatingGateway,
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Internal.ServiceGateways", &req, &out))
nodes := out.Nodes
for _, n := range nodes {
n.Node.RaftIndex = structs.RaftIndex{}
n.Service.RaftIndex = structs.RaftIndex{}
for _, m := range n.Checks {
m.RaftIndex = structs.RaftIndex{}
}
}
expect := structs.CheckServiceNodes{
structs.CheckServiceNode{
Node: &structs.Node{
Node: "foo",
RaftIndex: structs.RaftIndex{},
Address: "127.0.0.1",
Datacenter: "dc1",
Partition: acl.DefaultPartitionName,
},
Service: &structs.NodeService{
Kind: structs.ServiceKindTerminatingGateway,
ID: "terminating-gateway",
Service: "terminating-gateway",
Weights: &structs.Weights{Passing: 1, Warning: 1},
Port: 443,
Tags: []string{},
Meta: map[string]string{},
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
TaggedAddresses: map[string]structs.ServiceAddress{
"consul-virtual:" + google.CompoundServiceName().String(): {Address: "240.0.0.1"},
},
RaftIndex: structs.RaftIndex{},
Address: "",
},
Checks: structs.HealthChecks{
&structs.HealthCheck{
Name: "terminating connect",
Node: "foo",
CheckID: "terminating connect",
Status: api.HealthPassing,
ServiceID: "terminating-gateway",
ServiceName: "terminating-gateway",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
},
},
},
}
assert.Len(t, nodes, 1)
assert.Equal(t, expect, nodes)
}

View File

@ -315,7 +315,9 @@ func (s *Server) establishLeadership(ctx context.Context) error {
s.startFederationStateAntiEntropy(ctx) s.startFederationStateAntiEntropy(ctx)
if s.config.PeeringEnabled {
s.startPeeringStreamSync(ctx) s.startPeeringStreamSync(ctx)
}
s.startDeferredDeletion(ctx) s.startDeferredDeletion(ctx)
@ -758,7 +760,9 @@ func (s *Server) stopACLReplication() {
} }
func (s *Server) startDeferredDeletion(ctx context.Context) { func (s *Server) startDeferredDeletion(ctx context.Context) {
if s.config.PeeringEnabled {
s.startPeeringDeferredDeletion(ctx) s.startPeeringDeferredDeletion(ctx)
}
s.startTenancyDeferredDeletion(ctx) s.startTenancyDeferredDeletion(ctx)
} }

View File

@ -36,7 +36,7 @@ func TestConnectCA_ConfigurationSet_ChangeKeyConfig_Primary(t *testing.T) {
keyBits int keyBits int
}{ }{
{connect.DefaultPrivateKeyType, connect.DefaultPrivateKeyBits}, {connect.DefaultPrivateKeyType, connect.DefaultPrivateKeyBits},
{"ec", 256}, // {"ec", 256}, skip since values are same as Defaults
{"ec", 384}, {"ec", 384},
{"rsa", 2048}, {"rsa", 2048},
{"rsa", 4096}, {"rsa", 4096},
@ -55,7 +55,7 @@ func TestConnectCA_ConfigurationSet_ChangeKeyConfig_Primary(t *testing.T) {
providerState := map[string]string{"foo": "dc1-value"} providerState := map[string]string{"foo": "dc1-value"}
// Initialize primary as the primary DC // Initialize primary as the primary DC
dir1, srv := testServerWithConfig(t, func(c *Config) { _, srv := testServerWithConfig(t, func(c *Config) {
c.Datacenter = "dc1" c.Datacenter = "dc1"
c.PrimaryDatacenter = "dc1" c.PrimaryDatacenter = "dc1"
c.Build = "1.6.0" c.Build = "1.6.0"
@ -63,12 +63,9 @@ func TestConnectCA_ConfigurationSet_ChangeKeyConfig_Primary(t *testing.T) {
c.CAConfig.Config["PrivateKeyBits"] = src.keyBits c.CAConfig.Config["PrivateKeyBits"] = src.keyBits
c.CAConfig.Config["test_state"] = providerState c.CAConfig.Config["test_state"] = providerState
}) })
defer os.RemoveAll(dir1)
defer srv.Shutdown()
codec := rpcClient(t, srv) codec := rpcClient(t, srv)
defer codec.Close()
testrpc.WaitForLeader(t, srv.RPC, "dc1") waitForLeaderEstablishment(t, srv)
testrpc.WaitForActiveCARoot(t, srv.RPC, "dc1", nil) testrpc.WaitForActiveCARoot(t, srv.RPC, "dc1", nil)
var ( var (

View File

@ -6,7 +6,10 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"time"
"github.com/armon/go-metrics"
"github.com/armon/go-metrics/prometheus"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
@ -14,6 +17,7 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
@ -25,8 +29,72 @@ import (
"github.com/hashicorp/consul/proto/pbpeerstream" "github.com/hashicorp/consul/proto/pbpeerstream"
) )
var leaderExportedServicesCountKey = []string{"consul", "peering", "exported_services"}
var LeaderPeeringMetrics = []prometheus.GaugeDefinition{
{
Name: leaderExportedServicesCountKey,
Help: "A gauge that tracks how many services are exported for the peering. " +
"The labels are \"peering\" and, for enterprise, \"partition\". " +
"We emit this metric every 9 seconds",
},
}
func (s *Server) startPeeringStreamSync(ctx context.Context) { func (s *Server) startPeeringStreamSync(ctx context.Context) {
s.leaderRoutineManager.Start(ctx, peeringStreamsRoutineName, s.runPeeringSync) s.leaderRoutineManager.Start(ctx, peeringStreamsRoutineName, s.runPeeringSync)
s.leaderRoutineManager.Start(ctx, peeringStreamsMetricsRoutineName, s.runPeeringMetrics)
}
func (s *Server) runPeeringMetrics(ctx context.Context) error {
ticker := time.NewTicker(s.config.MetricsReportingInterval)
defer ticker.Stop()
logger := s.logger.Named(logging.PeeringMetrics)
defaultMetrics := metrics.Default
for {
select {
case <-ctx.Done():
logger.Info("stopping peering metrics")
// "Zero-out" the metric on exit so that when prometheus scrapes this
// metric from a non-leader, it does not get a stale value.
metrics.SetGauge(leaderExportedServicesCountKey, float32(0))
return nil
case <-ticker.C:
if err := s.emitPeeringMetricsOnce(logger, defaultMetrics()); err != nil {
s.logger.Error("error emitting peering stream metrics", "error", err)
}
}
}
}
func (s *Server) emitPeeringMetricsOnce(logger hclog.Logger, metricsImpl *metrics.Metrics) error {
_, peers, err := s.fsm.State().PeeringList(nil, *structs.NodeEnterpriseMetaInPartition(structs.WildcardSpecifier))
if err != nil {
return err
}
for _, peer := range peers {
status, found := s.peerStreamServer.StreamStatus(peer.ID)
if !found {
logger.Trace("did not find status for", "peer_name", peer.Name)
continue
}
esc := status.GetExportedServicesCount()
part := peer.Partition
labels := []metrics.Label{
{Name: "peer_name", Value: peer.Name},
{Name: "peer_id", Value: peer.ID},
}
if part != "" {
labels = append(labels, metrics.Label{Name: "partition", Value: part})
}
metricsImpl.SetGaugeWithLabels(leaderExportedServicesCountKey, float32(esc), labels)
}
return nil
} }
func (s *Server) runPeeringSync(ctx context.Context) error { func (s *Server) runPeeringSync(ctx context.Context) error {
@ -49,6 +117,7 @@ func (s *Server) runPeeringSync(ctx context.Context) error {
func (s *Server) stopPeeringStreamSync() { func (s *Server) stopPeeringStreamSync() {
// will be a no-op when not started // will be a no-op when not started
s.leaderRoutineManager.Stop(peeringStreamsRoutineName) s.leaderRoutineManager.Stop(peeringStreamsRoutineName)
s.leaderRoutineManager.Stop(peeringStreamsMetricsRoutineName)
} }
// syncPeeringsAndBlock is a long-running goroutine that is responsible for watching // syncPeeringsAndBlock is a long-running goroutine that is responsible for watching
@ -225,6 +294,11 @@ func (s *Server) establishStream(ctx context.Context, logger hclog.Logger, peer
retryCtx, cancel := context.WithCancel(ctx) retryCtx, cancel := context.WithCancel(ctx)
cancelFns[peer.ID] = cancel cancelFns[peer.ID] = cancel
streamStatus, err := s.peerStreamTracker.Register(peer.ID)
if err != nil {
return fmt.Errorf("failed to register stream: %v", err)
}
// Establish a stream-specific retry so that retrying stream/conn errors isn't dependent on state store changes. // Establish a stream-specific retry so that retrying stream/conn errors isn't dependent on state store changes.
go retryLoopBackoff(retryCtx, func() error { go retryLoopBackoff(retryCtx, func() error {
// Try a new address on each iteration by advancing the ring buffer on errors. // Try a new address on each iteration by advancing the ring buffer on errors.
@ -238,8 +312,15 @@ func (s *Server) establishStream(ctx context.Context, logger hclog.Logger, peer
logger.Trace("dialing peer", "addr", addr) logger.Trace("dialing peer", "addr", addr)
conn, err := grpc.DialContext(retryCtx, addr, conn, err := grpc.DialContext(retryCtx, addr,
grpc.WithBlock(), // TODO(peering): use a grpc.WithStatsHandler here?)
tlsOption, tlsOption,
// For keep alive parameters there is a larger comment in ClientConnPool.dial about that.
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
// send keepalive pings even if there is no active streams
PermitWithoutStream: true,
}),
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to dial: %w", err) return fmt.Errorf("failed to dial: %w", err)
@ -277,8 +358,7 @@ func (s *Server) establishStream(ctx context.Context, logger hclog.Logger, peer
return err return err
}, func(err error) { }, func(err error) {
// TODO(peering): These errors should be reported in the peer status, otherwise they're only in the logs. streamStatus.TrackSendError(err.Error())
// Lockable status isn't available here though. Could report it via the peering.Service?
logger.Error("error managing peering stream", "peer_id", peer.ID, "error", err) logger.Error("error managing peering stream", "peer_id", peer.ID, "error", err)
}) })

View File

@ -4,9 +4,12 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt"
"io/ioutil"
"testing" "testing"
"time" "time"
"github.com/armon/go-metrics"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -15,20 +18,34 @@ import (
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/proto/pbpeering" "github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/consul/types"
) )
func TestLeader_PeeringSync_Lifecycle_ClientDeletion(t *testing.T) { func TestLeader_PeeringSync_Lifecycle_ClientDeletion(t *testing.T) {
t.Run("without-tls", func(t *testing.T) {
testLeader_PeeringSync_Lifecycle_ClientDeletion(t, false)
})
t.Run("with-tls", func(t *testing.T) {
testLeader_PeeringSync_Lifecycle_ClientDeletion(t, true)
})
}
func testLeader_PeeringSync_Lifecycle_ClientDeletion(t *testing.T, enableTLS bool) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
// TODO(peering): Configure with TLS
_, s1 := testServerWithConfig(t, func(c *Config) { _, s1 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s1.dc1" c.NodeName = "bob"
c.Datacenter = "dc1" c.Datacenter = "dc1"
c.TLSConfig.Domain = "consul" c.TLSConfig.Domain = "consul"
if enableTLS {
c.TLSConfig.GRPC.CAFile = "../../test/hostname/CertAuth.crt"
c.TLSConfig.GRPC.CertFile = "../../test/hostname/Bob.crt"
c.TLSConfig.GRPC.KeyFile = "../../test/hostname/Bob.key"
}
}) })
testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc1")
@ -68,9 +85,14 @@ func TestLeader_PeeringSync_Lifecycle_ClientDeletion(t *testing.T) {
// Bring up s2 and store s1's token so that it attempts to dial. // Bring up s2 and store s1's token so that it attempts to dial.
_, s2 := testServerWithConfig(t, func(c *Config) { _, s2 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s2.dc2" c.NodeName = "betty"
c.Datacenter = "dc2" c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc2" c.PrimaryDatacenter = "dc2"
if enableTLS {
c.TLSConfig.GRPC.CAFile = "../../test/hostname/CertAuth.crt"
c.TLSConfig.GRPC.CertFile = "../../test/hostname/Betty.crt"
c.TLSConfig.GRPC.KeyFile = "../../test/hostname/Betty.key"
}
}) })
testrpc.WaitForLeader(t, s2.RPC, "dc2") testrpc.WaitForLeader(t, s2.RPC, "dc2")
@ -120,15 +142,27 @@ func TestLeader_PeeringSync_Lifecycle_ClientDeletion(t *testing.T) {
} }
func TestLeader_PeeringSync_Lifecycle_ServerDeletion(t *testing.T) { func TestLeader_PeeringSync_Lifecycle_ServerDeletion(t *testing.T) {
t.Run("without-tls", func(t *testing.T) {
testLeader_PeeringSync_Lifecycle_ServerDeletion(t, false)
})
t.Run("with-tls", func(t *testing.T) {
testLeader_PeeringSync_Lifecycle_ServerDeletion(t, true)
})
}
func testLeader_PeeringSync_Lifecycle_ServerDeletion(t *testing.T, enableTLS bool) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
} }
// TODO(peering): Configure with TLS
_, s1 := testServerWithConfig(t, func(c *Config) { _, s1 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s1.dc1" c.NodeName = "bob"
c.Datacenter = "dc1" c.Datacenter = "dc1"
c.TLSConfig.Domain = "consul" c.TLSConfig.Domain = "consul"
if enableTLS {
c.TLSConfig.GRPC.CAFile = "../../test/hostname/CertAuth.crt"
c.TLSConfig.GRPC.CertFile = "../../test/hostname/Bob.crt"
c.TLSConfig.GRPC.KeyFile = "../../test/hostname/Bob.key"
}
}) })
testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc1")
@ -164,9 +198,14 @@ func TestLeader_PeeringSync_Lifecycle_ServerDeletion(t *testing.T) {
// Bring up s2 and store s1's token so that it attempts to dial. // Bring up s2 and store s1's token so that it attempts to dial.
_, s2 := testServerWithConfig(t, func(c *Config) { _, s2 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s2.dc2" c.NodeName = "betty"
c.Datacenter = "dc2" c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc2" c.PrimaryDatacenter = "dc2"
if enableTLS {
c.TLSConfig.GRPC.CAFile = "../../test/hostname/CertAuth.crt"
c.TLSConfig.GRPC.CertFile = "../../test/hostname/Betty.crt"
c.TLSConfig.GRPC.KeyFile = "../../test/hostname/Betty.key"
}
}) })
testrpc.WaitForLeader(t, s2.RPC, "dc2") testrpc.WaitForLeader(t, s2.RPC, "dc2")
@ -215,6 +254,111 @@ func TestLeader_PeeringSync_Lifecycle_ServerDeletion(t *testing.T) {
}) })
} }
func TestLeader_PeeringSync_FailsForTLSError(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Run("server-name-validation", func(t *testing.T) {
testLeader_PeeringSync_failsForTLSError(t, func(p *pbpeering.Peering) {
p.PeerServerName = "wrong.name"
}, `transport: authentication handshake failed: x509: certificate is valid for server.dc1.consul, bob.server.dc1.consul, not wrong.name`)
})
t.Run("bad-ca-roots", func(t *testing.T) {
wrongRoot, err := ioutil.ReadFile("../../test/client_certs/rootca.crt")
require.NoError(t, err)
testLeader_PeeringSync_failsForTLSError(t, func(p *pbpeering.Peering) {
p.PeerCAPems = []string{string(wrongRoot)}
}, `transport: authentication handshake failed: x509: certificate signed by unknown authority`)
})
}
func testLeader_PeeringSync_failsForTLSError(t *testing.T, peerMutateFn func(p *pbpeering.Peering), expectErr string) {
require.NotNil(t, peerMutateFn)
_, s1 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "bob"
c.Datacenter = "dc1"
c.TLSConfig.Domain = "consul"
c.TLSConfig.GRPC.CAFile = "../../test/hostname/CertAuth.crt"
c.TLSConfig.GRPC.CertFile = "../../test/hostname/Bob.crt"
c.TLSConfig.GRPC.KeyFile = "../../test/hostname/Bob.key"
})
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create a peering by generating a token
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
conn, err := grpc.DialContext(ctx, s1.config.RPCAddr.String(),
grpc.WithContextDialer(newServerDialer(s1.config.RPCAddr.String())),
grpc.WithInsecure(),
grpc.WithBlock())
require.NoError(t, err)
defer conn.Close()
peeringClient := pbpeering.NewPeeringServiceClient(conn)
req := pbpeering.GenerateTokenRequest{
PeerName: "my-peer-s2",
}
resp, err := peeringClient.GenerateToken(ctx, &req)
require.NoError(t, err)
tokenJSON, err := base64.StdEncoding.DecodeString(resp.PeeringToken)
require.NoError(t, err)
var token structs.PeeringToken
require.NoError(t, json.Unmarshal(tokenJSON, &token))
// S1 should not have a stream tracked for dc2 because s1 generated a token
// for baz, and therefore needs to wait to be dialed.
time.Sleep(1 * time.Second)
_, found := s1.peerStreamServer.StreamStatus(token.PeerID)
require.False(t, found)
var (
s2PeerID = "cc56f0b8-3885-4e78-8d7b-614a0c45712d"
)
// Bring up s2 and store s1's token so that it attempts to dial.
_, s2 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "betty"
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc2"
c.TLSConfig.GRPC.CAFile = "../../test/hostname/CertAuth.crt"
c.TLSConfig.GRPC.CertFile = "../../test/hostname/Betty.crt"
c.TLSConfig.GRPC.KeyFile = "../../test/hostname/Betty.key"
})
testrpc.WaitForLeader(t, s2.RPC, "dc2")
// Simulate a peering initiation event by writing a peering with data from a peering token.
// Eventually the leader in dc2 should dial and connect to the leader in dc1.
p := &pbpeering.Peering{
ID: s2PeerID,
Name: "my-peer-s1",
PeerID: token.PeerID,
PeerCAPems: token.CA,
PeerServerName: token.ServerName,
PeerServerAddresses: token.ServerAddresses,
}
peerMutateFn(p)
require.True(t, p.ShouldDial())
// We maintain a pointer to the peering on the write so that we can get the ID without needing to re-query the state store.
require.NoError(t, s2.fsm.State().PeeringWrite(1000, p))
retry.Run(t, func(r *retry.R) {
status, found := s2.peerStreamTracker.StreamStatus(p.ID)
require.True(r, found)
require.False(r, status.Connected)
require.Contains(r, status.LastSendErrorMessage, expectErr)
})
}
func TestLeader_Peering_DeferredDeletion(t *testing.T) { func TestLeader_Peering_DeferredDeletion(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -282,6 +426,120 @@ func TestLeader_Peering_DeferredDeletion(t *testing.T) {
}) })
} }
// Test that the dialing peer attempts to reestablish connections when the accepting peer
// shuts down without sending a Terminated message.
//
// To test this, we start the two peer servers (accepting and dialing), set up peering, and then shut down
// the accepting peer. This terminates the connection without sending a Terminated message.
// We then restart the accepting peer (we actually spin up a new server with the same config and port) and then
// assert that the dialing peer reestablishes the connection.
func TestLeader_Peering_DialerReestablishesConnectionOnError(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
// Reserve a gRPC port so we can restart the accepting server with the same port.
ports := freeport.GetN(t, 1)
acceptingServerPort := ports[0]
_, acceptingServer := testServerWithConfig(t, func(c *Config) {
c.NodeName = "acceptingServer.dc1"
c.Datacenter = "dc1"
c.TLSConfig.Domain = "consul"
c.GRPCPort = acceptingServerPort
})
testrpc.WaitForLeader(t, acceptingServer.RPC, "dc1")
// Create a peering by generating a token.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
conn, err := grpc.DialContext(ctx, acceptingServer.config.RPCAddr.String(),
grpc.WithContextDialer(newServerDialer(acceptingServer.config.RPCAddr.String())),
grpc.WithInsecure(),
grpc.WithBlock())
require.NoError(t, err)
defer conn.Close()
peeringClient := pbpeering.NewPeeringServiceClient(conn)
req := pbpeering.GenerateTokenRequest{
PeerName: "my-peer-dialing-server",
}
resp, err := peeringClient.GenerateToken(ctx, &req)
require.NoError(t, err)
tokenJSON, err := base64.StdEncoding.DecodeString(resp.PeeringToken)
require.NoError(t, err)
var token structs.PeeringToken
require.NoError(t, json.Unmarshal(tokenJSON, &token))
var (
dialingServerPeerID = token.PeerID
acceptingServerPeerID = "cc56f0b8-3885-4e78-8d7b-614a0c45712d"
)
// Bring up dialingServer and store acceptingServer's token so that it attempts to dial.
_, dialingServer := testServerWithConfig(t, func(c *Config) {
c.NodeName = "dialing-server.dc2"
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc2"
})
testrpc.WaitForLeader(t, dialingServer.RPC, "dc2")
p := &pbpeering.Peering{
ID: acceptingServerPeerID,
Name: "my-peer-accepting-server",
PeerID: token.PeerID,
PeerCAPems: token.CA,
PeerServerName: token.ServerName,
PeerServerAddresses: token.ServerAddresses,
}
require.True(t, p.ShouldDial())
require.NoError(t, dialingServer.fsm.State().PeeringWrite(1000, p))
// Wait for the stream to be connected.
retry.Run(t, func(r *retry.R) {
status, found := dialingServer.peerStreamServer.StreamStatus(p.ID)
require.True(r, found)
require.True(r, status.Connected)
})
// Wait until the dialing server has sent its roots over. This avoids a race condition where the accepting server
// shuts down, but the dialing server is still sending messages to the stream. When this happens, an error is raised
// which causes the stream to restart.
// In this test, we want to test what happens when the stream is closed when there are _no_ messages being sent.
retry.Run(t, func(r *retry.R) {
_, bundle, err := acceptingServer.fsm.State().PeeringTrustBundleRead(nil, state.Query{Value: "my-peer-dialing-server"})
require.NoError(r, err)
require.NotNil(r, bundle)
})
// Shutdown the accepting server.
require.NoError(t, acceptingServer.Shutdown())
// Have to manually shut down the gRPC server otherwise it stays bound to the port.
acceptingServer.externalGRPCServer.Stop()
// Mimic the server restarting by starting a new server with the same config.
_, acceptingServerRestart := testServerWithConfig(t, func(c *Config) {
c.NodeName = "acceptingServer.dc1"
c.Datacenter = "dc1"
c.TLSConfig.Domain = "consul"
c.GRPCPort = acceptingServerPort
})
testrpc.WaitForLeader(t, acceptingServerRestart.RPC, "dc1")
// Re-insert the peering state.
require.NoError(t, acceptingServerRestart.fsm.State().PeeringWrite(2000, &pbpeering.Peering{
ID: dialingServerPeerID,
Name: "my-peer-dialing-server",
State: pbpeering.PeeringState_PENDING,
}))
// The dialing peer should eventually reconnect.
retry.Run(t, func(r *retry.R) {
connStreams := acceptingServerRestart.peerStreamServer.ConnectedStreams()
require.Contains(r, connStreams, dialingServerPeerID)
})
}
func insertTestPeeringData(t *testing.T, store *state.Store, peer string, lastIdx uint64) uint64 { func insertTestPeeringData(t *testing.T, store *state.Store, peer string, lastIdx uint64) uint64 {
lastIdx++ lastIdx++
require.NoError(t, store.PeeringTrustBundleWrite(lastIdx, &pbpeering.PeeringTrustBundle{ require.NoError(t, store.PeeringTrustBundleWrite(lastIdx, &pbpeering.PeeringTrustBundle{
@ -309,11 +567,6 @@ func insertTestPeeringData(t *testing.T, store *state.Store, peer string, lastId
Node: "aaa", Node: "aaa",
PeerName: peer, PeerName: peer,
}, },
{
CheckID: structs.SerfCheckID,
Node: "aaa",
PeerName: peer,
},
}, },
})) }))
@ -336,11 +589,6 @@ func insertTestPeeringData(t *testing.T, store *state.Store, peer string, lastId
Node: "bbb", Node: "bbb",
PeerName: peer, PeerName: peer,
}, },
{
CheckID: structs.SerfCheckID,
Node: "bbb",
PeerName: peer,
},
}, },
})) }))
@ -363,13 +611,514 @@ func insertTestPeeringData(t *testing.T, store *state.Store, peer string, lastId
Node: "ccc", Node: "ccc",
PeerName: peer, PeerName: peer,
}, },
{
CheckID: structs.SerfCheckID,
Node: "ccc",
PeerName: peer,
},
}, },
})) }))
return lastIdx return lastIdx
} }
// TODO(peering): once we move away from keeping state in stream tracker only on leaders, move this test to consul/server_test maybe
func TestLeader_Peering_ImportedExportedServicesCount(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
// TODO(peering): Configure with TLS
_, s1 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s1.dc1"
c.Datacenter = "dc1"
c.TLSConfig.Domain = "consul"
})
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create a peering by generating a token
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
t.Cleanup(cancel)
conn, err := grpc.DialContext(ctx, s1.config.RPCAddr.String(),
grpc.WithContextDialer(newServerDialer(s1.config.RPCAddr.String())),
grpc.WithInsecure(),
grpc.WithBlock())
require.NoError(t, err)
defer conn.Close()
peeringClient := pbpeering.NewPeeringServiceClient(conn)
req := pbpeering.GenerateTokenRequest{
PeerName: "my-peer-s2",
}
resp, err := peeringClient.GenerateToken(ctx, &req)
require.NoError(t, err)
tokenJSON, err := base64.StdEncoding.DecodeString(resp.PeeringToken)
require.NoError(t, err)
var token structs.PeeringToken
require.NoError(t, json.Unmarshal(tokenJSON, &token))
var (
s2PeerID = "cc56f0b8-3885-4e78-8d7b-614a0c45712d"
lastIdx = uint64(0)
)
// Bring up s2 and store s1's token so that it attempts to dial.
_, s2 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s2.dc2"
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc2"
})
testrpc.WaitForLeader(t, s2.RPC, "dc2")
// Simulate a peering initiation event by writing a peering with data from a peering token.
// Eventually the leader in dc2 should dial and connect to the leader in dc1.
p := &pbpeering.Peering{
ID: s2PeerID,
Name: "my-peer-s1",
PeerID: token.PeerID,
PeerCAPems: token.CA,
PeerServerName: token.ServerName,
PeerServerAddresses: token.ServerAddresses,
}
require.True(t, p.ShouldDial())
lastIdx++
require.NoError(t, s2.fsm.State().PeeringWrite(lastIdx, p))
/// add services to S1 to be synced to S2
lastIdx++
require.NoError(t, s1.FSM().State().EnsureRegistration(lastIdx, &structs.RegisterRequest{
ID: types.NodeID(generateUUID()),
Node: "aaa",
Address: "10.0.0.1",
Service: &structs.NodeService{
Service: "a-service",
ID: "a-service-1",
Port: 8080,
},
Checks: structs.HealthChecks{
{
CheckID: "a-service-1-check",
ServiceName: "a-service",
ServiceID: "a-service-1",
Node: "aaa",
},
},
}))
lastIdx++
require.NoError(t, s1.FSM().State().EnsureRegistration(lastIdx, &structs.RegisterRequest{
ID: types.NodeID(generateUUID()),
Node: "bbb",
Address: "10.0.0.2",
Service: &structs.NodeService{
Service: "b-service",
ID: "b-service-1",
Port: 8080,
},
Checks: structs.HealthChecks{
{
CheckID: "b-service-1-check",
ServiceName: "b-service",
ServiceID: "b-service-1",
Node: "bbb",
},
},
}))
lastIdx++
require.NoError(t, s1.FSM().State().EnsureRegistration(lastIdx, &structs.RegisterRequest{
ID: types.NodeID(generateUUID()),
Node: "ccc",
Address: "10.0.0.3",
Service: &structs.NodeService{
Service: "c-service",
ID: "c-service-1",
Port: 8080,
},
Checks: structs.HealthChecks{
{
CheckID: "c-service-1-check",
ServiceName: "c-service",
ServiceID: "c-service-1",
Node: "ccc",
},
},
}))
/// finished adding services
type testCase struct {
name string
description string
exportedService structs.ExportedServicesConfigEntry
expectedImportedServsCount uint64
expectedExportedServsCount uint64
}
testCases := []testCase{
{
name: "wildcard",
description: "for a wildcard exported services, we want to see all services synced",
exportedService: structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: structs.WildcardSpecifier,
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peer-s2",
},
},
},
},
},
expectedImportedServsCount: 4, // 3 services from above + the "consul" service
expectedExportedServsCount: 4, // 3 services from above + the "consul" service
},
{
name: "no sync",
description: "update the config entry to allow no service sync",
exportedService: structs.ExportedServicesConfigEntry{
Name: "default",
},
expectedImportedServsCount: 0, // we want to see this decremented from 4 --> 0
expectedExportedServsCount: 0, // we want to see this decremented from 4 --> 0
},
{
name: "just a, b services",
description: "export just two services",
exportedService: structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "a-service",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peer-s2",
},
},
},
{
Name: "b-service",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peer-s2",
},
},
},
},
},
expectedImportedServsCount: 2,
expectedExportedServsCount: 2,
},
{
name: "unexport b service",
description: "by unexporting b we want to see the count decrement eventually",
exportedService: structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "a-service",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peer-s2",
},
},
},
},
},
expectedImportedServsCount: 1,
expectedExportedServsCount: 1,
},
{
name: "export c service",
description: "now export the c service and expect the count to increment",
exportedService: structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "a-service",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peer-s2",
},
},
},
{
Name: "c-service",
Consumers: []structs.ServiceConsumer{
{
PeerName: "my-peer-s2",
},
},
},
},
},
expectedImportedServsCount: 2,
expectedExportedServsCount: 2,
},
}
conn2, err := grpc.DialContext(ctx, s2.config.RPCAddr.String(),
grpc.WithContextDialer(newServerDialer(s2.config.RPCAddr.String())),
grpc.WithInsecure(),
grpc.WithBlock())
require.NoError(t, err)
defer conn2.Close()
peeringClient2 := pbpeering.NewPeeringServiceClient(conn2)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
lastIdx++
require.NoError(t, s1.fsm.State().EnsureConfigEntry(lastIdx, &tc.exportedService))
// Check that imported services count on S2 are what we expect
retry.Run(t, func(r *retry.R) {
// on Read
resp, err := peeringClient2.PeeringRead(ctx, &pbpeering.PeeringReadRequest{Name: "my-peer-s1"})
require.NoError(r, err)
require.NotNil(r, resp.Peering)
require.Equal(r, tc.expectedImportedServsCount, resp.Peering.ImportedServiceCount)
// on List
resp2, err2 := peeringClient2.PeeringList(ctx, &pbpeering.PeeringListRequest{})
require.NoError(r, err2)
require.NotEmpty(r, resp2.Peerings)
require.Equal(r, tc.expectedExportedServsCount, resp2.Peerings[0].ImportedServiceCount)
})
// Check that exported services count on S1 are what we expect
retry.Run(t, func(r *retry.R) {
// on Read
resp, err := peeringClient.PeeringRead(ctx, &pbpeering.PeeringReadRequest{Name: "my-peer-s2"})
require.NoError(r, err)
require.NotNil(r, resp.Peering)
require.Equal(r, tc.expectedImportedServsCount, resp.Peering.ExportedServiceCount)
// on List
resp2, err2 := peeringClient.PeeringList(ctx, &pbpeering.PeeringListRequest{})
require.NoError(r, err2)
require.NotEmpty(r, resp2.Peerings)
require.Equal(r, tc.expectedExportedServsCount, resp2.Peerings[0].ExportedServiceCount)
})
})
}
}
// TODO(peering): once we move away from keeping state in stream tracker only on leaders, move this test to consul/server_test maybe
func TestLeader_PeeringMetrics_emitPeeringMetrics(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
var (
s2PeerID1 = generateUUID()
s2PeerID2 = generateUUID()
testContextTimeout = 60 * time.Second
lastIdx = uint64(0)
)
// TODO(peering): Configure with TLS
_, s1 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s1.dc1"
c.Datacenter = "dc1"
c.TLSConfig.Domain = "consul"
})
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create a peering by generating a token
ctx, cancel := context.WithTimeout(context.Background(), testContextTimeout)
t.Cleanup(cancel)
conn, err := grpc.DialContext(ctx, s1.config.RPCAddr.String(),
grpc.WithContextDialer(newServerDialer(s1.config.RPCAddr.String())),
grpc.WithInsecure(),
grpc.WithBlock())
require.NoError(t, err)
defer conn.Close()
peeringClient := pbpeering.NewPeeringServiceClient(conn)
req := pbpeering.GenerateTokenRequest{
PeerName: "my-peer-s2",
}
resp, err := peeringClient.GenerateToken(ctx, &req)
require.NoError(t, err)
tokenJSON, err := base64.StdEncoding.DecodeString(resp.PeeringToken)
require.NoError(t, err)
var token structs.PeeringToken
require.NoError(t, json.Unmarshal(tokenJSON, &token))
// Bring up s2 and store s1's token so that it attempts to dial.
_, s2 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s2.dc2"
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc2"
})
testrpc.WaitForLeader(t, s2.RPC, "dc2")
// Simulate exporting services in the tracker
{
// Simulate a peering initiation event by writing a peering with data from a peering token.
// Eventually the leader in dc2 should dial and connect to the leader in dc1.
p := &pbpeering.Peering{
ID: s2PeerID1,
Name: "my-peer-s1",
PeerID: token.PeerID,
PeerCAPems: token.CA,
PeerServerName: token.ServerName,
PeerServerAddresses: token.ServerAddresses,
}
require.True(t, p.ShouldDial())
lastIdx++
require.NoError(t, s2.fsm.State().PeeringWrite(lastIdx, p))
p2 := &pbpeering.Peering{
ID: s2PeerID2,
Name: "my-peer-s3",
PeerID: token.PeerID, // doesn't much matter what these values are
PeerCAPems: token.CA,
PeerServerName: token.ServerName,
PeerServerAddresses: token.ServerAddresses,
}
require.True(t, p2.ShouldDial())
lastIdx++
require.NoError(t, s2.fsm.State().PeeringWrite(lastIdx, p2))
// connect the stream
mst1, err := s2.peeringServer.Tracker.Connected(s2PeerID1)
require.NoError(t, err)
// mimic tracking exported services
mst1.TrackExportedService(structs.ServiceName{Name: "a-service"})
mst1.TrackExportedService(structs.ServiceName{Name: "b-service"})
mst1.TrackExportedService(structs.ServiceName{Name: "c-service"})
// connect the stream
mst2, err := s2.peeringServer.Tracker.Connected(s2PeerID2)
require.NoError(t, err)
// mimic tracking exported services
mst2.TrackExportedService(structs.ServiceName{Name: "d-service"})
mst2.TrackExportedService(structs.ServiceName{Name: "e-service"})
}
// set up a metrics sink
sink := metrics.NewInmemSink(testContextTimeout, testContextTimeout)
cfg := metrics.DefaultConfig("us-west")
cfg.EnableHostname = false
met, err := metrics.New(cfg, sink)
require.NoError(t, err)
errM := s2.emitPeeringMetricsOnce(s2.logger, met)
require.NoError(t, errM)
retry.Run(t, func(r *retry.R) {
intervals := sink.Data()
require.Len(r, intervals, 1)
intv := intervals[0]
// the keys for a Gauge value look like: {serviceName}.{prefix}.{key_name};{label=value};...
keyMetric1 := fmt.Sprintf("us-west.consul.peering.exported_services;peer_name=my-peer-s1;peer_id=%s", s2PeerID1)
metric1, ok := intv.Gauges[keyMetric1]
require.True(r, ok, fmt.Sprintf("did not find the key %q", keyMetric1))
require.Equal(r, float32(3), metric1.Value) // for a, b, c services
keyMetric2 := fmt.Sprintf("us-west.consul.peering.exported_services;peer_name=my-peer-s3;peer_id=%s", s2PeerID2)
metric2, ok := intv.Gauges[keyMetric2]
require.True(r, ok, fmt.Sprintf("did not find the key %q", keyMetric2))
require.Equal(r, float32(2), metric2.Value) // for d, e services
})
}
// Test that the leader doesn't start its peering deletion routing when
// peering is disabled.
func TestLeader_Peering_NoDeletionWhenPeeringDisabled(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
_, s1 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s1.dc1"
c.Datacenter = "dc1"
c.TLSConfig.Domain = "consul"
c.PeeringEnabled = false
})
testrpc.WaitForLeader(t, s1.RPC, "dc1")
var (
peerID = "cc56f0b8-3885-4e78-8d7b-614a0c45712d"
peerName = "my-peer-s2"
lastIdx = uint64(0)
)
// Simulate a peering initiation event by writing a peering to the state store.
lastIdx++
require.NoError(t, s1.fsm.State().PeeringWrite(lastIdx, &pbpeering.Peering{
ID: peerID,
Name: peerName,
}))
// Mark the peering for deletion to trigger the termination sequence.
lastIdx++
require.NoError(t, s1.fsm.State().PeeringWrite(lastIdx, &pbpeering.Peering{
ID: peerID,
Name: peerName,
DeletedAt: structs.TimeToProto(time.Now()),
}))
// The leader routine shouldn't be running so the peering should never get deleted.
require.Never(t, func() bool {
_, peering, err := s1.fsm.State().PeeringRead(nil, state.Query{
Value: peerName,
})
if err != nil {
t.Logf("unexpected err: %s", err)
return true
}
if peering == nil {
return true
}
return false
}, 7*time.Second, 1*time.Second, "peering should not have been deleted")
}
// Test that the leader doesn't start its peering establishment routine
// when peering is disabled.
func TestLeader_Peering_NoEstablishmentWhenPeeringDisabled(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
_, s1 := testServerWithConfig(t, func(c *Config) {
c.NodeName = "s1.dc1"
c.Datacenter = "dc1"
c.TLSConfig.Domain = "consul"
c.PeeringEnabled = false
})
testrpc.WaitForLeader(t, s1.RPC, "dc1")
var (
peerID = "cc56f0b8-3885-4e78-8d7b-614a0c45712d"
peerName = "my-peer-s2"
lastIdx = uint64(0)
)
// Simulate a peering initiation event by writing a peering to the state store.
require.NoError(t, s1.fsm.State().PeeringWrite(lastIdx, &pbpeering.Peering{
ID: peerID,
Name: peerName,
PeerServerAddresses: []string{"1.2.3.4"},
}))
require.Never(t, func() bool {
_, found := s1.peerStreamTracker.StreamStatus(peerID)
return found
}, 7*time.Second, 1*time.Second, "peering should not have been established")
}

View File

@ -7,6 +7,8 @@ import (
"strconv" "strconv"
"sync" "sync"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/acl/resolver"
"github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/grpc-external/services/peerstream" "github.com/hashicorp/consul/agent/grpc-external/services/peerstream"
"github.com/hashicorp/consul/agent/rpc/peering" "github.com/hashicorp/consul/agent/rpc/peering"
@ -52,7 +54,7 @@ func (b *PeeringBackend) GetLeaderAddress() string {
// GetAgentCACertificates gets the server's raw CA data from its TLS Configurator. // GetAgentCACertificates gets the server's raw CA data from its TLS Configurator.
func (b *PeeringBackend) GetAgentCACertificates() ([]string, error) { func (b *PeeringBackend) GetAgentCACertificates() ([]string, error) {
// TODO(peering): handle empty CA pems // TODO(peering): handle empty CA pems
return b.srv.tlsConfigurator.ManualCAPems(), nil return b.srv.tlsConfigurator.GRPCManualCAPems(), nil
} }
// GetServerAddresses looks up server node addresses from the state store. // GetServerAddresses looks up server node addresses from the state store.
@ -160,3 +162,7 @@ func (b *PeeringBackend) CatalogDeregister(req *structs.DeregisterRequest) error
_, err := b.srv.leaderRaftApply("Catalog.Deregister", structs.DeregisterRequestType, req) _, err := b.srv.leaderRaftApply("Catalog.Deregister", structs.DeregisterRequestType, req)
return err return err
} }
func (b *PeeringBackend) ResolveTokenAndDefaultMeta(token string, entMeta *acl.EnterpriseMeta, authzCtx *acl.AuthorizerContext) (resolver.Result, error) {
return b.srv.ResolveTokenAndDefaultMeta(token, entMeta, authzCtx)
}

View File

@ -42,7 +42,6 @@ func TestPeeringBackend_RejectsPartition(t *testing.T) {
peeringClient := pbpeering.NewPeeringServiceClient(conn) peeringClient := pbpeering.NewPeeringServiceClient(conn)
req := pbpeering.GenerateTokenRequest{ req := pbpeering.GenerateTokenRequest{
Datacenter: "dc1",
Partition: "test", Partition: "test",
} }
_, err = peeringClient.GenerateToken(ctx, &req) _, err = peeringClient.GenerateToken(ctx, &req)
@ -77,7 +76,6 @@ func TestPeeringBackend_IgnoresDefaultPartition(t *testing.T) {
peeringClient := pbpeering.NewPeeringServiceClient(conn) peeringClient := pbpeering.NewPeeringServiceClient(conn)
req := pbpeering.GenerateTokenRequest{ req := pbpeering.GenerateTokenRequest{
Datacenter: "dc1",
PeerName: "my-peer", PeerName: "my-peer",
Partition: "DeFaUlT", Partition: "DeFaUlT",
} }

View File

@ -15,43 +15,6 @@ import (
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
) )
func TestPeeringBackend_DoesNotForwardToDifferentDC(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
_, s1 := testServerDC(t, "dc1")
_, s2 := testServerDC(t, "dc2")
joinWAN(t, s2, s1)
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForLeader(t, s2.RPC, "dc2")
// make a grpc client to dial s2 directly
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
conn, err := gogrpc.DialContext(ctx, s2.config.RPCAddr.String(),
gogrpc.WithContextDialer(newServerDialer(s2.config.RPCAddr.String())),
gogrpc.WithInsecure(),
gogrpc.WithBlock())
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })
peeringClient := pbpeering.NewPeeringServiceClient(conn)
// GenerateToken request should fail against dc1, because we are dialing dc2. The GenerateToken request should never be forwarded across datacenters.
req := pbpeering.GenerateTokenRequest{
PeerName: "peer1-usw1",
Datacenter: "dc1",
}
_, err = peeringClient.GenerateToken(ctx, &req)
require.Error(t, err)
require.Contains(t, err.Error(), "requests to generate peering tokens cannot be forwarded to remote datacenters")
}
func TestPeeringBackend_ForwardToLeader(t *testing.T) { func TestPeeringBackend_ForwardToLeader(t *testing.T) {
t.Parallel() t.Parallel()
@ -86,7 +49,6 @@ func TestPeeringBackend_ForwardToLeader(t *testing.T) {
testutil.RunStep(t, "forward a write", func(t *testing.T) { testutil.RunStep(t, "forward a write", func(t *testing.T) {
// Do the grpc Write call to server2 // Do the grpc Write call to server2
req := pbpeering.GenerateTokenRequest{ req := pbpeering.GenerateTokenRequest{
Datacenter: "dc1",
PeerName: "foo", PeerName: "foo",
} }
_, err := peeringClient.GenerateToken(ctx, &req) _, err := peeringClient.GenerateToken(ctx, &req)

View File

@ -22,7 +22,7 @@ var (
}, },
Service: structs.ServiceQuery{ Service: structs.ServiceQuery{
Service: "${name.full}", Service: "${name.full}",
Failover: structs.QueryDatacenterOptions{ Failover: structs.QueryFailoverOptions{
Datacenters: []string{ Datacenters: []string{
"${name.full}", "${name.full}",
"${name.prefix}", "${name.prefix}",
@ -69,7 +69,7 @@ var (
}, },
Service: structs.ServiceQuery{ Service: structs.ServiceQuery{
Service: "${name.full}", Service: "${name.full}",
Failover: structs.QueryDatacenterOptions{ Failover: structs.QueryFailoverOptions{
Datacenters: []string{ Datacenters: []string{
"dc1", "dc1",
"dc2", "dc2",

View File

@ -20,7 +20,7 @@ func TestWalk_ServiceQuery(t *testing.T) {
service := &structs.ServiceQuery{ service := &structs.ServiceQuery{
Service: "the-service", Service: "the-service",
Failover: structs.QueryDatacenterOptions{ Failover: structs.QueryFailoverOptions{
Datacenters: []string{"dc1", "dc2"}, Datacenters: []string{"dc1", "dc2"},
}, },
Near: "_agent", Near: "_agent",

View File

@ -187,11 +187,16 @@ func parseService(svc *structs.ServiceQuery) error {
return fmt.Errorf("Must provide a Service name to query") return fmt.Errorf("Must provide a Service name to query")
} }
failover := svc.Failover
// NearestN can be 0 which means "don't fail over by RTT". // NearestN can be 0 which means "don't fail over by RTT".
if svc.Failover.NearestN < 0 { if failover.NearestN < 0 {
return fmt.Errorf("Bad NearestN '%d', must be >= 0", svc.Failover.NearestN) return fmt.Errorf("Bad NearestN '%d', must be >= 0", svc.Failover.NearestN)
} }
if (failover.NearestN != 0 || len(failover.Datacenters) != 0) && len(failover.Targets) != 0 {
return fmt.Errorf("Targets cannot be populated with NearestN or Datacenters")
}
// Make sure the metadata filters are valid // Make sure the metadata filters are valid
if err := structs.ValidateNodeMetadata(svc.NodeMeta, true); err != nil { if err := structs.ValidateNodeMetadata(svc.NodeMeta, true); err != nil {
return err return err
@ -462,7 +467,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
// and bail out. Otherwise, we fail over and try remote DCs, as allowed // and bail out. Otherwise, we fail over and try remote DCs, as allowed
// by the query setup. // by the query setup.
if len(reply.Nodes) == 0 { if len(reply.Nodes) == 0 {
wrapper := &queryServerWrapper{p.srv} wrapper := &queryServerWrapper{srv: p.srv, executeRemote: p.ExecuteRemote}
if err := queryFailover(wrapper, query, args, reply); err != nil { if err := queryFailover(wrapper, query, args, reply); err != nil {
return err return err
} }
@ -565,8 +570,13 @@ func (p *PreparedQuery) execute(query *structs.PreparedQuery,
reply.Nodes = nodes reply.Nodes = nodes
reply.DNS = query.DNS reply.DNS = query.DNS
// Stamp the result for this datacenter. // Stamp the result with its this datacenter or peer.
if peerName := query.Service.PeerName; peerName != "" {
reply.PeerName = peerName
reply.Datacenter = ""
} else {
reply.Datacenter = p.srv.config.Datacenter reply.Datacenter = p.srv.config.Datacenter
}
return nil return nil
} }
@ -651,12 +661,24 @@ func serviceMetaFilter(filters map[string]string, nodes structs.CheckServiceNode
type queryServer interface { type queryServer interface {
GetLogger() hclog.Logger GetLogger() hclog.Logger
GetOtherDatacentersByDistance() ([]string, error) GetOtherDatacentersByDistance() ([]string, error)
ForwardDC(method, dc string, args interface{}, reply interface{}) error GetLocalDC() string
ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error
} }
// queryServerWrapper applies the queryServer interface to a Server. // queryServerWrapper applies the queryServer interface to a Server.
type queryServerWrapper struct { type queryServerWrapper struct {
srv *Server srv *Server
executeRemote func(args *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error
}
// GetLocalDC returns the name of the local datacenter.
func (q *queryServerWrapper) GetLocalDC() string {
return q.srv.config.Datacenter
}
// ExecuteRemote calls ExecuteRemote on PreparedQuery.
func (q *queryServerWrapper) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
return q.executeRemote(args, reply)
} }
// GetLogger returns the server's logger. // GetLogger returns the server's logger.
@ -683,11 +705,6 @@ func (q *queryServerWrapper) GetOtherDatacentersByDistance() ([]string, error) {
return result, nil return result, nil
} }
// ForwardDC calls into the server's RPC forwarder.
func (q *queryServerWrapper) ForwardDC(method, dc string, args interface{}, reply interface{}) error {
return q.srv.forwardDC(method, dc, args, reply)
}
// queryFailover runs an algorithm to determine which DCs to try and then calls // queryFailover runs an algorithm to determine which DCs to try and then calls
// them to try to locate alternative services. // them to try to locate alternative services.
func queryFailover(q queryServer, query *structs.PreparedQuery, func queryFailover(q queryServer, query *structs.PreparedQuery,
@ -709,7 +726,7 @@ func queryFailover(q queryServer, query *structs.PreparedQuery,
// Build a candidate list of DCs to try, starting with the nearest N // Build a candidate list of DCs to try, starting with the nearest N
// from RTTs. // from RTTs.
var dcs []string var targets []structs.QueryFailoverTarget
index := make(map[string]struct{}) index := make(map[string]struct{})
if query.Service.Failover.NearestN > 0 { if query.Service.Failover.NearestN > 0 {
for i, dc := range nearest { for i, dc := range nearest {
@ -717,15 +734,16 @@ func queryFailover(q queryServer, query *structs.PreparedQuery,
break break
} }
dcs = append(dcs, dc) targets = append(targets, structs.QueryFailoverTarget{Datacenter: dc})
index[dc] = struct{}{} index[dc] = struct{}{}
} }
} }
// Then add any DCs explicitly listed that weren't selected above. // Then add any DCs explicitly listed that weren't selected above.
for _, dc := range query.Service.Failover.Datacenters { for _, target := range query.Service.Failover.AsTargets() {
// This will prevent a log of other log spammage if we do not // This will prevent a log of other log spammage if we do not
// attempt to talk to datacenters we don't know about. // attempt to talk to datacenters we don't know about.
if dc := target.Datacenter; dc != "" {
if _, ok := known[dc]; !ok { if _, ok := known[dc]; !ok {
q.GetLogger().Debug("Skipping unknown datacenter in prepared query", "datacenter", dc) q.GetLogger().Debug("Skipping unknown datacenter in prepared query", "datacenter", dc)
continue continue
@ -734,13 +752,18 @@ func queryFailover(q queryServer, query *structs.PreparedQuery,
// This will make sure we don't re-try something that fails // This will make sure we don't re-try something that fails
// from the NearestN list. // from the NearestN list.
if _, ok := index[dc]; !ok { if _, ok := index[dc]; !ok {
dcs = append(dcs, dc) targets = append(targets, target)
}
}
if target.PeerName != "" {
targets = append(targets, target)
} }
} }
// Now try the selected DCs in priority order. // Now try the selected DCs in priority order.
failovers := 0 failovers := 0
for _, dc := range dcs { for _, target := range targets {
// This keeps track of how many iterations we actually run. // This keeps track of how many iterations we actually run.
failovers++ failovers++
@ -752,7 +775,15 @@ func queryFailover(q queryServer, query *structs.PreparedQuery,
// through this slice across successive RPC calls. // through this slice across successive RPC calls.
reply.Nodes = nil reply.Nodes = nil
// Note that we pass along the limit since it can be applied // Reset PeerName because it may have been set by a previous failover
// target.
query.Service.PeerName = target.PeerName
dc := target.Datacenter
if target.PeerName != "" {
dc = q.GetLocalDC()
}
// Note that we pass along the limit since may be applied
// remotely to save bandwidth. We also pass along the consistency // remotely to save bandwidth. We also pass along the consistency
// mode information and token we were given, so that applies to // mode information and token we were given, so that applies to
// the remote query as well. // the remote query as well.
@ -763,9 +794,11 @@ func queryFailover(q queryServer, query *structs.PreparedQuery,
QueryOptions: args.QueryOptions, QueryOptions: args.QueryOptions,
Connect: args.Connect, Connect: args.Connect,
} }
if err := q.ForwardDC("PreparedQuery.ExecuteRemote", dc, remote, reply); err != nil {
if err = q.ExecuteRemote(remote, reply); err != nil {
q.GetLogger().Warn("Failed querying for service in datacenter", q.GetLogger().Warn("Failed querying for service in datacenter",
"service", query.Service.Service, "service", query.Service.Service,
"peerName", query.Service.PeerName,
"datacenter", dc, "datacenter", dc,
"error", err, "error", err,
) )

View File

@ -2,6 +2,9 @@ package consul
import ( import (
"bytes" "bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
@ -14,6 +17,7 @@ import (
"github.com/hashicorp/serf/coordinate" "github.com/hashicorp/serf/coordinate"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc"
msgpackrpc "github.com/hashicorp/consul-net-rpc/net-rpc-msgpackrpc" msgpackrpc "github.com/hashicorp/consul-net-rpc/net-rpc-msgpackrpc"
"github.com/hashicorp/consul-net-rpc/net/rpc" "github.com/hashicorp/consul-net-rpc/net/rpc"
@ -23,6 +27,7 @@ import (
"github.com/hashicorp/consul/agent/structs/aclfilter" "github.com/hashicorp/consul/agent/structs/aclfilter"
tokenStore "github.com/hashicorp/consul/agent/token" tokenStore "github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/consul/types" "github.com/hashicorp/consul/types"
@ -82,8 +87,25 @@ func TestPreparedQuery_Apply(t *testing.T) {
t.Fatalf("bad: %v", err) t.Fatalf("bad: %v", err)
} }
// Fix that and make sure it propagates an error from the Raft apply. // Fix that and ensure Targets and NearestN cannot be set at the same time.
query.Query.Service.Failover.NearestN = 1
query.Query.Service.Failover.Targets = []structs.QueryFailoverTarget{{PeerName: "peer"}}
err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Apply", &query, &reply)
if err == nil || !strings.Contains(err.Error(), "Targets cannot be populated with") {
t.Fatalf("bad: %v", err)
}
// Fix that and ensure Targets and Datacenters cannot be set at the same time.
query.Query.Service.Failover.NearestN = 0 query.Query.Service.Failover.NearestN = 0
query.Query.Service.Failover.Datacenters = []string{"dc2"}
query.Query.Service.Failover.Targets = []structs.QueryFailoverTarget{{PeerName: "peer"}}
err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Apply", &query, &reply)
if err == nil || !strings.Contains(err.Error(), "Targets cannot be populated with") {
t.Fatalf("bad: %v", err)
}
// Fix that and make sure it propagates an error from the Raft apply.
query.Query.Service.Failover.Targets = nil
query.Query.Session = "nope" query.Query.Session = "nope"
err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Apply", &query, &reply) err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Apply", &query, &reply)
if err == nil || !strings.Contains(err.Error(), "invalid session") { if err == nil || !strings.Contains(err.Error(), "invalid session") {
@ -1442,6 +1464,17 @@ func TestPreparedQuery_Execute(t *testing.T) {
s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig)
dir3, s3 := testServerWithConfig(t, func(c *Config) {
c.Datacenter = "dc3"
c.PrimaryDatacenter = "dc3"
c.NodeName = "acceptingServer.dc3"
})
defer os.RemoveAll(dir3)
defer s3.Shutdown()
waitForLeaderEstablishment(t, s3)
codec3 := rpcClient(t, s3)
defer codec3.Close()
// Try to WAN join. // Try to WAN join.
joinWAN(t, s2, s1) joinWAN(t, s2, s1)
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
@ -1456,6 +1489,70 @@ func TestPreparedQuery_Execute(t *testing.T) {
// check for RPC forwarding // check for RPC forwarding
testrpc.WaitForLeader(t, s1.RPC, "dc1", testrpc.WithToken("root")) testrpc.WaitForLeader(t, s1.RPC, "dc1", testrpc.WithToken("root"))
testrpc.WaitForLeader(t, s1.RPC, "dc2", testrpc.WithToken("root")) testrpc.WaitForLeader(t, s1.RPC, "dc2", testrpc.WithToken("root"))
testrpc.WaitForLeader(t, s3.RPC, "dc3")
acceptingPeerName := "my-peer-accepting-server"
dialingPeerName := "my-peer-dialing-server"
// Set up peering between dc1 (dailing) and dc3 (accepting) and export the foo service
{
// Create a peering by generating a token.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
conn, err := grpc.DialContext(ctx, s3.config.RPCAddr.String(),
grpc.WithContextDialer(newServerDialer(s3.config.RPCAddr.String())),
grpc.WithInsecure(),
grpc.WithBlock())
require.NoError(t, err)
defer conn.Close()
peeringClient := pbpeering.NewPeeringServiceClient(conn)
req := pbpeering.GenerateTokenRequest{
PeerName: dialingPeerName,
}
resp, err := peeringClient.GenerateToken(ctx, &req)
require.NoError(t, err)
tokenJSON, err := base64.StdEncoding.DecodeString(resp.PeeringToken)
require.NoError(t, err)
var token structs.PeeringToken
require.NoError(t, json.Unmarshal(tokenJSON, &token))
p := &pbpeering.Peering{
ID: "cc56f0b8-3885-4e78-8d7b-614a0c45712d",
Name: acceptingPeerName,
PeerID: token.PeerID,
PeerCAPems: token.CA,
PeerServerName: token.ServerName,
PeerServerAddresses: token.ServerAddresses,
}
require.True(t, p.ShouldDial())
require.NoError(t, s1.fsm.State().PeeringWrite(1000, p))
// Wait for the stream to be connected.
retry.Run(t, func(r *retry.R) {
status, found := s1.peerStreamServer.StreamStatus(p.ID)
require.True(r, found)
require.True(r, status.Connected)
})
exportedServices := structs.ConfigEntryRequest{
Op: structs.ConfigEntryUpsert,
Datacenter: "dc3",
Entry: &structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "foo",
Consumers: []structs.ServiceConsumer{{PeerName: dialingPeerName}},
},
},
},
}
var configOutput bool
require.NoError(t, msgpackrpc.CallWithCodec(codec3, "ConfigEntry.Apply", &exportedServices, &configOutput))
require.True(t, configOutput)
}
execNoNodesToken := createTokenWithPolicyName(t, codec1, "no-nodes", `service_prefix "foo" { policy = "read" }`, "root") execNoNodesToken := createTokenWithPolicyName(t, codec1, "no-nodes", `service_prefix "foo" { policy = "read" }`, "root")
rules := ` rules := `
@ -1485,9 +1582,16 @@ func TestPreparedQuery_Execute(t *testing.T) {
// Set up some nodes in each DC that host the service. // Set up some nodes in each DC that host the service.
{ {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
for _, dc := range []string{"dc1", "dc2"} { for _, d := range []struct {
codec rpc.ClientCodec
dc string
}{
{codec1, "dc1"},
{codec2, "dc2"},
{codec3, "dc3"},
} {
req := structs.RegisterRequest{ req := structs.RegisterRequest{
Datacenter: dc, Datacenter: d.dc,
Node: fmt.Sprintf("node%d", i+1), Node: fmt.Sprintf("node%d", i+1),
Address: fmt.Sprintf("127.0.0.%d", i+1), Address: fmt.Sprintf("127.0.0.%d", i+1),
NodeMeta: map[string]string{ NodeMeta: map[string]string{
@ -1497,7 +1601,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
Service: &structs.NodeService{ Service: &structs.NodeService{
Service: "foo", Service: "foo",
Port: 8000, Port: 8000,
Tags: []string{dc, fmt.Sprintf("tag%d", i+1)}, Tags: []string{d.dc, fmt.Sprintf("tag%d", i+1)},
Meta: map[string]string{ Meta: map[string]string{
"svc-group": fmt.Sprintf("%d", i%2), "svc-group": fmt.Sprintf("%d", i%2),
"foo": "true", "foo": "true",
@ -1510,15 +1614,8 @@ func TestPreparedQuery_Execute(t *testing.T) {
req.Service.Meta["unique"] = "true" req.Service.Meta["unique"] = "true"
} }
var codec rpc.ClientCodec
if dc == "dc1" {
codec = codec1
} else {
codec = codec2
}
var reply struct{} var reply struct{}
if err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &reply); err != nil { if err := msgpackrpc.CallWithCodec(d.codec, "Catalog.Register", &req, &reply); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
} }
@ -1576,6 +1673,17 @@ func TestPreparedQuery_Execute(t *testing.T) {
assert.True(t, reply.QueryMeta.KnownLeader) assert.True(t, reply.QueryMeta.KnownLeader)
} }
expectFailoverPeerNodes := func(t *testing.T, query *structs.PreparedQueryRequest, reply *structs.PreparedQueryExecuteResponse, n int) {
t.Helper()
assert.Len(t, reply.Nodes, n)
assert.Equal(t, "", reply.Datacenter)
assert.Equal(t, acceptingPeerName, reply.PeerName)
assert.Equal(t, 2, reply.Failovers)
assert.Equal(t, query.Query.Service.Service, reply.Service)
assert.Equal(t, query.Query.DNS, reply.DNS)
assert.True(t, reply.QueryMeta.KnownLeader)
}
t.Run("run the registered query", func(t *testing.T) { t.Run("run the registered query", func(t *testing.T) {
req := structs.PreparedQueryExecuteRequest{ req := structs.PreparedQueryExecuteRequest{
Datacenter: "dc1", Datacenter: "dc1",
@ -1962,10 +2070,10 @@ func TestPreparedQuery_Execute(t *testing.T) {
require.NoError(t, msgpackrpc.CallWithCodec(codec1, "PreparedQuery.Apply", &query, &query.Query.ID)) require.NoError(t, msgpackrpc.CallWithCodec(codec1, "PreparedQuery.Apply", &query, &query.Query.ID))
// Update the health of a node to mark it critical. // Update the health of a node to mark it critical.
setHealth := func(t *testing.T, node string, health string) { setHealth := func(t *testing.T, codec rpc.ClientCodec, dc string, node string, health string) {
t.Helper() t.Helper()
req := structs.RegisterRequest{ req := structs.RegisterRequest{
Datacenter: "dc1", Datacenter: dc,
Node: node, Node: node,
Address: "127.0.0.1", Address: "127.0.0.1",
Service: &structs.NodeService{ Service: &structs.NodeService{
@ -1981,9 +2089,9 @@ func TestPreparedQuery_Execute(t *testing.T) {
WriteRequest: structs.WriteRequest{Token: "root"}, WriteRequest: structs.WriteRequest{Token: "root"},
} }
var reply struct{} var reply struct{}
require.NoError(t, msgpackrpc.CallWithCodec(codec1, "Catalog.Register", &req, &reply)) require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &reply))
} }
setHealth(t, "node1", api.HealthCritical) setHealth(t, codec1, "dc1", "node1", api.HealthCritical)
// The failing node should be filtered. // The failing node should be filtered.
t.Run("failing node filtered", func(t *testing.T) { t.Run("failing node filtered", func(t *testing.T) {
@ -2003,7 +2111,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
}) })
// Upgrade it to a warning and re-query, should be 10 nodes again. // Upgrade it to a warning and re-query, should be 10 nodes again.
setHealth(t, "node1", api.HealthWarning) setHealth(t, codec1, "dc1", "node1", api.HealthWarning)
t.Run("warning nodes are included", func(t *testing.T) { t.Run("warning nodes are included", func(t *testing.T) {
req := structs.PreparedQueryExecuteRequest{ req := structs.PreparedQueryExecuteRequest{
Datacenter: "dc1", Datacenter: "dc1",
@ -2173,7 +2281,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
// Now fail everything in dc1 and we should get an empty list back. // Now fail everything in dc1 and we should get an empty list back.
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
setHealth(t, fmt.Sprintf("node%d", i+1), api.HealthCritical) setHealth(t, codec1, "dc1", fmt.Sprintf("node%d", i+1), api.HealthCritical)
} }
t.Run("everything is failing so should get empty list", func(t *testing.T) { t.Run("everything is failing so should get empty list", func(t *testing.T) {
req := structs.PreparedQueryExecuteRequest{ req := structs.PreparedQueryExecuteRequest{
@ -2308,6 +2416,61 @@ func TestPreparedQuery_Execute(t *testing.T) {
assert.NotEqual(t, "node3", node.Node.Node) assert.NotEqual(t, "node3", node.Node.Node)
} }
}) })
// Modify the query to have it fail over to a bogus DC and then dc2.
query.Query.Service.Failover = structs.QueryFailoverOptions{
Targets: []structs.QueryFailoverTarget{
{Datacenter: "dc2"},
{PeerName: acceptingPeerName},
},
}
require.NoError(t, msgpackrpc.CallWithCodec(codec1, "PreparedQuery.Apply", &query, &query.Query.ID))
// Ensure the foo service has fully replicated.
retry.Run(t, func(r *retry.R) {
_, nodes, err := s1.fsm.State().CheckServiceNodes(nil, "foo", nil, acceptingPeerName)
require.NoError(r, err)
require.Len(r, nodes, 10)
})
// Now we should see 9 nodes from dc2
t.Run("failing over to cluster peers", func(t *testing.T) {
req := structs.PreparedQueryExecuteRequest{
Datacenter: "dc1",
QueryIDOrName: query.Query.ID,
QueryOptions: structs.QueryOptions{Token: execToken},
}
var reply structs.PreparedQueryExecuteResponse
require.NoError(t, msgpackrpc.CallWithCodec(codec1, "PreparedQuery.Execute", &req, &reply))
for _, node := range reply.Nodes {
assert.NotEqual(t, "node3", node.Node.Node)
}
expectFailoverNodes(t, &query, &reply, 9)
})
// Set all checks in dc2 as critical
for i := 0; i < 10; i++ {
setHealth(t, codec2, "dc2", fmt.Sprintf("node%d", i+1), api.HealthCritical)
}
// Now we should see 9 nodes from dc3 (we have the tag filter still)
t.Run("failing over to cluster peers", func(t *testing.T) {
req := structs.PreparedQueryExecuteRequest{
Datacenter: "dc1",
QueryIDOrName: query.Query.ID,
QueryOptions: structs.QueryOptions{Token: execToken},
}
var reply structs.PreparedQueryExecuteResponse
require.NoError(t, msgpackrpc.CallWithCodec(codec1, "PreparedQuery.Execute", &req, &reply))
for _, node := range reply.Nodes {
assert.NotEqual(t, "node3", node.Node.Node)
}
expectFailoverPeerNodes(t, &query, &reply, 9)
})
} }
func TestPreparedQuery_Execute_ForwardLeader(t *testing.T) { func TestPreparedQuery_Execute_ForwardLeader(t *testing.T) {
@ -2724,7 +2887,9 @@ func TestPreparedQuery_Wrapper(t *testing.T) {
joinWAN(t, s2, s1) joinWAN(t, s2, s1)
// Try all the operations on a real server via the wrapper. // Try all the operations on a real server via the wrapper.
wrapper := &queryServerWrapper{s1} wrapper := &queryServerWrapper{srv: s1, executeRemote: func(args *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
return nil
}}
wrapper.GetLogger().Debug("Test") wrapper.GetLogger().Debug("Test")
ret, err := wrapper.GetOtherDatacentersByDistance() ret, err := wrapper.GetOtherDatacentersByDistance()
@ -2746,7 +2911,7 @@ type mockQueryServer struct {
Datacenters []string Datacenters []string
DatacentersError error DatacentersError error
QueryLog []string QueryLog []string
QueryFn func(dc string, args interface{}, reply interface{}) error QueryFn func(args *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error
Logger hclog.Logger Logger hclog.Logger
LogBuffer *bytes.Buffer LogBuffer *bytes.Buffer
} }
@ -2768,17 +2933,27 @@ func (m *mockQueryServer) GetLogger() hclog.Logger {
return m.Logger return m.Logger
} }
func (m *mockQueryServer) GetLocalDC() string {
return "dc1"
}
func (m *mockQueryServer) GetOtherDatacentersByDistance() ([]string, error) { func (m *mockQueryServer) GetOtherDatacentersByDistance() ([]string, error) {
return m.Datacenters, m.DatacentersError return m.Datacenters, m.DatacentersError
} }
func (m *mockQueryServer) ForwardDC(method, dc string, args interface{}, reply interface{}) error { func (m *mockQueryServer) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
m.QueryLog = append(m.QueryLog, fmt.Sprintf("%s:%s", dc, method)) peerName := args.Query.Service.PeerName
if ret, ok := reply.(*structs.PreparedQueryExecuteResponse); ok { dc := args.Datacenter
ret.Datacenter = dc if peerName != "" {
m.QueryLog = append(m.QueryLog, fmt.Sprintf("peer:%s", peerName))
} else {
m.QueryLog = append(m.QueryLog, fmt.Sprintf("%s:%s", dc, "PreparedQuery.ExecuteRemote"))
} }
reply.PeerName = peerName
reply.Datacenter = dc
if m.QueryFn != nil { if m.QueryFn != nil {
return m.QueryFn(dc, args, reply) return m.QueryFn(args, reply)
} }
return nil return nil
} }
@ -2788,7 +2963,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
query := &structs.PreparedQuery{ query := &structs.PreparedQuery{
Name: "test", Name: "test",
Service: structs.ServiceQuery{ Service: structs.ServiceQuery{
Failover: structs.QueryDatacenterOptions{ Failover: structs.QueryFailoverOptions{
NearestN: 0, NearestN: 0,
Datacenters: []string{""}, Datacenters: []string{""},
}, },
@ -2862,10 +3037,9 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
{ {
mock := &mockQueryServer{ mock := &mockQueryServer{
Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"}, Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"},
QueryFn: func(dc string, _ interface{}, reply interface{}) error { QueryFn: func(req *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
ret := reply.(*structs.PreparedQueryExecuteResponse) if req.Datacenter == "dc1" {
if dc == "dc1" { reply.Nodes = nodes()
ret.Nodes = nodes()
} }
return nil return nil
}, },
@ -2890,10 +3064,9 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
{ {
mock := &mockQueryServer{ mock := &mockQueryServer{
Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"}, Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"},
QueryFn: func(dc string, _ interface{}, reply interface{}) error { QueryFn: func(req *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
ret := reply.(*structs.PreparedQueryExecuteResponse) if req.Datacenter == "dc3" {
if dc == "dc3" { reply.Nodes = nodes()
ret.Nodes = nodes()
} }
return nil return nil
}, },
@ -2926,7 +3099,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
} }
if len(reply.Nodes) != 0 || if len(reply.Nodes) != 0 ||
reply.Datacenter != "xxx" || reply.Failovers != 4 { reply.Datacenter != "xxx" || reply.Failovers != 4 {
t.Fatalf("bad: %v", reply) t.Fatalf("bad: %+v", reply)
} }
if queries := mock.JoinQueryLog(); queries != "dc1:PreparedQuery.ExecuteRemote|dc2:PreparedQuery.ExecuteRemote|dc3:PreparedQuery.ExecuteRemote|xxx:PreparedQuery.ExecuteRemote" { if queries := mock.JoinQueryLog(); queries != "dc1:PreparedQuery.ExecuteRemote|dc2:PreparedQuery.ExecuteRemote|dc3:PreparedQuery.ExecuteRemote|xxx:PreparedQuery.ExecuteRemote" {
t.Fatalf("bad: %s", queries) t.Fatalf("bad: %s", queries)
@ -2940,10 +3113,9 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
{ {
mock := &mockQueryServer{ mock := &mockQueryServer{
Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"}, Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"},
QueryFn: func(dc string, _ interface{}, reply interface{}) error { QueryFn: func(req *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
ret := reply.(*structs.PreparedQueryExecuteResponse) if req.Datacenter == "dc4" {
if dc == "dc4" { reply.Nodes = nodes()
ret.Nodes = nodes()
} }
return nil return nil
}, },
@ -2969,10 +3141,9 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
{ {
mock := &mockQueryServer{ mock := &mockQueryServer{
Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"}, Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"},
QueryFn: func(dc string, _ interface{}, reply interface{}) error { QueryFn: func(req *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
ret := reply.(*structs.PreparedQueryExecuteResponse) if req.Datacenter == "dc4" {
if dc == "dc4" { reply.Nodes = nodes()
ret.Nodes = nodes()
} }
return nil return nil
}, },
@ -2998,10 +3169,9 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
{ {
mock := &mockQueryServer{ mock := &mockQueryServer{
Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"}, Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"},
QueryFn: func(dc string, _ interface{}, reply interface{}) error { QueryFn: func(req *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
ret := reply.(*structs.PreparedQueryExecuteResponse) if req.Datacenter == "dc4" {
if dc == "dc4" { reply.Nodes = nodes()
ret.Nodes = nodes()
} }
return nil return nil
}, },
@ -3029,12 +3199,11 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
{ {
mock := &mockQueryServer{ mock := &mockQueryServer{
Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"}, Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"},
QueryFn: func(dc string, _ interface{}, reply interface{}) error { QueryFn: func(req *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
ret := reply.(*structs.PreparedQueryExecuteResponse) if req.Datacenter == "dc1" {
if dc == "dc1" {
return fmt.Errorf("XXX") return fmt.Errorf("XXX")
} else if dc == "dc4" { } else if req.Datacenter == "dc4" {
ret.Nodes = nodes() reply.Nodes = nodes()
} }
return nil return nil
}, },
@ -3063,10 +3232,9 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
{ {
mock := &mockQueryServer{ mock := &mockQueryServer{
Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"}, Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"},
QueryFn: func(dc string, _ interface{}, reply interface{}) error { QueryFn: func(req *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
ret := reply.(*structs.PreparedQueryExecuteResponse) if req.Datacenter == "xxx" {
if dc == "xxx" { reply.Nodes = nodes()
ret.Nodes = nodes()
} }
return nil return nil
}, },
@ -3092,17 +3260,15 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
{ {
mock := &mockQueryServer{ mock := &mockQueryServer{
Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"}, Datacenters: []string{"dc1", "dc2", "dc3", "xxx", "dc4"},
QueryFn: func(dc string, args interface{}, reply interface{}) error { QueryFn: func(req *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
inp := args.(*structs.PreparedQueryExecuteRemoteRequest) if req.Datacenter == "xxx" {
ret := reply.(*structs.PreparedQueryExecuteResponse) if req.Limit != 5 {
if dc == "xxx" { t.Fatalf("bad: %d", req.Limit)
if inp.Limit != 5 {
t.Fatalf("bad: %d", inp.Limit)
} }
if inp.RequireConsistent != true { if req.RequireConsistent != true {
t.Fatalf("bad: %v", inp.RequireConsistent) t.Fatalf("bad: %v", req.RequireConsistent)
} }
ret.Nodes = nodes() reply.Nodes = nodes()
} }
return nil return nil
}, },
@ -3124,4 +3290,32 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
t.Fatalf("bad: %s", queries) t.Fatalf("bad: %s", queries)
} }
} }
// Failover returns data from the first cluster peer with data.
query.Service.Failover.Datacenters = nil
query.Service.Failover.Targets = []structs.QueryFailoverTarget{
{PeerName: "cluster-01"},
{Datacenter: "dc44"},
{PeerName: "cluster-02"},
}
{
mock := &mockQueryServer{
Datacenters: []string{"dc44"},
QueryFn: func(args *structs.PreparedQueryExecuteRemoteRequest, reply *structs.PreparedQueryExecuteResponse) error {
if args.Query.Service.PeerName == "cluster-02" {
reply.Nodes = nodes()
}
return nil
},
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
require.Equal(t, "cluster-02", reply.PeerName)
require.Equal(t, 3, reply.Failovers)
require.Equal(t, nodes(), reply.Nodes)
require.Equal(t, "peer:cluster-01|dc44:PreparedQuery.ExecuteRemote|peer:cluster-02", mock.JoinQueryLog())
}
} }

View File

@ -127,6 +127,7 @@ const (
virtualIPCheckRoutineName = "virtual IP version check" virtualIPCheckRoutineName = "virtual IP version check"
peeringStreamsRoutineName = "streaming peering resources" peeringStreamsRoutineName = "streaming peering resources"
peeringDeletionRoutineName = "peering deferred deletion" peeringDeletionRoutineName = "peering deferred deletion"
peeringStreamsMetricsRoutineName = "metrics for streaming peering resources"
) )
var ( var (
@ -367,8 +368,9 @@ type Server struct {
// peeringBackend is shared between the external and internal gRPC services for peering // peeringBackend is shared between the external and internal gRPC services for peering
peeringBackend *PeeringBackend peeringBackend *PeeringBackend
// peerStreamServer is a server used to handle peering streams // peerStreamServer is a server used to handle peering streams from external clusters.
peerStreamServer *peerstream.Server peerStreamServer *peerstream.Server
// peeringServer handles peering RPC requests internal to this cluster, like generating peering tokens.
peeringServer *peering.Server peeringServer *peering.Server
peerStreamTracker *peerstream.Tracker peerStreamTracker *peerstream.Tracker
@ -792,6 +794,7 @@ func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler
}, },
Datacenter: config.Datacenter, Datacenter: config.Datacenter,
ConnectEnabled: config.ConnectEnabled, ConnectEnabled: config.ConnectEnabled,
PeeringEnabled: config.PeeringEnabled,
}) })
s.peeringServer = p s.peeringServer = p

View File

@ -25,6 +25,7 @@ import (
"github.com/hashicorp/consul-net-rpc/net/rpc" "github.com/hashicorp/consul-net-rpc/net/rpc"
"github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/connect"
external "github.com/hashicorp/consul/agent/grpc-external"
"github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/rpc/middleware" "github.com/hashicorp/consul/agent/rpc/middleware"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
@ -299,8 +300,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) {
} }
} }
srv, err := NewServer(c, deps, grpc.NewServer()) srv, err := NewServer(c, deps, external.NewServer(deps.Logger.Named("grpc.external"), deps.TLSConfigurator))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1990,7 +1990,7 @@ func (s *Store) deleteServiceTxn(tx WriteTxn, idx uint64, nodeName, serviceID st
} }
} }
psn := structs.PeeredServiceName{Peer: svc.PeerName, ServiceName: name} psn := structs.PeeredServiceName{Peer: svc.PeerName, ServiceName: name}
if err := freeServiceVirtualIP(tx, psn, nil); err != nil { if err := freeServiceVirtualIP(tx, idx, psn, nil); err != nil {
return fmt.Errorf("failed to clean up virtual IP for %q: %v", name.String(), err) return fmt.Errorf("failed to clean up virtual IP for %q: %v", name.String(), err)
} }
if err := cleanupKindServiceName(tx, idx, svc.CompoundServiceName(), svc.ServiceKind); err != nil { if err := cleanupKindServiceName(tx, idx, svc.CompoundServiceName(), svc.ServiceKind); err != nil {
@ -2008,6 +2008,7 @@ func (s *Store) deleteServiceTxn(tx WriteTxn, idx uint64, nodeName, serviceID st
// is removed. // is removed.
func freeServiceVirtualIP( func freeServiceVirtualIP(
tx WriteTxn, tx WriteTxn,
idx uint64,
psn structs.PeeredServiceName, psn structs.PeeredServiceName,
excludeGateway *structs.ServiceName, excludeGateway *structs.ServiceName,
) error { ) error {
@ -2059,6 +2060,10 @@ func freeServiceVirtualIP(
return fmt.Errorf("failed updating freed virtual IP table: %v", err) return fmt.Errorf("failed updating freed virtual IP table: %v", err)
} }
if err := updateVirtualIPMaxIndexes(tx, idx, psn.ServiceName.PartitionOrDefault(), psn.Peer); err != nil {
return err
}
return nil return nil
} }
@ -2907,6 +2912,25 @@ func (s *Store) GatewayServices(ws memdb.WatchSet, gateway string, entMeta *acl.
return lib.MaxUint64(maxIdx, idx), results, nil return lib.MaxUint64(maxIdx, idx), results, nil
} }
// TODO: Find a way to consolidate this with CheckIngressServiceNodes
// ServiceGateways is used to query all gateways associated with a service
func (s *Store) ServiceGateways(ws memdb.WatchSet, service string, kind structs.ServiceKind, entMeta acl.EnterpriseMeta) (uint64, structs.CheckServiceNodes, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// tableGatewayServices is not peer-aware, and the existence of TG/IG gateways is scrubbed during peer replication.
maxIdx, nodes, err := serviceGatewayNodes(tx, ws, service, kind, &entMeta, structs.DefaultPeerKeyword)
// Watch for index changes to the gateway nodes
idx, chans := maxIndexAndWatchChsForServiceNodes(tx, nodes, false)
for _, ch := range chans {
ws.Add(ch)
}
maxIdx = lib.MaxUint64(maxIdx, idx)
return parseCheckServiceNodes(tx, ws, maxIdx, nodes, &entMeta, structs.DefaultPeerKeyword, err)
}
func (s *Store) VirtualIPForService(psn structs.PeeredServiceName) (string, error) { func (s *Store) VirtualIPForService(psn structs.PeeredServiceName) (string, error) {
tx := s.db.Txn(false) tx := s.db.Txn(false)
defer tx.Abort() defer tx.Abort()
@ -3478,7 +3502,7 @@ func updateTerminatingGatewayVirtualIPs(tx WriteTxn, idx uint64, conf *structs.T
} }
if len(nodes) == 0 { if len(nodes) == 0 {
psn := structs.PeeredServiceName{Peer: structs.DefaultPeerKeyword, ServiceName: sn} psn := structs.PeeredServiceName{Peer: structs.DefaultPeerKeyword, ServiceName: sn}
if err := freeServiceVirtualIP(tx, psn, &gatewayName); err != nil { if err := freeServiceVirtualIP(tx, idx, psn, &gatewayName); err != nil {
return err return err
} }
} }
@ -3862,7 +3886,7 @@ func (s *Store) collectGatewayServices(tx ReadTxn, ws memdb.WatchSet, iter memdb
return maxIdx, results, nil return maxIdx, results, nil
} }
// TODO(ingress): How to handle index rolling back when a config entry is // TODO: How to handle index rolling back when a config entry is
// deleted that references a service? // deleted that references a service?
// We might need something like the service_last_extinction index? // We might need something like the service_last_extinction index?
func serviceGatewayNodes(tx ReadTxn, ws memdb.WatchSet, service string, kind structs.ServiceKind, entMeta *acl.EnterpriseMeta, peerName string) (uint64, structs.ServiceNodes, error) { func serviceGatewayNodes(tx ReadTxn, ws memdb.WatchSet, service string, kind structs.ServiceKind, entMeta *acl.EnterpriseMeta, peerName string) (uint64, structs.ServiceNodes, error) {

View File

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbcommon"
"github.com/hashicorp/consul/proto/pbservice" "github.com/hashicorp/consul/proto/pbservice"
"github.com/hashicorp/consul/proto/pbsubscribe" "github.com/hashicorp/consul/proto/pbsubscribe"
) )
@ -71,6 +72,39 @@ func (e EventPayloadCheckServiceNode) ToSubscriptionEvent(idx uint64) *pbsubscri
} }
} }
// EventPayloadServiceListUpdate is used as the Payload for a stream.Event when
// services (not service instances) are registered/deregistered. These events
// are used to materialize the list of services in a datacenter.
type EventPayloadServiceListUpdate struct {
Op pbsubscribe.CatalogOp
Name string
EnterpriseMeta acl.EnterpriseMeta
PeerName string
}
func (e *EventPayloadServiceListUpdate) ToSubscriptionEvent(idx uint64) *pbsubscribe.Event {
return &pbsubscribe.Event{
Index: idx,
Payload: &pbsubscribe.Event_Service{
Service: &pbsubscribe.ServiceListUpdate{
Op: e.Op,
Name: e.Name,
EnterpriseMeta: pbcommon.NewEnterpriseMetaFromStructs(e.EnterpriseMeta),
PeerName: e.PeerName,
},
},
}
}
func (e *EventPayloadServiceListUpdate) Subject() stream.Subject { return stream.SubjectNone }
func (e *EventPayloadServiceListUpdate) HasReadPermission(authz acl.Authorizer) bool {
var authzContext acl.AuthorizerContext
e.EnterpriseMeta.FillAuthzContext(&authzContext)
return authz.ServiceRead(e.Name, &authzContext) == acl.Allow
}
// serviceHealthSnapshot returns a stream.SnapshotFunc that provides a snapshot // serviceHealthSnapshot returns a stream.SnapshotFunc that provides a snapshot
// of stream.Events that describe the current state of a service health query. // of stream.Events that describe the current state of a service health query.
func (s *Store) ServiceHealthSnapshot(req stream.SubscribeRequest, buf stream.SnapshotAppender) (index uint64, err error) { func (s *Store) ServiceHealthSnapshot(req stream.SubscribeRequest, buf stream.SnapshotAppender) (index uint64, err error) {
@ -156,6 +190,65 @@ type nodeTuple struct {
var serviceChangeIndirect = serviceChange{changeType: changeIndirect} var serviceChangeIndirect = serviceChange{changeType: changeIndirect}
// ServiceListUpdateEventsFromChanges returns events representing changes to
// the list of services from the given set of state store changes.
func ServiceListUpdateEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) {
var events []stream.Event
for _, change := range changes.Changes {
if change.Table != tableKindServiceNames {
continue
}
kindName := changeObject(change).(*KindServiceName)
// TODO(peering): make this peer-aware.
payload := &EventPayloadServiceListUpdate{
Name: kindName.Service.Name,
EnterpriseMeta: kindName.Service.EnterpriseMeta,
}
if change.Deleted() {
payload.Op = pbsubscribe.CatalogOp_Deregister
} else {
payload.Op = pbsubscribe.CatalogOp_Register
}
events = append(events, stream.Event{
Topic: EventTopicServiceList,
Index: changes.Index,
Payload: payload,
})
}
return events, nil
}
// ServiceListSnapshot is a stream.SnapshotFunc that returns a snapshot of
// all service names.
func (s *Store) ServiceListSnapshot(_ stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) {
index, names, err := s.ServiceNamesOfKind(nil, "")
if err != nil {
return 0, err
}
if l := len(names); l > 0 {
events := make([]stream.Event, l)
for idx, name := range names {
events[idx] = stream.Event{
Topic: EventTopicServiceList,
Index: index,
Payload: &EventPayloadServiceListUpdate{
Op: pbsubscribe.CatalogOp_Register,
Name: name.Service.Name,
EnterpriseMeta: name.Service.EnterpriseMeta,
},
}
}
buf.Append(events)
}
return index, nil
}
// ServiceHealthEventsFromChanges returns all the service and Connect health // ServiceHealthEventsFromChanges returns all the service and Connect health
// events that should be emitted given a set of changes to the state store. // events that should be emitted given a set of changes to the state store.
func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) { func ServiceHealthEventsFromChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) {

View File

@ -8,6 +8,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
@ -1674,7 +1675,7 @@ func TestServiceHealthEventsFromChanges(t *testing.T) {
configEntryDest := &structs.ServiceConfigEntry{ configEntryDest := &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "destination1", Name: "destination1",
Destination: &structs.DestinationConfig{Port: 9000, Address: "kafka.test.com"}, Destination: &structs.DestinationConfig{Port: 9000, Addresses: []string{"kafka.test.com"}},
} }
return ensureConfigEntryTxn(tx, tx.Index, configEntryDest) return ensureConfigEntryTxn(tx, tx.Index, configEntryDest)
}, },
@ -1720,7 +1721,7 @@ func TestServiceHealthEventsFromChanges(t *testing.T) {
configEntryDest := &structs.ServiceConfigEntry{ configEntryDest := &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults, Kind: structs.ServiceDefaults,
Name: "destination1", Name: "destination1",
Destination: &structs.DestinationConfig{Port: 9000, Address: "kafka.test.com"}, Destination: &structs.DestinationConfig{Port: 9000, Addresses: []string{"kafka.test.com"}},
} }
return ensureConfigEntryTxn(tx, tx.Index, configEntryDest) return ensureConfigEntryTxn(tx, tx.Index, configEntryDest)
}, },
@ -2543,3 +2544,114 @@ func newPayloadCheckServiceNodeWithOverride(
overrideNamespace: overrideNamespace, overrideNamespace: overrideNamespace,
} }
} }
func TestServiceListUpdateSnapshot(t *testing.T) {
const index uint64 = 123
store := testStateStore(t)
require.NoError(t, store.EnsureRegistration(index, testServiceRegistration(t, "db")))
buf := &snapshotAppender{}
idx, err := store.ServiceListSnapshot(stream.SubscribeRequest{Subject: stream.SubjectNone}, buf)
require.NoError(t, err)
require.NotZero(t, idx)
require.Len(t, buf.events, 1)
require.Len(t, buf.events[0], 1)
payload := buf.events[0][0].Payload.(*EventPayloadServiceListUpdate)
require.Equal(t, pbsubscribe.CatalogOp_Register, payload.Op)
require.Equal(t, "db", payload.Name)
}
func TestServiceListUpdateEventsFromChanges(t *testing.T) {
const changeIndex = 123
testCases := map[string]struct {
setup func(*Store, *txn) error
mutate func(*Store, *txn) error
events []stream.Event
}{
"register new service": {
mutate: func(store *Store, tx *txn) error {
return store.ensureRegistrationTxn(tx, changeIndex, false, testServiceRegistration(t, "db"), false)
},
events: []stream.Event{
{
Topic: EventTopicServiceList,
Index: changeIndex,
Payload: &EventPayloadServiceListUpdate{
Op: pbsubscribe.CatalogOp_Register,
Name: "db",
EnterpriseMeta: *acl.DefaultEnterpriseMeta(),
},
},
},
},
"service already registered": {
setup: func(store *Store, tx *txn) error {
return store.ensureRegistrationTxn(tx, changeIndex, false, testServiceRegistration(t, "db"), false)
},
mutate: func(store *Store, tx *txn) error {
return store.ensureRegistrationTxn(tx, changeIndex, false, testServiceRegistration(t, "db"), false)
},
events: nil,
},
"deregister last instance of service": {
setup: func(store *Store, tx *txn) error {
return store.ensureRegistrationTxn(tx, changeIndex, false, testServiceRegistration(t, "db"), false)
},
mutate: func(store *Store, tx *txn) error {
return store.deleteServiceTxn(tx, tx.Index, "node1", "db", nil, "")
},
events: []stream.Event{
{
Topic: EventTopicServiceList,
Index: changeIndex,
Payload: &EventPayloadServiceListUpdate{
Op: pbsubscribe.CatalogOp_Deregister,
Name: "db",
EnterpriseMeta: *acl.DefaultEnterpriseMeta(),
},
},
},
},
"deregister (not the last) instance of service": {
setup: func(store *Store, tx *txn) error {
if err := store.ensureRegistrationTxn(tx, changeIndex, false, testServiceRegistration(t, "db"), false); err != nil {
return err
}
if err := store.ensureRegistrationTxn(tx, changeIndex, false, testServiceRegistration(t, "db", regNode2), false); err != nil {
return err
}
return nil
},
mutate: func(store *Store, tx *txn) error {
return store.deleteServiceTxn(tx, tx.Index, "node1", "db", nil, "")
},
events: nil,
},
}
for desc, tc := range testCases {
t.Run(desc, func(t *testing.T) {
store := testStateStore(t)
if tc.setup != nil {
tx := store.db.WriteTxn(0)
require.NoError(t, tc.setup(store, tx))
require.NoError(t, tx.Commit())
}
tx := store.db.WriteTxn(0)
t.Cleanup(tx.Abort)
if tc.mutate != nil {
require.NoError(t, tc.mutate(store, tx))
}
events, err := ServiceListUpdateEventsFromChanges(tx, Changes{Index: changeIndex, Changes: tx.Changes()})
require.NoError(t, err)
require.Equal(t, tc.events, events)
})
}
}

View File

@ -34,11 +34,11 @@ func testIndexerTableChecks() map[string]indexerTestCase {
Node: "NoDe", Node: "NoDe",
CheckID: "CheckId", CheckID: "CheckId",
}, },
expected: []byte("internal\x00node\x00checkid\x00"), expected: []byte("~\x00node\x00checkid\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("internal\x00node\x00checkid\x00"), expected: []byte("~\x00node\x00checkid\x00"),
}, },
prefix: []indexValue{ prefix: []indexValue{
{ {
@ -47,7 +47,7 @@ func testIndexerTableChecks() map[string]indexerTestCase {
}, },
{ {
source: Query{Value: "nOdE"}, source: Query{Value: "nOdE"},
expected: []byte("internal\x00node\x00"), expected: []byte("~\x00node\x00"),
}, },
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
@ -77,11 +77,11 @@ func testIndexerTableChecks() map[string]indexerTestCase {
indexStatus: { indexStatus: {
read: indexValue{ read: indexValue{
source: Query{Value: "PASSING"}, source: Query{Value: "PASSING"},
expected: []byte("internal\x00passing\x00"), expected: []byte("~\x00passing\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("internal\x00passing\x00"), expected: []byte("~\x00passing\x00"),
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
{ {
@ -99,11 +99,11 @@ func testIndexerTableChecks() map[string]indexerTestCase {
indexService: { indexService: {
read: indexValue{ read: indexValue{
source: Query{Value: "ServiceName"}, source: Query{Value: "ServiceName"},
expected: []byte("internal\x00servicename\x00"), expected: []byte("~\x00servicename\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("internal\x00servicename\x00"), expected: []byte("~\x00servicename\x00"),
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
{ {
@ -124,11 +124,11 @@ func testIndexerTableChecks() map[string]indexerTestCase {
Node: "NoDe", Node: "NoDe",
Service: "SeRvIcE", Service: "SeRvIcE",
}, },
expected: []byte("internal\x00node\x00service\x00"), expected: []byte("~\x00node\x00service\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("internal\x00node\x00service\x00"), expected: []byte("~\x00node\x00service\x00"),
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
{ {
@ -152,11 +152,11 @@ func testIndexerTableChecks() map[string]indexerTestCase {
source: Query{ source: Query{
Value: "NoDe", Value: "NoDe",
}, },
expected: []byte("internal\x00node\x00"), expected: []byte("~\x00node\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("internal\x00node\x00"), expected: []byte("~\x00node\x00"),
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
{ {
@ -272,11 +272,11 @@ func testIndexerTableNodes() map[string]indexerTestCase {
indexID: { indexID: {
read: indexValue{ read: indexValue{
source: Query{Value: "NoDeId"}, source: Query{Value: "NoDeId"},
expected: []byte("internal\x00nodeid\x00"), expected: []byte("~\x00nodeid\x00"),
}, },
write: indexValue{ write: indexValue{
source: &structs.Node{Node: "NoDeId"}, source: &structs.Node{Node: "NoDeId"},
expected: []byte("internal\x00nodeid\x00"), expected: []byte("~\x00nodeid\x00"),
}, },
prefix: []indexValue{ prefix: []indexValue{
{ {
@ -289,11 +289,11 @@ func testIndexerTableNodes() map[string]indexerTestCase {
}, },
{ {
source: Query{Value: "NoDeId"}, source: Query{Value: "NoDeId"},
expected: []byte("internal\x00nodeid\x00"), expected: []byte("~\x00nodeid\x00"),
}, },
{ {
source: Query{}, source: Query{},
expected: []byte("internal\x00"), expected: []byte("~\x00"),
}, },
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
@ -322,27 +322,27 @@ func testIndexerTableNodes() map[string]indexerTestCase {
indexUUID: { indexUUID: {
read: indexValue{ read: indexValue{
source: Query{Value: uuid}, source: Query{Value: uuid},
expected: append([]byte("internal\x00"), uuidBuf...), expected: append([]byte("~\x00"), uuidBuf...),
}, },
write: indexValue{ write: indexValue{
source: &structs.Node{ source: &structs.Node{
ID: types.NodeID(uuid), ID: types.NodeID(uuid),
Node: "NoDeId", Node: "NoDeId",
}, },
expected: append([]byte("internal\x00"), uuidBuf...), expected: append([]byte("~\x00"), uuidBuf...),
}, },
prefix: []indexValue{ prefix: []indexValue{
{ // partial length { // partial length
source: Query{Value: uuid[:6]}, source: Query{Value: uuid[:6]},
expected: append([]byte("internal\x00"), uuidBuf[:3]...), expected: append([]byte("~\x00"), uuidBuf[:3]...),
}, },
{ // full length { // full length
source: Query{Value: uuid}, source: Query{Value: uuid},
expected: append([]byte("internal\x00"), uuidBuf...), expected: append([]byte("~\x00"), uuidBuf...),
}, },
{ {
source: Query{}, source: Query{},
expected: []byte("internal\x00"), expected: []byte("~\x00"),
}, },
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
@ -382,7 +382,7 @@ func testIndexerTableNodes() map[string]indexerTestCase {
Key: "KeY", Key: "KeY",
Value: "VaLuE", Value: "VaLuE",
}, },
expected: []byte("internal\x00KeY\x00VaLuE\x00"), expected: []byte("~\x00KeY\x00VaLuE\x00"),
}, },
writeMulti: indexValueMulti{ writeMulti: indexValueMulti{
source: &structs.Node{ source: &structs.Node{
@ -393,8 +393,8 @@ func testIndexerTableNodes() map[string]indexerTestCase {
}, },
}, },
expected: [][]byte{ expected: [][]byte{
[]byte("internal\x00MaP-kEy-1\x00mAp-VaL-1\x00"), []byte("~\x00MaP-kEy-1\x00mAp-VaL-1\x00"),
[]byte("internal\x00mAp-KeY-2\x00MaP-vAl-2\x00"), []byte("~\x00mAp-KeY-2\x00MaP-vAl-2\x00"),
}, },
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
@ -449,11 +449,11 @@ func testIndexerTableServices() map[string]indexerTestCase {
Node: "NoDeId", Node: "NoDeId",
Service: "SeRvIcE", Service: "SeRvIcE",
}, },
expected: []byte("internal\x00nodeid\x00service\x00"), expected: []byte("~\x00nodeid\x00service\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("internal\x00nodeid\x00service\x00"), expected: []byte("~\x00nodeid\x00service\x00"),
}, },
prefix: []indexValue{ prefix: []indexValue{
{ {
@ -466,11 +466,11 @@ func testIndexerTableServices() map[string]indexerTestCase {
}, },
{ {
source: Query{}, source: Query{},
expected: []byte("internal\x00"), expected: []byte("~\x00"),
}, },
{ {
source: Query{Value: "NoDeId"}, source: Query{Value: "NoDeId"},
expected: []byte("internal\x00nodeid\x00"), expected: []byte("~\x00nodeid\x00"),
}, },
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
@ -505,11 +505,11 @@ func testIndexerTableServices() map[string]indexerTestCase {
source: Query{ source: Query{
Value: "NoDeId", Value: "NoDeId",
}, },
expected: []byte("internal\x00nodeid\x00"), expected: []byte("~\x00nodeid\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("internal\x00nodeid\x00"), expected: []byte("~\x00nodeid\x00"),
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
{ {
@ -530,11 +530,11 @@ func testIndexerTableServices() map[string]indexerTestCase {
indexService: { indexService: {
read: indexValue{ read: indexValue{
source: Query{Value: "ServiceName"}, source: Query{Value: "ServiceName"},
expected: []byte("internal\x00servicename\x00"), expected: []byte("~\x00servicename\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("internal\x00servicename\x00"), expected: []byte("~\x00servicename\x00"),
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
{ {
@ -552,14 +552,14 @@ func testIndexerTableServices() map[string]indexerTestCase {
indexConnect: { indexConnect: {
read: indexValue{ read: indexValue{
source: Query{Value: "ConnectName"}, source: Query{Value: "ConnectName"},
expected: []byte("internal\x00connectname\x00"), expected: []byte("~\x00connectname\x00"),
}, },
write: indexValue{ write: indexValue{
source: &structs.ServiceNode{ source: &structs.ServiceNode{
ServiceName: "ConnectName", ServiceName: "ConnectName",
ServiceConnect: structs.ServiceConnect{Native: true}, ServiceConnect: structs.ServiceConnect{Native: true},
}, },
expected: []byte("internal\x00connectname\x00"), expected: []byte("~\x00connectname\x00"),
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
{ {
@ -571,7 +571,7 @@ func testIndexerTableServices() map[string]indexerTestCase {
DestinationServiceName: "ConnectName", DestinationServiceName: "ConnectName",
}, },
}, },
expected: []byte("internal\x00connectname\x00"), expected: []byte("~\x00connectname\x00"),
}, },
}, },
{ {
@ -621,13 +621,13 @@ func testIndexerTableServices() map[string]indexerTestCase {
indexKind: { indexKind: {
read: indexValue{ read: indexValue{
source: Query{Value: "connect-proxy"}, source: Query{Value: "connect-proxy"},
expected: []byte("internal\x00connect-proxy\x00"), expected: []byte("~\x00connect-proxy\x00"),
}, },
write: indexValue{ write: indexValue{
source: &structs.ServiceNode{ source: &structs.ServiceNode{
ServiceKind: structs.ServiceKindConnectProxy, ServiceKind: structs.ServiceKindConnectProxy,
}, },
expected: []byte("internal\x00connect-proxy\x00"), expected: []byte("~\x00connect-proxy\x00"),
}, },
extra: []indexerTestCase{ extra: []indexerTestCase{
{ {
@ -636,7 +636,7 @@ func testIndexerTableServices() map[string]indexerTestCase {
ServiceName: "ServiceName", ServiceName: "ServiceName",
ServiceKind: structs.ServiceKindTypical, ServiceKind: structs.ServiceKindTypical,
}, },
expected: []byte("internal\x00\x00"), expected: []byte("~\x00\x00"),
}, },
}, },
{ {
@ -694,18 +694,18 @@ func testIndexerTableServiceVirtualIPs() map[string]indexerTestCase {
Name: "foo", Name: "foo",
}, },
}, },
expected: []byte("internal\x00foo\x00"), expected: []byte("~\x00foo\x00"),
}, },
write: indexValue{ write: indexValue{
source: obj, source: obj,
expected: []byte("internal\x00foo\x00"), expected: []byte("~\x00foo\x00"),
}, },
prefix: []indexValue{ prefix: []indexValue{
{ {
source: Query{ source: Query{
Value: "foo", Value: "foo",
}, },
expected: []byte("internal\x00foo\x00"), expected: []byte("~\x00foo\x00"),
}, },
{ {
source: Query{ source: Query{

View File

@ -4,6 +4,7 @@ import (
"context" "context"
crand "crypto/rand" crand "crypto/rand"
"fmt" "fmt"
"github.com/hashicorp/consul/acl"
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
@ -5346,6 +5347,400 @@ func TestStateStore_GatewayServices_Terminating(t *testing.T) {
assert.Len(t, out, 0) assert.Len(t, out, 0)
} }
func TestStateStore_ServiceGateways_Terminating(t *testing.T) {
s := testStateStore(t)
// Listing with no results returns an empty list.
ws := memdb.NewWatchSet()
idx, nodes, err := s.GatewayServices(ws, "db", nil)
assert.Nil(t, err)
assert.Equal(t, uint64(0), idx)
assert.Len(t, nodes, 0)
// Create some nodes
assert.Nil(t, s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(t, s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(t, s.EnsureNode(12, &structs.Node{Node: "baz", Address: "127.0.0.2"}))
// Typical services and some consul services spread across two nodes
assert.Nil(t, s.EnsureService(13, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(t, s.EnsureService(15, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(t, s.EnsureService(16, "bar", &structs.NodeService{ID: "consul", Service: "consul", Tags: nil}))
assert.Nil(t, s.EnsureService(17, "bar", &structs.NodeService{ID: "consul", Service: "consul", Tags: nil}))
// Add ingress gateway and a connect proxy, neither should get picked up by terminating gateway
ingressNS := &structs.NodeService{
Kind: structs.ServiceKindIngressGateway,
ID: "ingress",
Service: "ingress",
Port: 8443,
}
assert.Nil(t, s.EnsureService(18, "baz", ingressNS))
proxyNS := &structs.NodeService{
Kind: structs.ServiceKindConnectProxy,
ID: "db proxy",
Service: "db proxy",
Proxy: structs.ConnectProxyConfig{
DestinationServiceName: "db",
},
Port: 8000,
}
assert.Nil(t, s.EnsureService(19, "foo", proxyNS))
// Register a gateway
assert.Nil(t, s.EnsureService(20, "baz", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443}))
// Associate gateway with db and api
assert.Nil(t, s.EnsureConfigEntry(21, &structs.TerminatingGatewayConfigEntry{
Kind: "terminating-gateway",
Name: "gateway",
Services: []structs.LinkedService{
{
Name: "db",
},
{
Name: "api",
},
},
}))
assert.True(t, watchFired(ws))
// Read everything back.
ws = memdb.NewWatchSet()
idx, out, err := s.ServiceGateways(ws, "db", structs.ServiceKindTerminatingGateway, *structs.DefaultEnterpriseMetaInDefaultPartition())
assert.Nil(t, err)
assert.Equal(t, uint64(21), idx)
assert.Len(t, out, 1)
expect := structs.CheckServiceNodes{
{
Node: &structs.Node{
ID: "",
Address: "127.0.0.2",
Node: "baz",
Partition: acl.DefaultPartitionName,
RaftIndex: structs.RaftIndex{
CreateIndex: 12,
ModifyIndex: 12,
},
},
Service: &structs.NodeService{
Service: "gateway",
Kind: structs.ServiceKindTerminatingGateway,
ID: "gateway",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
Weights: &structs.Weights{Passing: 1, Warning: 1},
Port: 443,
RaftIndex: structs.RaftIndex{
CreateIndex: 20,
ModifyIndex: 20,
},
},
},
}
assert.Equal(t, expect, out)
// Check that we don't update on same exact config
assert.Nil(t, s.EnsureConfigEntry(21, &structs.TerminatingGatewayConfigEntry{
Kind: "terminating-gateway",
Name: "gateway",
Services: []structs.LinkedService{
{
Name: "db",
},
{
Name: "api",
},
},
}))
assert.False(t, watchFired(ws))
idx, out, err = s.ServiceGateways(ws, "api", structs.ServiceKindTerminatingGateway, *structs.DefaultEnterpriseMetaInDefaultPartition())
assert.Nil(t, err)
assert.Equal(t, uint64(21), idx)
assert.Len(t, out, 1)
expect = structs.CheckServiceNodes{
{
Node: &structs.Node{
ID: "",
Address: "127.0.0.2",
Node: "baz",
Partition: acl.DefaultPartitionName,
RaftIndex: structs.RaftIndex{
CreateIndex: 12,
ModifyIndex: 12,
},
},
Service: &structs.NodeService{
Service: "gateway",
Kind: structs.ServiceKindTerminatingGateway,
ID: "gateway",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
Weights: &structs.Weights{Passing: 1, Warning: 1},
Port: 443,
RaftIndex: structs.RaftIndex{
CreateIndex: 20,
ModifyIndex: 20,
},
},
},
}
assert.Equal(t, expect, out)
// Associate gateway with a wildcard and add TLS config
assert.Nil(t, s.EnsureConfigEntry(22, &structs.TerminatingGatewayConfigEntry{
Kind: "terminating-gateway",
Name: "gateway",
Services: []structs.LinkedService{
{
Name: "api",
CAFile: "api/ca.crt",
CertFile: "api/client.crt",
KeyFile: "api/client.key",
SNI: "my-domain",
},
{
Name: "db",
},
{
Name: "*",
CAFile: "ca.crt",
CertFile: "client.crt",
KeyFile: "client.key",
SNI: "my-alt-domain",
},
},
}))
assert.True(t, watchFired(ws))
// Read everything back.
ws = memdb.NewWatchSet()
idx, out, err = s.ServiceGateways(ws, "db", structs.ServiceKindTerminatingGateway, *structs.DefaultEnterpriseMetaInDefaultPartition())
assert.Nil(t, err)
assert.Equal(t, uint64(22), idx)
assert.Len(t, out, 1)
expect = structs.CheckServiceNodes{
{
Node: &structs.Node{
ID: "",
Address: "127.0.0.2",
Node: "baz",
Partition: acl.DefaultPartitionName,
RaftIndex: structs.RaftIndex{
CreateIndex: 12,
ModifyIndex: 12,
},
},
Service: &structs.NodeService{
Service: "gateway",
Kind: structs.ServiceKindTerminatingGateway,
ID: "gateway",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
Weights: &structs.Weights{Passing: 1, Warning: 1},
Port: 443,
RaftIndex: structs.RaftIndex{
CreateIndex: 20,
ModifyIndex: 20,
},
},
},
}
assert.Equal(t, expect, out)
// Add a service covered by wildcard
assert.Nil(t, s.EnsureService(23, "bar", &structs.NodeService{ID: "redis", Service: "redis", Tags: nil, Address: "", Port: 6379}))
ws = memdb.NewWatchSet()
idx, out, err = s.ServiceGateways(ws, "redis", structs.ServiceKindTerminatingGateway, *structs.DefaultEnterpriseMetaInDefaultPartition())
assert.Nil(t, err)
assert.Equal(t, uint64(23), idx)
assert.Len(t, out, 1)
expect = structs.CheckServiceNodes{
{
Node: &structs.Node{
ID: "",
Address: "127.0.0.2",
Node: "baz",
Partition: acl.DefaultPartitionName,
RaftIndex: structs.RaftIndex{
CreateIndex: 12,
ModifyIndex: 12,
},
},
Service: &structs.NodeService{
Service: "gateway",
Kind: structs.ServiceKindTerminatingGateway,
ID: "gateway",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
Weights: &structs.Weights{Passing: 1, Warning: 1},
Port: 443,
RaftIndex: structs.RaftIndex{
CreateIndex: 20,
ModifyIndex: 20,
},
},
},
}
assert.Equal(t, expect, out)
// Delete a service covered by wildcard
assert.Nil(t, s.DeleteService(24, "bar", "redis", structs.DefaultEnterpriseMetaInDefaultPartition(), ""))
assert.True(t, watchFired(ws))
ws = memdb.NewWatchSet()
idx, out, err = s.ServiceGateways(ws, "redis", structs.ServiceKindTerminatingGateway, *structs.DefaultEnterpriseMetaInDefaultPartition())
assert.Nil(t, err)
// TODO: wildcards don't keep the same extinction index
assert.Equal(t, uint64(0), idx)
assert.Len(t, out, 0)
// Update the entry that only leaves one service
assert.Nil(t, s.EnsureConfigEntry(25, &structs.TerminatingGatewayConfigEntry{
Kind: "terminating-gateway",
Name: "gateway",
Services: []structs.LinkedService{
{
Name: "db",
},
},
}))
assert.True(t, watchFired(ws))
ws = memdb.NewWatchSet()
idx, out, err = s.ServiceGateways(ws, "db", structs.ServiceKindTerminatingGateway, *structs.DefaultEnterpriseMetaInDefaultPartition())
assert.Nil(t, err)
assert.Equal(t, uint64(25), idx)
assert.Len(t, out, 1)
// previously associated services should not be present
expect = structs.CheckServiceNodes{
{
Node: &structs.Node{
ID: "",
Address: "127.0.0.2",
Node: "baz",
Partition: acl.DefaultPartitionName,
RaftIndex: structs.RaftIndex{
CreateIndex: 12,
ModifyIndex: 12,
},
},
Service: &structs.NodeService{
Service: "gateway",
Kind: structs.ServiceKindTerminatingGateway,
ID: "gateway",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
Weights: &structs.Weights{Passing: 1, Warning: 1},
Port: 443,
RaftIndex: structs.RaftIndex{
CreateIndex: 20,
ModifyIndex: 20,
},
},
},
}
assert.Equal(t, expect, out)
// Attempt to associate a different gateway with services that include db
assert.Nil(t, s.EnsureConfigEntry(26, &structs.TerminatingGatewayConfigEntry{
Kind: "terminating-gateway",
Name: "gateway2",
Services: []structs.LinkedService{
{
Name: "*",
},
},
}))
// check that watchset fired for new terminating gateway node service
assert.Nil(t, s.EnsureService(20, "baz", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway2", Service: "gateway2", Port: 443}))
assert.True(t, watchFired(ws))
ws = memdb.NewWatchSet()
idx, out, err = s.ServiceGateways(ws, "db", structs.ServiceKindTerminatingGateway, *structs.DefaultEnterpriseMetaInDefaultPartition())
assert.Nil(t, err)
assert.Equal(t, uint64(26), idx)
assert.Len(t, out, 2)
expect = structs.CheckServiceNodes{
{
Node: &structs.Node{
ID: "",
Address: "127.0.0.2",
Node: "baz",
Partition: acl.DefaultPartitionName,
RaftIndex: structs.RaftIndex{
CreateIndex: 12,
ModifyIndex: 12,
},
},
Service: &structs.NodeService{
Service: "gateway",
Kind: structs.ServiceKindTerminatingGateway,
ID: "gateway",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
Weights: &structs.Weights{Passing: 1, Warning: 1},
Port: 443,
RaftIndex: structs.RaftIndex{
CreateIndex: 20,
ModifyIndex: 20,
},
},
},
{
Node: &structs.Node{
ID: "",
Address: "127.0.0.2",
Node: "baz",
Partition: acl.DefaultPartitionName,
RaftIndex: structs.RaftIndex{
CreateIndex: 12,
ModifyIndex: 12,
},
},
Service: &structs.NodeService{
Service: "gateway2",
Kind: structs.ServiceKindTerminatingGateway,
ID: "gateway2",
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
Weights: &structs.Weights{Passing: 1, Warning: 1},
Port: 443,
RaftIndex: structs.RaftIndex{
CreateIndex: 20,
ModifyIndex: 20,
},
},
},
}
assert.Equal(t, expect, out)
// Deleting the all gateway's node services should trigger the watch and keep the raft index stable
assert.Nil(t, s.DeleteService(27, "baz", "gateway", structs.DefaultEnterpriseMetaInDefaultPartition(), structs.DefaultPeerKeyword))
assert.True(t, watchFired(ws))
assert.Nil(t, s.DeleteService(28, "baz", "gateway2", structs.DefaultEnterpriseMetaInDefaultPartition(), structs.DefaultPeerKeyword))
ws = memdb.NewWatchSet()
idx, out, err = s.ServiceGateways(ws, "db", structs.ServiceKindTerminatingGateway, *structs.DefaultEnterpriseMetaInDefaultPartition())
assert.Nil(t, err)
assert.Equal(t, uint64(28), idx)
assert.Len(t, out, 0)
// Deleting the config entry even with a node service should remove existing mappings
assert.Nil(t, s.EnsureService(29, "baz", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443}))
assert.Nil(t, s.DeleteConfigEntry(30, "terminating-gateway", "gateway", nil))
assert.True(t, watchFired(ws))
idx, out, err = s.ServiceGateways(ws, "api", structs.ServiceKindTerminatingGateway, *structs.DefaultEnterpriseMetaInDefaultPartition())
assert.Nil(t, err)
// TODO: similar to ingress, the index can backslide if the config is deleted.
assert.Equal(t, uint64(28), idx)
assert.Len(t, out, 0)
}
func TestStateStore_GatewayServices_ServiceDeletion(t *testing.T) { func TestStateStore_GatewayServices_ServiceDeletion(t *testing.T) {
s := testStateStore(t) s := testStateStore(t)

View File

@ -43,6 +43,12 @@ func PBToStreamSubscribeRequest(req *pbsubscribe.SubscribeRequest, entMeta acl.E
Name: named.Key, Name: named.Key,
EnterpriseMeta: &entMeta, EnterpriseMeta: &entMeta,
} }
case EventTopicServiceList:
// Events on this topic are published to SubjectNone, but rather than
// exposing this in (and further complicating) the streaming API we rely
// on consumers passing WildcardSubject instead, which is functionally the
// same for this purpose.
return nil, fmt.Errorf("topic %s can only be consumed using WildcardSubject", EventTopicServiceList)
default: default:
return nil, fmt.Errorf("cannot construct subject for topic %s", req.Topic) return nil, fmt.Errorf("cannot construct subject for topic %s", req.Topic)
} }

View File

@ -184,6 +184,7 @@ var (
EventTopicServiceResolver = pbsubscribe.Topic_ServiceResolver EventTopicServiceResolver = pbsubscribe.Topic_ServiceResolver
EventTopicIngressGateway = pbsubscribe.Topic_IngressGateway EventTopicIngressGateway = pbsubscribe.Topic_IngressGateway
EventTopicServiceIntentions = pbsubscribe.Topic_ServiceIntentions EventTopicServiceIntentions = pbsubscribe.Topic_ServiceIntentions
EventTopicServiceList = pbsubscribe.Topic_ServiceList
) )
func processDBChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) { func processDBChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) {
@ -192,6 +193,7 @@ func processDBChanges(tx ReadTxn, changes Changes) ([]stream.Event, error) {
aclChangeUnsubscribeEvent, aclChangeUnsubscribeEvent,
caRootsChangeEvents, caRootsChangeEvents,
ServiceHealthEventsFromChanges, ServiceHealthEventsFromChanges,
ServiceListUpdateEventsFromChanges,
ConfigEntryEventsFromChanges, ConfigEntryEventsFromChanges,
// TODO: add other table handlers here. // TODO: add other table handlers here.
} }

View File

@ -213,6 +213,13 @@ func (s *Store) PeeringWrite(idx uint64, p *pbpeering.Peering) error {
return fmt.Errorf("cannot write to peering that is marked for deletion") return fmt.Errorf("cannot write to peering that is marked for deletion")
} }
if p.State == pbpeering.PeeringState_UNDEFINED {
p.State = existing.State
}
// TODO(peering): Confirm behavior when /peering/token is called more than once.
// We may need to avoid clobbering existing values.
p.ImportedServiceCount = existing.ImportedServiceCount
p.ExportedServiceCount = existing.ExportedServiceCount
p.CreateIndex = existing.CreateIndex p.CreateIndex = existing.CreateIndex
p.ModifyIndex = idx p.ModifyIndex = idx
} else { } else {
@ -346,7 +353,9 @@ func (s *Store) ExportedServicesForAllPeersByName(ws memdb.WatchSet, entMeta acl
} }
m := list.ListAllDiscoveryChains() m := list.ListAllDiscoveryChains()
if len(m) > 0 { if len(m) > 0 {
out[peering.Name] = maps.SliceOfKeys(m) sns := maps.SliceOfKeys[structs.ServiceName, structs.ExportedDiscoveryChainInfo](m)
sort.Sort(structs.ServiceList(sns))
out[peering.Name] = sns
} }
} }

View File

@ -6075,7 +6075,7 @@ func TestDNS_PreparedQuery_Failover(t *testing.T) {
Name: "my-query", Name: "my-query",
Service: structs.ServiceQuery{ Service: structs.ServiceQuery{
Service: "db", Service: "db",
Failover: structs.QueryDatacenterOptions{ Failover: structs.QueryFailoverOptions{
Datacenters: []string{"dc2"}, Datacenters: []string{"dc2"},
}, },
}, },

View File

@ -5,6 +5,8 @@ import (
recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"time"
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware" agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/tlsutil"
@ -25,6 +27,12 @@ func NewServer(logger agentmiddleware.Logger, tls *tlsutil.Configurator) *grpc.S
// Add middlware interceptors to recover in case of panics. // Add middlware interceptors to recover in case of panics.
recovery.StreamServerInterceptor(recoveryOpts...), recovery.StreamServerInterceptor(recoveryOpts...),
), ),
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
// This must be less than the keealive.ClientParameters Time setting, otherwise
// the server will disconnect the client for sending too many keepalive pings.
// Currently the client param is set to 30s.
MinTime: 15 * time.Second,
}),
} }
if tls != nil && tls.GRPCTLSConfigured() { if tls != nil && tls.GRPCTLSConfigured() {
creds := credentials.NewTLS(tls.IncomingGRPCConfig()) creds := credentials.NewTLS(tls.IncomingGRPCConfig())

View File

@ -8,7 +8,11 @@ import (
// healthSnapshot represents a normalized view of a set of CheckServiceNodes // healthSnapshot represents a normalized view of a set of CheckServiceNodes
// meant for easy comparison to aid in differential synchronization // meant for easy comparison to aid in differential synchronization
type healthSnapshot struct { type healthSnapshot struct {
Nodes map[types.NodeID]*nodeSnapshot // Nodes is a map of a node name to a nodeSnapshot. Ideally we would be able to use
// the types.NodeID and assume they are UUIDs for the map key but Consul doesn't
// require a NodeID. Therefore we must key off of the only bit of ID material
// that is required which is the node name.
Nodes map[string]*nodeSnapshot
} }
type nodeSnapshot struct { type nodeSnapshot struct {
@ -40,20 +44,20 @@ func newHealthSnapshot(all []structs.CheckServiceNode, partition, peerName strin
} }
snap := &healthSnapshot{ snap := &healthSnapshot{
Nodes: make(map[types.NodeID]*nodeSnapshot), Nodes: make(map[string]*nodeSnapshot),
} }
for _, instance := range all { for _, instance := range all {
if instance.Node.ID == "" { if instance.Node.Node == "" {
panic("TODO(peering): data should always have a node ID") panic("TODO(peering): data should always have a node name")
} }
nodeSnap, ok := snap.Nodes[instance.Node.ID] nodeSnap, ok := snap.Nodes[instance.Node.Node]
if !ok { if !ok {
nodeSnap = &nodeSnapshot{ nodeSnap = &nodeSnapshot{
Node: instance.Node, Node: instance.Node,
Services: make(map[structs.ServiceID]*serviceSnapshot), Services: make(map[structs.ServiceID]*serviceSnapshot),
} }
snap.Nodes[instance.Node.ID] = nodeSnap snap.Nodes[instance.Node.Node] = nodeSnap
} }
if instance.Service.ID == "" { if instance.Service.ID == "" {

View File

@ -69,8 +69,8 @@ func TestHealthSnapshot(t *testing.T) {
}, },
}, },
expect: &healthSnapshot{ expect: &healthSnapshot{
Nodes: map[types.NodeID]*nodeSnapshot{ Nodes: map[string]*nodeSnapshot{
"abc-123": { "abc": {
Node: newNode("abc-123", "abc", "my-peer"), Node: newNode("abc-123", "abc", "my-peer"),
Services: map[structs.ServiceID]*serviceSnapshot{ Services: map[structs.ServiceID]*serviceSnapshot{
structs.NewServiceID("xyz-123", nil): { structs.NewServiceID("xyz-123", nil): {
@ -88,14 +88,14 @@ func TestHealthSnapshot(t *testing.T) {
name: "multiple", name: "multiple",
in: []structs.CheckServiceNode{ in: []structs.CheckServiceNode{
{ {
Node: newNode("abc-123", "abc", ""), Node: newNode("", "abc", ""),
Service: newService("xyz-123", 8080, ""), Service: newService("xyz-123", 8080, ""),
Checks: structs.HealthChecks{ Checks: structs.HealthChecks{
newCheck("abc", "xyz-123", ""), newCheck("abc", "xyz-123", ""),
}, },
}, },
{ {
Node: newNode("abc-123", "abc", ""), Node: newNode("", "abc", ""),
Service: newService("xyz-789", 8181, ""), Service: newService("xyz-789", 8181, ""),
Checks: structs.HealthChecks{ Checks: structs.HealthChecks{
newCheck("abc", "xyz-789", ""), newCheck("abc", "xyz-789", ""),
@ -110,9 +110,9 @@ func TestHealthSnapshot(t *testing.T) {
}, },
}, },
expect: &healthSnapshot{ expect: &healthSnapshot{
Nodes: map[types.NodeID]*nodeSnapshot{ Nodes: map[string]*nodeSnapshot{
"abc-123": { "abc": {
Node: newNode("abc-123", "abc", "my-peer"), Node: newNode("", "abc", "my-peer"),
Services: map[structs.ServiceID]*serviceSnapshot{ Services: map[structs.ServiceID]*serviceSnapshot{
structs.NewServiceID("xyz-123", nil): { structs.NewServiceID("xyz-123", nil): {
Service: newService("xyz-123", 8080, "my-peer"), Service: newService("xyz-123", 8080, "my-peer"),
@ -128,7 +128,7 @@ func TestHealthSnapshot(t *testing.T) {
}, },
}, },
}, },
"def-456": { "def": {
Node: newNode("def-456", "def", "my-peer"), Node: newNode("def-456", "def", "my-peer"),
Services: map[structs.ServiceID]*serviceSnapshot{ Services: map[structs.ServiceID]*serviceSnapshot{
structs.NewServiceID("xyz-456", nil): { structs.NewServiceID("xyz-456", nil): {

View File

@ -5,10 +5,9 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"google.golang.org/genproto/googleapis/rpc/code" "google.golang.org/genproto/googleapis/rpc/code"
newproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/cache"
@ -37,15 +36,24 @@ import (
// If there are no instances in the event, we consider that to be a de-registration. // If there are no instances in the event, we consider that to be a de-registration.
func makeServiceResponse( func makeServiceResponse(
logger hclog.Logger, logger hclog.Logger,
mst *MutableStatus,
update cache.UpdateEvent, update cache.UpdateEvent,
) (*pbpeerstream.ReplicationMessage_Response, error) { ) (*pbpeerstream.ReplicationMessage_Response, error) {
any, csn, err := marshalToProtoAny[*pbservice.IndexedCheckServiceNodes](update.Result) serviceName := strings.TrimPrefix(update.CorrelationID, subExportedService)
sn := structs.ServiceNameFromString(serviceName)
csn, ok := update.Result.(*pbservice.IndexedCheckServiceNodes)
if !ok {
return nil, fmt.Errorf("invalid type for service response: %T", update.Result)
}
export := &pbpeerstream.ExportedService{
Nodes: csn.Nodes,
}
any, err := anypb.New(export)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal: %w", err) return nil, fmt.Errorf("failed to marshal: %w", err)
} }
serviceName := strings.TrimPrefix(update.CorrelationID, subExportedService)
// If no nodes are present then it's due to one of: // If no nodes are present then it's due to one of:
// 1. The service is newly registered or exported and yielded a transient empty update. // 1. The service is newly registered or exported and yielded a transient empty update.
// 2. All instances of the service were de-registered. // 2. All instances of the service were de-registered.
@ -54,8 +62,10 @@ func makeServiceResponse(
// We don't distinguish when these three things occurred, but it's safe to send a DELETE Op in all cases, so we do that. // We don't distinguish when these three things occurred, but it's safe to send a DELETE Op in all cases, so we do that.
// Case #1 is a no-op for the importing peer. // Case #1 is a no-op for the importing peer.
if len(csn.Nodes) == 0 { if len(csn.Nodes) == 0 {
mst.RemoveExportedService(sn)
return &pbpeerstream.ReplicationMessage_Response{ return &pbpeerstream.ReplicationMessage_Response{
ResourceURL: pbpeerstream.TypeURLService, ResourceURL: pbpeerstream.TypeURLExportedService,
// TODO(peering): Nonce management // TODO(peering): Nonce management
Nonce: "", Nonce: "",
ResourceID: serviceName, ResourceID: serviceName,
@ -63,9 +73,11 @@ func makeServiceResponse(
}, nil }, nil
} }
mst.TrackExportedService(sn)
// If there are nodes in the response, we push them as an UPSERT operation. // If there are nodes in the response, we push them as an UPSERT operation.
return &pbpeerstream.ReplicationMessage_Response{ return &pbpeerstream.ReplicationMessage_Response{
ResourceURL: pbpeerstream.TypeURLService, ResourceURL: pbpeerstream.TypeURLExportedService,
// TODO(peering): Nonce management // TODO(peering): Nonce management
Nonce: "", Nonce: "",
ResourceID: serviceName, ResourceID: serviceName,
@ -84,7 +96,7 @@ func makeCARootsResponse(
} }
return &pbpeerstream.ReplicationMessage_Response{ return &pbpeerstream.ReplicationMessage_Response{
ResourceURL: pbpeerstream.TypeURLRoots, ResourceURL: pbpeerstream.TypeURLPeeringTrustBundle,
// TODO(peering): Nonce management // TODO(peering): Nonce management
Nonce: "", Nonce: "",
ResourceID: "roots", ResourceID: "roots",
@ -97,13 +109,13 @@ func makeCARootsResponse(
// the protobuf.Any type, the asserted T type, and any errors // the protobuf.Any type, the asserted T type, and any errors
// during marshalling or type assertion. // during marshalling or type assertion.
// `in` MUST be of type T or it returns an error. // `in` MUST be of type T or it returns an error.
func marshalToProtoAny[T proto.Message](in any) (*anypb.Any, T, error) { func marshalToProtoAny[T newproto.Message](in any) (*anypb.Any, T, error) {
typ, ok := in.(T) typ, ok := in.(T)
if !ok { if !ok {
var outType T var outType T
return nil, typ, fmt.Errorf("input type is not %T: %T", outType, in) return nil, typ, fmt.Errorf("input type is not %T: %T", outType, in)
} }
any, err := ptypes.MarshalAny(typ) any, err := anypb.New(typ)
if err != nil { if err != nil {
return nil, typ, err return nil, typ, err
} }
@ -113,7 +125,9 @@ func marshalToProtoAny[T proto.Message](in any) (*anypb.Any, T, error) {
func (s *Server) processResponse( func (s *Server) processResponse(
peerName string, peerName string,
partition string, partition string,
mutableStatus *MutableStatus,
resp *pbpeerstream.ReplicationMessage_Response, resp *pbpeerstream.ReplicationMessage_Response,
logger hclog.Logger,
) (*pbpeerstream.ReplicationMessage, error) { ) (*pbpeerstream.ReplicationMessage, error) {
if !pbpeerstream.KnownTypeURL(resp.ResourceURL) { if !pbpeerstream.KnownTypeURL(resp.ResourceURL) {
err := fmt.Errorf("received response for unknown resource type %q", resp.ResourceURL) err := fmt.Errorf("received response for unknown resource type %q", resp.ResourceURL)
@ -137,7 +151,7 @@ func (s *Server) processResponse(
), err ), err
} }
if err := s.handleUpsert(peerName, partition, resp.ResourceURL, resp.ResourceID, resp.Resource); err != nil { if err := s.handleUpsert(peerName, partition, mutableStatus, resp.ResourceURL, resp.ResourceID, resp.Resource, logger); err != nil {
return makeNACKReply( return makeNACKReply(
resp.ResourceURL, resp.ResourceURL,
resp.Nonce, resp.Nonce,
@ -149,7 +163,7 @@ func (s *Server) processResponse(
return makeACKReply(resp.ResourceURL, resp.Nonce), nil return makeACKReply(resp.ResourceURL, resp.Nonce), nil
case pbpeerstream.Operation_OPERATION_DELETE: case pbpeerstream.Operation_OPERATION_DELETE:
if err := s.handleDelete(peerName, partition, resp.ResourceURL, resp.ResourceID); err != nil { if err := s.handleDelete(peerName, partition, mutableStatus, resp.ResourceURL, resp.ResourceID, logger); err != nil {
return makeNACKReply( return makeNACKReply(
resp.ResourceURL, resp.ResourceURL,
resp.Nonce, resp.Nonce,
@ -178,25 +192,38 @@ func (s *Server) processResponse(
func (s *Server) handleUpsert( func (s *Server) handleUpsert(
peerName string, peerName string,
partition string, partition string,
mutableStatus *MutableStatus,
resourceURL string, resourceURL string,
resourceID string, resourceID string,
resource *anypb.Any, resource *anypb.Any,
logger hclog.Logger,
) error { ) error {
if resource.TypeUrl != resourceURL {
return fmt.Errorf("mismatched resourceURL %q and Any typeUrl %q", resourceURL, resource.TypeUrl)
}
switch resourceURL { switch resourceURL {
case pbpeerstream.TypeURLService: case pbpeerstream.TypeURLExportedService:
sn := structs.ServiceNameFromString(resourceID) sn := structs.ServiceNameFromString(resourceID)
sn.OverridePartition(partition) sn.OverridePartition(partition)
csn := &pbservice.IndexedCheckServiceNodes{} export := &pbpeerstream.ExportedService{}
if err := ptypes.UnmarshalAny(resource, csn); err != nil { if err := resource.UnmarshalTo(export); err != nil {
return fmt.Errorf("failed to unmarshal resource: %w", err) return fmt.Errorf("failed to unmarshal resource: %w", err)
} }
return s.handleUpdateService(peerName, partition, sn, csn) err := s.handleUpdateService(peerName, partition, sn, export)
if err != nil {
return fmt.Errorf("did not increment imported services count for service=%q: %w", sn.String(), err)
}
case pbpeerstream.TypeURLRoots: mutableStatus.TrackImportedService(sn)
return nil
case pbpeerstream.TypeURLPeeringTrustBundle:
roots := &pbpeering.PeeringTrustBundle{} roots := &pbpeering.PeeringTrustBundle{}
if err := ptypes.UnmarshalAny(resource, roots); err != nil { if err := resource.UnmarshalTo(roots); err != nil {
return fmt.Errorf("failed to unmarshal resource: %w", err) return fmt.Errorf("failed to unmarshal resource: %w", err)
} }
@ -219,7 +246,7 @@ func (s *Server) handleUpdateService(
peerName string, peerName string,
partition string, partition string,
sn structs.ServiceName, sn structs.ServiceName,
pbNodes *pbservice.IndexedCheckServiceNodes, export *pbpeerstream.ExportedService,
) error { ) error {
// Capture instances in the state store for reconciliation later. // Capture instances in the state store for reconciliation later.
_, storedInstances, err := s.GetStore().CheckServiceNodes(nil, sn.Name, &sn.EnterpriseMeta, peerName) _, storedInstances, err := s.GetStore().CheckServiceNodes(nil, sn.Name, &sn.EnterpriseMeta, peerName)
@ -227,7 +254,7 @@ func (s *Server) handleUpdateService(
return fmt.Errorf("failed to read imported services: %w", err) return fmt.Errorf("failed to read imported services: %w", err)
} }
structsNodes, err := pbNodes.CheckServiceNodesToStruct() structsNodes, err := export.CheckServiceNodesToStruct()
if err != nil { if err != nil {
return fmt.Errorf("failed to convert protobuf instances to structs: %w", err) return fmt.Errorf("failed to convert protobuf instances to structs: %w", err)
} }
@ -290,8 +317,8 @@ func (s *Server) handleUpdateService(
deletedNodeChecks = make(map[nodeCheckTuple]struct{}) deletedNodeChecks = make(map[nodeCheckTuple]struct{})
) )
for _, csn := range storedInstances { for _, csn := range storedInstances {
if _, ok := snap.Nodes[csn.Node.ID]; !ok { if _, ok := snap.Nodes[csn.Node.Node]; !ok {
unusedNodes[string(csn.Node.ID)] = struct{}{} unusedNodes[csn.Node.Node] = struct{}{}
// Since the node is not in the snapshot we can know the associated service // Since the node is not in the snapshot we can know the associated service
// instance is not in the snapshot either, since a service instance can't // instance is not in the snapshot either, since a service instance can't
@ -316,7 +343,7 @@ func (s *Server) handleUpdateService(
// Delete the service instance if not in the snapshot. // Delete the service instance if not in the snapshot.
sid := csn.Service.CompoundServiceID() sid := csn.Service.CompoundServiceID()
if _, ok := snap.Nodes[csn.Node.ID].Services[sid]; !ok { if _, ok := snap.Nodes[csn.Node.Node].Services[sid]; !ok {
err := s.Backend.CatalogDeregister(&structs.DeregisterRequest{ err := s.Backend.CatalogDeregister(&structs.DeregisterRequest{
Node: csn.Node.Node, Node: csn.Node.Node,
ServiceID: csn.Service.ID, ServiceID: csn.Service.ID,
@ -335,7 +362,7 @@ func (s *Server) handleUpdateService(
// Reconcile checks. // Reconcile checks.
for _, chk := range csn.Checks { for _, chk := range csn.Checks {
if _, ok := snap.Nodes[csn.Node.ID].Services[sid].Checks[chk.CheckID]; !ok { if _, ok := snap.Nodes[csn.Node.Node].Services[sid].Checks[chk.CheckID]; !ok {
// Checks without a ServiceID are node checks. // Checks without a ServiceID are node checks.
// If the node exists but the check does not then the check was deleted. // If the node exists but the check does not then the check was deleted.
if chk.ServiceID == "" { if chk.ServiceID == "" {
@ -425,14 +452,24 @@ func (s *Server) handleUpsertRoots(
func (s *Server) handleDelete( func (s *Server) handleDelete(
peerName string, peerName string,
partition string, partition string,
mutableStatus *MutableStatus,
resourceURL string, resourceURL string,
resourceID string, resourceID string,
logger hclog.Logger,
) error { ) error {
switch resourceURL { switch resourceURL {
case pbpeerstream.TypeURLService: case pbpeerstream.TypeURLExportedService:
sn := structs.ServiceNameFromString(resourceID) sn := structs.ServiceNameFromString(resourceID)
sn.OverridePartition(partition) sn.OverridePartition(partition)
return s.handleUpdateService(peerName, partition, sn, nil)
err := s.handleUpdateService(peerName, partition, sn, nil)
if err != nil {
return err
}
mutableStatus.RemoveImportedService(sn)
return nil
default: default:
return fmt.Errorf("unexpected resourceURL: %s", resourceURL) return fmt.Errorf("unexpected resourceURL: %s", resourceURL)

View File

@ -1,6 +1,8 @@
package peerstream package peerstream
import ( import (
"time"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -17,6 +19,11 @@ import (
// TODO(peering): fix up these interfaces to be more testable now that they are // TODO(peering): fix up these interfaces to be more testable now that they are
// extracted from private peering // extracted from private peering
const (
defaultOutgoingHeartbeatInterval = 15 * time.Second
defaultIncomingHeartbeatTimeout = 2 * time.Minute
)
type Server struct { type Server struct {
Config Config
} }
@ -30,6 +37,12 @@ type Config struct {
// Datacenter of the Consul server this gRPC server is hosted on // Datacenter of the Consul server this gRPC server is hosted on
Datacenter string Datacenter string
ConnectEnabled bool ConnectEnabled bool
// outgoingHeartbeatInterval is how often we send a heartbeat.
outgoingHeartbeatInterval time.Duration
// incomingHeartbeatTimeout is how long we'll wait between receiving heartbeats before we close the connection.
incomingHeartbeatTimeout time.Duration
} }
//go:generate mockery --name ACLResolver --inpackage //go:generate mockery --name ACLResolver --inpackage
@ -46,6 +59,12 @@ func NewServer(cfg Config) *Server {
if cfg.Datacenter == "" { if cfg.Datacenter == "" {
panic("Datacenter is required") panic("Datacenter is required")
} }
if cfg.outgoingHeartbeatInterval == 0 {
cfg.outgoingHeartbeatInterval = defaultOutgoingHeartbeatInterval
}
if cfg.incomingHeartbeatTimeout == 0 {
cfg.incomingHeartbeatTimeout = defaultIncomingHeartbeatTimeout
}
return &Server{ return &Server{
Config: cfg, Config: cfg,
} }

View File

@ -5,11 +5,12 @@ import (
"fmt" "fmt"
"io" "io"
"strings" "strings"
"sync"
"time"
"github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"google.golang.org/genproto/googleapis/rpc/code"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status" grpcstatus "google.golang.org/grpc/status"
@ -103,6 +104,7 @@ func (s *Server) StreamResources(stream pbpeerstream.PeerStreamService_StreamRes
RemoteID: "", RemoteID: "",
PeerName: p.Name, PeerName: p.Name,
Partition: p.Partition, Partition: p.Partition,
InitialResourceURL: req.ResourceURL,
Stream: stream, Stream: stream,
} }
err = s.HandleStream(streamReq) err = s.HandleStream(streamReq)
@ -129,6 +131,9 @@ type HandleStreamRequest struct {
// Partition is the local partition associated with the peer. // Partition is the local partition associated with the peer.
Partition string Partition string
// InitialResourceURL is the ResourceURL from the initial Request.
InitialResourceURL string
// Stream is the open stream to the peer cluster. // Stream is the open stream to the peer cluster.
Stream BidirectionalStream Stream BidirectionalStream
} }
@ -155,9 +160,19 @@ func (s *Server) DrainStream(req HandleStreamRequest) {
} }
} }
func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
if err := s.realHandleStream(streamReq); err != nil {
s.Tracker.DisconnectedDueToError(streamReq.LocalID, err.Error())
return err
}
// TODO(peering) Also need to clear subscriptions associated with the peer
s.Tracker.DisconnectedGracefully(streamReq.LocalID)
return nil
}
// The localID provided is the locally-generated identifier for the peering. // The localID provided is the locally-generated identifier for the peering.
// The remoteID is an identifier that the remote peer recognizes for the peering. // The remoteID is an identifier that the remote peer recognizes for the peering.
func (s *Server) HandleStream(streamReq HandleStreamRequest) error { func (s *Server) realHandleStream(streamReq HandleStreamRequest) error {
// TODO: pass logger down from caller? // TODO: pass logger down from caller?
logger := s.Logger.Named("stream"). logger := s.Logger.Named("stream").
With("peer_name", streamReq.PeerName). With("peer_name", streamReq.PeerName).
@ -170,9 +185,6 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
return fmt.Errorf("failed to register stream: %v", err) return fmt.Errorf("failed to register stream: %v", err)
} }
// TODO(peering) Also need to clear subscriptions associated with the peer
defer s.Tracker.Disconnected(streamReq.LocalID)
var trustDomain string var trustDomain string
if s.ConnectEnabled { if s.ConnectEnabled {
// Read the TrustDomain up front - we do not allow users to change the ClusterID // Read the TrustDomain up front - we do not allow users to change the ClusterID
@ -183,6 +195,13 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
} }
} }
remoteSubTracker := newResourceSubscriptionTracker()
if streamReq.InitialResourceURL != "" {
if remoteSubTracker.Subscribe(streamReq.InitialResourceURL) {
logger.Info("subscribing to resource type", "resourceURL", streamReq.InitialResourceURL)
}
}
mgr := newSubscriptionManager( mgr := newSubscriptionManager(
streamReq.Stream.Context(), streamReq.Stream.Context(),
logger, logger,
@ -190,24 +209,46 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
trustDomain, trustDomain,
s.Backend, s.Backend,
s.GetStore, s.GetStore,
remoteSubTracker,
) )
subCh := mgr.subscribe(streamReq.Stream.Context(), streamReq.LocalID, streamReq.PeerName, streamReq.Partition) subCh := mgr.subscribe(streamReq.Stream.Context(), streamReq.LocalID, streamReq.PeerName, streamReq.Partition)
// We need a mutex to protect against simultaneous sends to the client.
var sendMutex sync.Mutex
// streamSend is a helper function that sends msg over the stream
// respecting the send mutex. It also logs the send and calls status.TrackSendError
// on error.
streamSend := func(msg *pbpeerstream.ReplicationMessage) error {
logTraceSend(logger, msg)
sendMutex.Lock()
err := streamReq.Stream.Send(msg)
sendMutex.Unlock()
if err != nil {
status.TrackSendError(err.Error())
}
return err
}
// Subscribe to all relevant resource types.
for _, resourceURL := range []string{
pbpeerstream.TypeURLExportedService,
pbpeerstream.TypeURLPeeringTrustBundle,
} {
sub := makeReplicationRequest(&pbpeerstream.ReplicationMessage_Request{ sub := makeReplicationRequest(&pbpeerstream.ReplicationMessage_Request{
ResourceURL: pbpeerstream.TypeURLService, ResourceURL: resourceURL,
PeerID: streamReq.RemoteID, PeerID: streamReq.RemoteID,
}) })
logTraceSend(logger, sub) if err := streamSend(sub); err != nil {
if err := streamReq.Stream.Send(sub); err != nil {
if err == io.EOF { if err == io.EOF {
logger.Info("stream ended by peer") logger.Info("stream ended by peer")
status.TrackReceiveError(err.Error())
return nil return nil
} }
// TODO(peering) Test error handling in calls to Send/Recv // TODO(peering) Test error handling in calls to Send/Recv
status.TrackSendError(err.Error()) return fmt.Errorf("failed to send subscription for %q to stream: %w", resourceURL, err)
return fmt.Errorf("failed to send to stream: %v", err) }
} }
// TODO(peering): Should this be buffered? // TODO(peering): Should this be buffered?
@ -224,15 +265,49 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
if err == io.EOF { if err == io.EOF {
logger.Info("stream ended by peer") logger.Info("stream ended by peer")
status.TrackReceiveError(err.Error()) status.TrackRecvError(err.Error())
return return
} }
logger.Error("failed to receive from stream", "error", err) logger.Error("failed to receive from stream", "error", err)
status.TrackReceiveError(err.Error()) status.TrackRecvError(err.Error())
return return
} }
}() }()
// Heartbeat sender.
go func() {
tick := time.NewTicker(s.outgoingHeartbeatInterval)
defer tick.Stop()
for {
select {
case <-streamReq.Stream.Context().Done():
return
case <-tick.C:
}
heartbeat := &pbpeerstream.ReplicationMessage{
Payload: &pbpeerstream.ReplicationMessage_Heartbeat_{
Heartbeat: &pbpeerstream.ReplicationMessage_Heartbeat{},
},
}
if err := streamSend(heartbeat); err != nil {
logger.Warn("error sending heartbeat", "err", err)
}
}
}()
// incomingHeartbeatCtx will complete if incoming heartbeats time out.
incomingHeartbeatCtx, incomingHeartbeatCtxCancel :=
context.WithTimeout(context.Background(), s.incomingHeartbeatTimeout)
// NOTE: It's important that we wrap the call to cancel in a wrapper func because during the loop we're
// re-assigning the value of incomingHeartbeatCtxCancel and we want the defer to run on the last assigned
// value, not the current value.
defer func() {
incomingHeartbeatCtxCancel()
}()
for { for {
select { select {
// When the doneCh is closed that means that the peering was deleted locally. // When the doneCh is closed that means that the peering was deleted locally.
@ -244,10 +319,10 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
Terminated: &pbpeerstream.ReplicationMessage_Terminated{}, Terminated: &pbpeerstream.ReplicationMessage_Terminated{},
}, },
} }
logTraceSend(logger, term) if err := streamSend(term); err != nil {
// Nolint directive needed due to bug in govet that doesn't see that the cancel
if err := streamReq.Stream.Send(term); err != nil { // func of the incomingHeartbeatTimer _does_ get called.
status.TrackSendError(err.Error()) //nolint:govet
return fmt.Errorf("failed to send to stream: %v", err) return fmt.Errorf("failed to send to stream: %v", err)
} }
@ -256,10 +331,20 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
return nil return nil
// We haven't received a heartbeat within the expected interval. Kill the stream.
case <-incomingHeartbeatCtx.Done():
logger.Error("ending stream due to heartbeat timeout")
return fmt.Errorf("heartbeat timeout")
case msg, open := <-recvChan: case msg, open := <-recvChan:
if !open { if !open {
logger.Trace("no longer receiving data on the stream") // The only time we expect the stream to end is when we've received a "Terminated" message.
return nil // We handle the case of receiving the Terminated message below and then this function exits.
// So if the channel is closed while this function is still running then we haven't received a Terminated
// message which means we want to try and reestablish the stream.
// It's the responsibility of the caller of this function to reestablish the stream on error and so that's
// why we return an error here.
return fmt.Errorf("stream ended unexpectedly")
} }
// NOTE: this code should have similar error handling to the // NOTE: this code should have similar error handling to the
@ -284,17 +369,86 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
if !pbpeerstream.KnownTypeURL(req.ResourceURL) { if !pbpeerstream.KnownTypeURL(req.ResourceURL) {
return grpcstatus.Errorf(codes.InvalidArgument, "subscription request to unknown resource URL: %s", req.ResourceURL) return grpcstatus.Errorf(codes.InvalidArgument, "subscription request to unknown resource URL: %s", req.ResourceURL)
} }
switch {
case req.ResponseNonce == "":
// TODO(peering): This can happen on a client peer since they don't try to receive subscriptions before entering HandleStream.
// Should change that behavior or only allow it that one time.
case req.Error != nil && (req.Error.Code != int32(code.Code_OK) || req.Error.Message != ""): // There are different formats of requests depending upon where in the stream lifecycle we are.
//
// 1. Initial Request: This is the first request being received
// FROM the establishing peer. This is handled specially in
// (*Server).StreamResources BEFORE calling
// (*Server).HandleStream. This takes care of determining what
// the PeerID is for the stream. This is ALSO treated as (2) below.
//
// 2. Subscription Request: This is the first request for a
// given ResourceURL within a stream. The Initial Request (1)
// is always one of these as well.
//
// These must contain a valid ResourceURL with no Error or
// ResponseNonce set.
//
// It is valid to subscribe to the same ResourceURL twice
// within the lifetime of a stream, but all duplicate
// subscriptions are treated as no-ops upon receipt.
//
// 3. ACK Request: This is the message sent in reaction to an
// earlier Response to indicate that the response was processed
// by the other side successfully.
//
// These must contain a ResponseNonce and no Error.
//
// 4. NACK Request: This is the message sent in reaction to an
// earlier Response to indicate that the response was NOT
// processed by the other side successfully.
//
// These must contain a ResponseNonce and an Error.
//
if !remoteSubTracker.IsSubscribed(req.ResourceURL) {
// This must be a new subscription request to add a new
// resource type, vet it like a new request.
if !streamReq.WasDialed() {
if req.PeerID != "" && req.PeerID != streamReq.RemoteID {
// Not necessary after the first request from the dialer,
// but if provided must match.
return grpcstatus.Errorf(codes.InvalidArgument,
"initial subscription requests for a resource type must have consistent PeerID values: got=%q expected=%q",
req.PeerID,
streamReq.RemoteID,
)
}
}
if req.ResponseNonce != "" {
return grpcstatus.Error(codes.InvalidArgument, "initial subscription requests for a resource type must not contain a nonce")
}
if req.Error != nil {
return grpcstatus.Error(codes.InvalidArgument, "initial subscription request for a resource type must not contain an error")
}
if remoteSubTracker.Subscribe(req.ResourceURL) {
logger.Info("subscribing to resource type", "resourceURL", req.ResourceURL)
}
status.TrackAck()
continue
}
// At this point we have a valid ResourceURL and we are subscribed to it.
switch {
case req.ResponseNonce == "" && req.Error != nil:
return grpcstatus.Error(codes.InvalidArgument, "initial subscription request for a resource type must not contain an error")
case req.ResponseNonce != "" && req.Error == nil: // ACK
// TODO(peering): handle ACK fully
status.TrackAck()
case req.ResponseNonce != "" && req.Error != nil: // NACK
// TODO(peering): handle NACK fully
logger.Warn("client peer was unable to apply resource", "code", req.Error.Code, "error", req.Error.Message) logger.Warn("client peer was unable to apply resource", "code", req.Error.Code, "error", req.Error.Message)
status.TrackNack(fmt.Sprintf("client peer was unable to apply resource: %s", req.Error.Message)) status.TrackNack(fmt.Sprintf("client peer was unable to apply resource: %s", req.Error.Message))
default: default:
status.TrackAck() // This branch might be dead code, but it could also happen
// during a stray 're-subscribe' so just ignore the
// message.
} }
continue continue
@ -302,17 +456,15 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
if resp := msg.GetResponse(); resp != nil { if resp := msg.GetResponse(); resp != nil {
// TODO(peering): Ensure there's a nonce // TODO(peering): Ensure there's a nonce
reply, err := s.processResponse(streamReq.PeerName, streamReq.Partition, resp) reply, err := s.processResponse(streamReq.PeerName, streamReq.Partition, status, resp, logger)
if err != nil { if err != nil {
logger.Error("failed to persist resource", "resourceURL", resp.ResourceURL, "resourceID", resp.ResourceID) logger.Error("failed to persist resource", "resourceURL", resp.ResourceURL, "resourceID", resp.ResourceID)
status.TrackReceiveError(err.Error()) status.TrackRecvError(err.Error())
} else { } else {
status.TrackReceiveSuccess() status.TrackRecvResourceSuccess()
} }
logTraceSend(logger, reply) if err := streamSend(reply); err != nil {
if err := streamReq.Stream.Send(reply); err != nil {
status.TrackSendError(err.Error())
return fmt.Errorf("failed to send to stream: %v", err) return fmt.Errorf("failed to send to stream: %v", err)
} }
@ -329,11 +481,27 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
return nil return nil
} }
if msg.GetHeartbeat() != nil {
status.TrackRecvHeartbeat()
// Reset the heartbeat timeout by creating a new context.
// We first must cancel the old context so there's no leaks. This is safe to do because we're only
// reading that context within this for{} loop, and so we won't accidentally trigger the heartbeat
// timeout.
incomingHeartbeatCtxCancel()
// NOTE: IDEs and govet think that the reassigned cancel below never gets
// called, but it does by the defer when the heartbeat ctx is first created.
// They just can't trace the execution properly for some reason (possibly golang/go#29587).
//nolint:govet
incomingHeartbeatCtx, incomingHeartbeatCtxCancel =
context.WithTimeout(context.Background(), s.incomingHeartbeatTimeout)
}
case update := <-subCh: case update := <-subCh:
var resp *pbpeerstream.ReplicationMessage_Response var resp *pbpeerstream.ReplicationMessage_Response
switch { switch {
case strings.HasPrefix(update.CorrelationID, subExportedService): case strings.HasPrefix(update.CorrelationID, subExportedService):
resp, err = makeServiceResponse(logger, update) resp, err = makeServiceResponse(logger, status, update)
if err != nil { if err != nil {
// Log the error and skip this response to avoid locking up peering due to a bad update event. // Log the error and skip this response to avoid locking up peering due to a bad update event.
logger.Error("failed to create service response", "error", err) logger.Error("failed to create service response", "error", err)
@ -360,10 +528,7 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error {
} }
replResp := makeReplicationResponse(resp) replResp := makeReplicationResponse(resp)
if err := streamSend(replResp); err != nil {
logTraceSend(logger, replResp)
if err := streamReq.Stream.Send(replResp); err != nil {
status.TrackSendError(err.Error())
return fmt.Errorf("failed to push data for %q: %w", update.CorrelationID, err) return fmt.Errorf("failed to push data for %q: %w", update.CorrelationID, err)
} }
} }
@ -383,8 +548,8 @@ func getTrustDomain(store StateStore, logger hclog.Logger) (string, error) {
return connect.SpiffeIDSigningForCluster(cfg.ClusterID).Host(), nil return connect.SpiffeIDSigningForCluster(cfg.ClusterID).Host(), nil
} }
func (s *Server) StreamStatus(peer string) (resp Status, found bool) { func (s *Server) StreamStatus(peerID string) (resp Status, found bool) {
return s.Tracker.StreamStatus(peer) return s.Tracker.StreamStatus(peerID)
} }
// ConnectedStreams returns a map of connected stream IDs to the corresponding channel for tearing them down. // ConnectedStreams returns a map of connected stream IDs to the corresponding channel for tearing them down.
@ -420,3 +585,63 @@ func logTraceProto(logger hclog.Logger, pb proto.Message, received bool) {
logger.Trace("replication message", "direction", dir, "protobuf", out) logger.Trace("replication message", "direction", dir, "protobuf", out)
} }
// resourceSubscriptionTracker is used to keep track of the ResourceURLs that a
// stream has subscribed to and can notify you when a subscription comes in by
// closing the channels returned by SubscribedChan.
type resourceSubscriptionTracker struct {
// notifierMap keeps track of a notification channel for each resourceURL.
// Keys may exist in here even when they do not exist in 'subscribed' as
// calling SubscribedChan has to possibly create and and hand out a
// notification channel in advance of any notification.
notifierMap map[string]chan struct{}
// subscribed is a set that keeps track of resourceURLs that are currently
// subscribed to. Keys are never deleted. If a key is present in this map
// it is also present in 'notifierMap'.
subscribed map[string]struct{}
}
func newResourceSubscriptionTracker() *resourceSubscriptionTracker {
return &resourceSubscriptionTracker{
subscribed: make(map[string]struct{}),
notifierMap: make(map[string]chan struct{}),
}
}
// IsSubscribed returns true if the given ResourceURL has an active subscription.
func (t *resourceSubscriptionTracker) IsSubscribed(resourceURL string) bool {
_, ok := t.subscribed[resourceURL]
return ok
}
// Subscribe subscribes to the given ResourceURL. It will return true if this
// was the FIRST time a subscription occurred. It will also close the
// notification channel associated with this ResourceURL.
func (t *resourceSubscriptionTracker) Subscribe(resourceURL string) bool {
if _, ok := t.subscribed[resourceURL]; ok {
return false
}
t.subscribed[resourceURL] = struct{}{}
// and notify
ch := t.ensureNotifierChan(resourceURL)
close(ch)
return true
}
// SubscribedChan returns a channel that will be closed when the ResourceURL is
// subscribed using the Subscribe method.
func (t *resourceSubscriptionTracker) SubscribedChan(resourceURL string) <-chan struct{} {
return t.ensureNotifierChan(resourceURL)
}
func (t *resourceSubscriptionTracker) ensureNotifierChan(resourceURL string) chan struct{} {
if ch, ok := t.notifierMap[resourceURL]; ok {
return ch
}
ch := make(chan struct{})
t.notifierMap[resourceURL] = ch
return ch
}

File diff suppressed because it is too large Load Diff

View File

@ -4,9 +4,11 @@ import (
"fmt" "fmt"
"sync" "sync"
"time" "time"
"github.com/hashicorp/consul/agent/structs"
) )
// Tracker contains a map of (PeerID -> Status). // Tracker contains a map of (PeerID -> MutableStatus).
// As streams are opened and closed we track details about their status. // As streams are opened and closed we track details about their status.
type Tracker struct { type Tracker struct {
mu sync.RWMutex mu sync.RWMutex
@ -31,16 +33,37 @@ func (t *Tracker) SetClock(clock func() time.Time) {
} }
} }
// Register a stream for a given peer but do not mark it as connected.
func (t *Tracker) Register(id string) (*MutableStatus, error) {
t.mu.Lock()
defer t.mu.Unlock()
status, _, err := t.registerLocked(id, false)
return status, err
}
func (t *Tracker) registerLocked(id string, initAsConnected bool) (*MutableStatus, bool, error) {
status, ok := t.streams[id]
if !ok {
status = newMutableStatus(t.timeNow, initAsConnected)
t.streams[id] = status
return status, true, nil
}
return status, false, nil
}
// Connected registers a stream for a given peer, and marks it as connected. // Connected registers a stream for a given peer, and marks it as connected.
// It also enforces that there is only one active stream for a peer. // It also enforces that there is only one active stream for a peer.
func (t *Tracker) Connected(id string) (*MutableStatus, error) { func (t *Tracker) Connected(id string) (*MutableStatus, error) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
return t.connectedLocked(id)
}
status, ok := t.streams[id] func (t *Tracker) connectedLocked(id string) (*MutableStatus, error) {
if !ok { status, newlyRegistered, err := t.registerLocked(id, true)
status = newMutableStatus(t.timeNow) if err != nil {
t.streams[id] = status return nil, err
} else if newlyRegistered {
return status, nil return status, nil
} }
@ -52,13 +75,23 @@ func (t *Tracker) Connected(id string) (*MutableStatus, error) {
return status, nil return status, nil
} }
// Disconnected ensures that if a peer id's stream status is tracked, it is marked as disconnected. // DisconnectedGracefully marks the peer id's stream status as disconnected gracefully.
func (t *Tracker) Disconnected(id string) { func (t *Tracker) DisconnectedGracefully(id string) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
if status, ok := t.streams[id]; ok { if status, ok := t.streams[id]; ok {
status.TrackDisconnected() status.TrackDisconnectedGracefully()
}
}
// DisconnectedDueToError marks the peer id's stream status as disconnected due to an error.
func (t *Tracker) DisconnectedDueToError(id string, error string) {
t.mu.Lock()
defer t.mu.Unlock()
if status, ok := t.streams[id]; ok {
status.TrackDisconnectedDueToError(error)
} }
} }
@ -112,6 +145,10 @@ type Status struct {
// Connected is true when there is an open stream for the peer. // Connected is true when there is an open stream for the peer.
Connected bool Connected bool
// DisconnectErrorMessage tracks the error that caused the stream to disconnect non-gracefully.
// If the stream is connected or it disconnected gracefully it will be empty.
DisconnectErrorMessage string
// If the status is not connected, DisconnectTime tracks when the stream was closed. Else it's zero. // If the status is not connected, DisconnectTime tracks when the stream was closed. Else it's zero.
DisconnectTime time.Time DisconnectTime time.Time
@ -130,24 +167,39 @@ type Status struct {
// LastSendErrorMessage tracks the last error message when sending into the stream. // LastSendErrorMessage tracks the last error message when sending into the stream.
LastSendErrorMessage string LastSendErrorMessage string
// LastReceiveSuccess tracks the time we last successfully stored a resource replicated FROM the peer. // LastRecvHeartbeat tracks when we last received a heartbeat from our peer.
LastReceiveSuccess time.Time LastRecvHeartbeat time.Time
// LastReceiveError tracks either: // LastRecvResourceSuccess tracks the time we last successfully stored a resource replicated FROM the peer.
LastRecvResourceSuccess time.Time
// LastRecvError tracks either:
// - The time we failed to store a resource replicated FROM the peer. // - The time we failed to store a resource replicated FROM the peer.
// - The time of the last error when receiving from the stream. // - The time of the last error when receiving from the stream.
LastReceiveError time.Time LastRecvError time.Time
// LastReceiveError tracks either: // LastRecvErrorMessage tracks the last error message when receiving from the stream.
// - The error message when we failed to store a resource replicated FROM the peer. LastRecvErrorMessage string
// - The last error message when receiving from the stream.
LastReceiveErrorMessage string // TODO(peering): consider keeping track of imported and exported services thru raft
// ImportedServices keeps track of which service names are imported for the peer
ImportedServices map[string]struct{}
// ExportedServices keeps track of which service names a peer asks to export
ExportedServices map[string]struct{}
} }
func newMutableStatus(now func() time.Time) *MutableStatus { func (s *Status) GetImportedServicesCount() uint64 {
return uint64(len(s.ImportedServices))
}
func (s *Status) GetExportedServicesCount() uint64 {
return uint64(len(s.ExportedServices))
}
func newMutableStatus(now func() time.Time, connected bool) *MutableStatus {
return &MutableStatus{ return &MutableStatus{
Status: Status{ Status: Status{
Connected: true, Connected: connected,
}, },
timeNow: now, timeNow: now,
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
@ -171,16 +223,24 @@ func (s *MutableStatus) TrackSendError(error string) {
s.mu.Unlock() s.mu.Unlock()
} }
func (s *MutableStatus) TrackReceiveSuccess() { // TrackRecvResourceSuccess tracks receiving a replicated resource.
func (s *MutableStatus) TrackRecvResourceSuccess() {
s.mu.Lock() s.mu.Lock()
s.LastReceiveSuccess = s.timeNow().UTC() s.LastRecvResourceSuccess = s.timeNow().UTC()
s.mu.Unlock() s.mu.Unlock()
} }
func (s *MutableStatus) TrackReceiveError(error string) { // TrackRecvHeartbeat tracks receiving a heartbeat from our peer.
func (s *MutableStatus) TrackRecvHeartbeat() {
s.mu.Lock() s.mu.Lock()
s.LastReceiveError = s.timeNow().UTC() s.LastRecvHeartbeat = s.timeNow().UTC()
s.LastReceiveErrorMessage = error s.mu.Unlock()
}
func (s *MutableStatus) TrackRecvError(error string) {
s.mu.Lock()
s.LastRecvError = s.timeNow().UTC()
s.LastRecvErrorMessage = error
s.mu.Unlock() s.mu.Unlock()
} }
@ -195,13 +255,27 @@ func (s *MutableStatus) TrackConnected() {
s.mu.Lock() s.mu.Lock()
s.Connected = true s.Connected = true
s.DisconnectTime = time.Time{} s.DisconnectTime = time.Time{}
s.DisconnectErrorMessage = ""
s.mu.Unlock() s.mu.Unlock()
} }
func (s *MutableStatus) TrackDisconnected() { // TrackDisconnectedGracefully tracks when the stream was disconnected in a way we expected.
// For example, we got a terminated message, or we terminated the stream ourselves.
func (s *MutableStatus) TrackDisconnectedGracefully() {
s.mu.Lock() s.mu.Lock()
s.Connected = false s.Connected = false
s.DisconnectTime = s.timeNow().UTC() s.DisconnectTime = s.timeNow().UTC()
s.DisconnectErrorMessage = ""
s.mu.Unlock()
}
// TrackDisconnectedDueToError tracks when the stream was disconnected due to an error.
// For example the heartbeat timed out, or we couldn't send into the stream.
func (s *MutableStatus) TrackDisconnectedDueToError(error string) {
s.mu.Lock()
s.Connected = false
s.DisconnectTime = s.timeNow().UTC()
s.DisconnectErrorMessage = error
s.mu.Unlock() s.mu.Unlock()
} }
@ -222,3 +296,53 @@ func (s *MutableStatus) GetStatus() Status {
return copy return copy
} }
func (s *MutableStatus) RemoveImportedService(sn structs.ServiceName) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.ImportedServices, sn.String())
}
func (s *MutableStatus) TrackImportedService(sn structs.ServiceName) {
s.mu.Lock()
defer s.mu.Unlock()
if s.ImportedServices == nil {
s.ImportedServices = make(map[string]struct{})
}
s.ImportedServices[sn.String()] = struct{}{}
}
func (s *MutableStatus) GetImportedServicesCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.ImportedServices)
}
func (s *MutableStatus) RemoveExportedService(sn structs.ServiceName) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.ExportedServices, sn.String())
}
func (s *MutableStatus) TrackExportedService(sn structs.ServiceName) {
s.mu.Lock()
defer s.mu.Unlock()
if s.ExportedServices == nil {
s.ExportedServices = make(map[string]struct{})
}
s.ExportedServices[sn.String()] = struct{}{}
}
func (s *MutableStatus) GetExportedServicesCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.ExportedServices)
}

View File

@ -62,7 +62,7 @@ func TestTracker_EnsureConnectedDisconnected(t *testing.T) {
}) })
testutil.RunStep(t, "disconnect", func(t *testing.T) { testutil.RunStep(t, "disconnect", func(t *testing.T) {
tracker.Disconnected(peerID) tracker.DisconnectedGracefully(peerID)
sequence++ sequence++
expect := Status{ expect := Status{
@ -147,7 +147,7 @@ func TestTracker_connectedStreams(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Mark foo as disconnected to avoid showing it as an active stream // Mark foo as disconnected to avoid showing it as an active stream
status.TrackDisconnected() status.TrackDisconnectedGracefully()
_, err = s.Connected("bar") _, err = s.Connected("bar")
require.NoError(t, err) require.NoError(t, err)
@ -162,3 +162,61 @@ func TestTracker_connectedStreams(t *testing.T) {
}) })
} }
} }
func TestMutableStatus_TrackConnected(t *testing.T) {
s := MutableStatus{
Status: Status{
Connected: false,
DisconnectTime: time.Now(),
DisconnectErrorMessage: "disconnected",
},
}
s.TrackConnected()
require.True(t, s.IsConnected())
require.True(t, s.Connected)
require.Equal(t, time.Time{}, s.DisconnectTime)
require.Empty(t, s.DisconnectErrorMessage)
}
func TestMutableStatus_TrackDisconnectedGracefully(t *testing.T) {
it := incrementalTime{
base: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
}
disconnectTime := it.FutureNow(1)
s := MutableStatus{
timeNow: it.Now,
Status: Status{
Connected: true,
},
}
s.TrackDisconnectedGracefully()
require.False(t, s.IsConnected())
require.False(t, s.Connected)
require.Equal(t, disconnectTime, s.DisconnectTime)
require.Empty(t, s.DisconnectErrorMessage)
}
func TestMutableStatus_TrackDisconnectedDueToError(t *testing.T) {
it := incrementalTime{
base: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
}
disconnectTime := it.FutureNow(1)
s := MutableStatus{
timeNow: it.Now,
Status: Status{
Connected: true,
},
}
s.TrackDisconnectedDueToError("disconnect err")
require.False(t, s.IsConnected())
require.False(t, s.Connected)
require.Equal(t, disconnectTime, s.DisconnectTime)
require.Equal(t, "disconnect err", s.DisconnectErrorMessage)
}

View File

@ -19,6 +19,13 @@ import (
// streaming machinery instead to be cheaper. // streaming machinery instead to be cheaper.
func (m *subscriptionManager) notifyExportedServicesForPeerID(ctx context.Context, state *subscriptionState, peerID string) { func (m *subscriptionManager) notifyExportedServicesForPeerID(ctx context.Context, state *subscriptionState, peerID string) {
// Wait until this is subscribed-to.
select {
case <-m.serviceSubReady:
case <-ctx.Done():
return
}
// syncSubscriptionsAndBlock ensures that the subscriptions to the subscription backend // syncSubscriptionsAndBlock ensures that the subscriptions to the subscription backend
// match the list of services exported to the peer. // match the list of services exported to the peer.
m.syncViaBlockingQuery(ctx, "exported-services", func(ctx context.Context, store StateStore, ws memdb.WatchSet) (interface{}, error) { m.syncViaBlockingQuery(ctx, "exported-services", func(ctx context.Context, store StateStore, ws memdb.WatchSet) (interface{}, error) {
@ -34,6 +41,13 @@ func (m *subscriptionManager) notifyExportedServicesForPeerID(ctx context.Contex
// TODO: add a new streaming subscription type to list-by-kind-and-partition since we're getting evictions // TODO: add a new streaming subscription type to list-by-kind-and-partition since we're getting evictions
func (m *subscriptionManager) notifyMeshGatewaysForPartition(ctx context.Context, state *subscriptionState, partition string) { func (m *subscriptionManager) notifyMeshGatewaysForPartition(ctx context.Context, state *subscriptionState, partition string) {
// Wait until this is subscribed-to.
select {
case <-m.serviceSubReady:
case <-ctx.Done():
return
}
m.syncViaBlockingQuery(ctx, "mesh-gateways", func(ctx context.Context, store StateStore, ws memdb.WatchSet) (interface{}, error) { m.syncViaBlockingQuery(ctx, "mesh-gateways", func(ctx context.Context, store StateStore, ws memdb.WatchSet) (interface{}, error) {
// Fetch our current list of all mesh gateways. // Fetch our current list of all mesh gateways.
entMeta := structs.DefaultEnterpriseMetaInPartition(partition) entMeta := structs.DefaultEnterpriseMetaInPartition(partition)

View File

@ -19,6 +19,7 @@ import (
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/proto/pbcommon" "github.com/hashicorp/consul/proto/pbcommon"
"github.com/hashicorp/consul/proto/pbpeering" "github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/proto/pbpeerstream"
"github.com/hashicorp/consul/proto/pbservice" "github.com/hashicorp/consul/proto/pbservice"
) )
@ -39,6 +40,8 @@ type subscriptionManager struct {
viewStore MaterializedViewStore viewStore MaterializedViewStore
backend SubscriptionBackend backend SubscriptionBackend
getStore func() StateStore getStore func() StateStore
serviceSubReady <-chan struct{}
trustBundlesSubReady <-chan struct{}
} }
// TODO(peering): Maybe centralize so that there is a single manager per datacenter, rather than per peering. // TODO(peering): Maybe centralize so that there is a single manager per datacenter, rather than per peering.
@ -49,6 +52,7 @@ func newSubscriptionManager(
trustDomain string, trustDomain string,
backend SubscriptionBackend, backend SubscriptionBackend,
getStore func() StateStore, getStore func() StateStore,
remoteSubTracker *resourceSubscriptionTracker,
) *subscriptionManager { ) *subscriptionManager {
logger = logger.Named("subscriptions") logger = logger.Named("subscriptions")
store := submatview.NewStore(logger.Named("viewstore")) store := submatview.NewStore(logger.Named("viewstore"))
@ -61,6 +65,8 @@ func newSubscriptionManager(
viewStore: store, viewStore: store,
backend: backend, backend: backend,
getStore: getStore, getStore: getStore,
serviceSubReady: remoteSubTracker.SubscribedChan(pbpeerstream.TypeURLExportedService),
trustBundlesSubReady: remoteSubTracker.SubscribedChan(pbpeerstream.TypeURLPeeringTrustBundle),
} }
} }
@ -297,6 +303,13 @@ func (m *subscriptionManager) notifyRootCAUpdatesForPartition(
updateCh chan<- cache.UpdateEvent, updateCh chan<- cache.UpdateEvent,
partition string, partition string,
) { ) {
// Wait until this is subscribed-to.
select {
case <-m.trustBundlesSubReady:
case <-ctx.Done():
return
}
var idx uint64 var idx uint64
// TODO(peering): retry logic; fail past a threshold // TODO(peering): retry logic; fail past a threshold
for { for {

View File

@ -16,6 +16,7 @@ import (
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbcommon" "github.com/hashicorp/consul/proto/pbcommon"
"github.com/hashicorp/consul/proto/pbpeering" "github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/proto/pbpeerstream"
"github.com/hashicorp/consul/proto/pbservice" "github.com/hashicorp/consul/proto/pbservice"
"github.com/hashicorp/consul/proto/prototest" "github.com/hashicorp/consul/proto/prototest"
"github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil"
@ -32,12 +33,16 @@ func TestSubscriptionManager_RegisterDeregister(t *testing.T) {
_, id := backend.ensurePeering(t, "my-peering") _, id := backend.ensurePeering(t, "my-peering")
partition := acl.DefaultEnterpriseMeta().PartitionOrEmpty() partition := acl.DefaultEnterpriseMeta().PartitionOrEmpty()
// Only configure a tracker for catalog events.
tracker := newResourceSubscriptionTracker()
tracker.Subscribe(pbpeerstream.TypeURLExportedService)
mgr := newSubscriptionManager(ctx, testutil.Logger(t), Config{ mgr := newSubscriptionManager(ctx, testutil.Logger(t), Config{
Datacenter: "dc1", Datacenter: "dc1",
ConnectEnabled: true, ConnectEnabled: true,
}, connect.TestTrustDomain, backend, func() StateStore { }, connect.TestTrustDomain, backend, func() StateStore {
return backend.store return backend.store
}) }, tracker)
subCh := mgr.subscribe(ctx, id, "my-peering", partition) subCh := mgr.subscribe(ctx, id, "my-peering", partition)
var ( var (
@ -442,12 +447,16 @@ func TestSubscriptionManager_InitialSnapshot(t *testing.T) {
_, id := backend.ensurePeering(t, "my-peering") _, id := backend.ensurePeering(t, "my-peering")
partition := acl.DefaultEnterpriseMeta().PartitionOrEmpty() partition := acl.DefaultEnterpriseMeta().PartitionOrEmpty()
// Only configure a tracker for catalog events.
tracker := newResourceSubscriptionTracker()
tracker.Subscribe(pbpeerstream.TypeURLExportedService)
mgr := newSubscriptionManager(ctx, testutil.Logger(t), Config{ mgr := newSubscriptionManager(ctx, testutil.Logger(t), Config{
Datacenter: "dc1", Datacenter: "dc1",
ConnectEnabled: true, ConnectEnabled: true,
}, connect.TestTrustDomain, backend, func() StateStore { }, connect.TestTrustDomain, backend, func() StateStore {
return backend.store return backend.store
}) }, tracker)
subCh := mgr.subscribe(ctx, id, "my-peering", partition) subCh := mgr.subscribe(ctx, id, "my-peering", partition)
// Register two services that are not yet exported // Register two services that are not yet exported
@ -571,21 +580,21 @@ func TestSubscriptionManager_CARoots(t *testing.T) {
_, id := backend.ensurePeering(t, "my-peering") _, id := backend.ensurePeering(t, "my-peering")
partition := acl.DefaultEnterpriseMeta().PartitionOrEmpty() partition := acl.DefaultEnterpriseMeta().PartitionOrEmpty()
// Only configure a tracker for CA roots events.
tracker := newResourceSubscriptionTracker()
tracker.Subscribe(pbpeerstream.TypeURLPeeringTrustBundle)
mgr := newSubscriptionManager(ctx, testutil.Logger(t), Config{ mgr := newSubscriptionManager(ctx, testutil.Logger(t), Config{
Datacenter: "dc1", Datacenter: "dc1",
ConnectEnabled: true, ConnectEnabled: true,
}, connect.TestTrustDomain, backend, func() StateStore { }, connect.TestTrustDomain, backend, func() StateStore {
return backend.store return backend.store
}) }, tracker)
subCh := mgr.subscribe(ctx, id, "my-peering", partition) subCh := mgr.subscribe(ctx, id, "my-peering", partition)
testutil.RunStep(t, "initial events contain trust bundle", func(t *testing.T) { testutil.RunStep(t, "initial events contain trust bundle", func(t *testing.T) {
// events are ordered so we can expect a deterministic list // events are ordered so we can expect a deterministic list
expectEvents(t, subCh, expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) {
// mesh-gateway assertions are done in other tests
require.Equal(t, subMeshGateway+partition, got.CorrelationID)
},
func(t *testing.T, got cache.UpdateEvent) { func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, subCARoot, got.CorrelationID) require.Equal(t, subCARoot, got.CorrelationID)
roots, ok := got.Result.(*pbpeering.PeeringTrustBundle) roots, ok := got.Result.(*pbpeering.PeeringTrustBundle)

View File

@ -2,6 +2,7 @@ package peerstream
import ( import (
"context" "context"
"fmt"
"io" "io"
"sync" "sync"
"time" "time"
@ -24,14 +25,7 @@ func (c *MockClient) Send(r *pbpeerstream.ReplicationMessage) error {
} }
func (c *MockClient) Recv() (*pbpeerstream.ReplicationMessage, error) { func (c *MockClient) Recv() (*pbpeerstream.ReplicationMessage, error) {
select { return c.RecvWithTimeout(10 * time.Millisecond)
case err := <-c.ErrCh:
return nil, err
case r := <-c.ReplicationStream.sendCh:
return r, nil
case <-time.After(10 * time.Millisecond):
return nil, io.EOF
}
} }
func (c *MockClient) RecvWithTimeout(dur time.Duration) (*pbpeerstream.ReplicationMessage, error) { func (c *MockClient) RecvWithTimeout(dur time.Duration) (*pbpeerstream.ReplicationMessage, error) {
@ -61,7 +55,6 @@ type MockStream struct {
recvCh chan *pbpeerstream.ReplicationMessage recvCh chan *pbpeerstream.ReplicationMessage
ctx context.Context ctx context.Context
mu sync.Mutex
} }
var _ pbpeerstream.PeerStreamService_StreamResourcesServer = (*MockStream)(nil) var _ pbpeerstream.PeerStreamService_StreamResourcesServer = (*MockStream)(nil)
@ -117,12 +110,37 @@ func (s *MockStream) SendHeader(metadata.MD) error {
// SetTrailer implements grpc.ServerStream // SetTrailer implements grpc.ServerStream
func (s *MockStream) SetTrailer(metadata.MD) {} func (s *MockStream) SetTrailer(metadata.MD) {}
// incrementalTime is an artificial clock used during testing. For those
// scenarios you would pass around the method pointer for `Now` in places where
// you would be using `time.Now`.
type incrementalTime struct { type incrementalTime struct {
base time.Time base time.Time
next uint64 next uint64
mu sync.Mutex
} }
// Now advances the internal clock by 1 second and returns that value.
func (t *incrementalTime) Now() time.Time { func (t *incrementalTime) Now() time.Time {
t.mu.Lock()
defer t.mu.Unlock()
t.next++ t.next++
return t.base.Add(time.Duration(t.next) * time.Second)
dur := time.Duration(t.next) * time.Second
return t.base.Add(dur)
}
// FutureNow will return a given future value of the Now() function.
// The numerical argument indicates which future Now value you wanted. The
// value must be > 0.
func (t *incrementalTime) FutureNow(n int) time.Time {
if n < 1 {
panic(fmt.Sprintf("argument must be > 1 but was %d", n))
}
t.mu.Lock()
defer t.mu.Unlock()
dur := time.Duration(t.next+uint64(n)) * time.Second
return t.base.Add(dur)
} }

View File

@ -256,15 +256,6 @@ func (l *State) aclTokenForServiceSync(id structs.ServiceID, fallback func() str
return fallback() return fallback()
} }
// AddService is used to add a service entry to the local state.
// This entry is persistent and the agent will make a best effort to
// ensure it is registered
func (l *State) AddService(service *structs.NodeService, token string) error {
l.Lock()
defer l.Unlock()
return l.addServiceLocked(service, token)
}
func (l *State) addServiceLocked(service *structs.NodeService, token string) error { func (l *State) addServiceLocked(service *structs.NodeService, token string) error {
if service == nil { if service == nil {
return fmt.Errorf("no service") return fmt.Errorf("no service")
@ -293,7 +284,9 @@ func (l *State) addServiceLocked(service *structs.NodeService, token string) err
return nil return nil
} }
// AddServiceWithChecks adds a service and its check tp the local state atomically // AddServiceWithChecks adds a service entry and its checks to the local state atomically
// This entry is persistent and the agent will make a best effort to
// ensure it is registered
func (l *State) AddServiceWithChecks(service *structs.NodeService, checks []*structs.HealthCheck, token string) error { func (l *State) AddServiceWithChecks(service *structs.NodeService, checks []*structs.HealthCheck, token string) error {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()

View File

@ -64,7 +64,7 @@ func TestAgentAntiEntropy_Services(t *testing.T) {
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
assert.False(t, a.State.ServiceExists(structs.ServiceID{ID: srv1.ID})) assert.False(t, a.State.ServiceExists(structs.ServiceID{ID: srv1.ID}))
a.State.AddService(srv1, "") a.State.AddServiceWithChecks(srv1, nil, "")
assert.True(t, a.State.ServiceExists(structs.ServiceID{ID: srv1.ID})) assert.True(t, a.State.ServiceExists(structs.ServiceID{ID: srv1.ID}))
args.Service = srv1 args.Service = srv1
if err := a.RPC("Catalog.Register", args, &out); err != nil { if err := a.RPC("Catalog.Register", args, &out); err != nil {
@ -83,7 +83,7 @@ func TestAgentAntiEntropy_Services(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv2, "") a.State.AddServiceWithChecks(srv2, nil, "")
srv2_mod := new(structs.NodeService) srv2_mod := new(structs.NodeService)
*srv2_mod = *srv2 *srv2_mod = *srv2
@ -105,7 +105,7 @@ func TestAgentAntiEntropy_Services(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv3, "") a.State.AddServiceWithChecks(srv3, nil, "")
// Exists remote (delete) // Exists remote (delete)
srv4 := &structs.NodeService{ srv4 := &structs.NodeService{
@ -137,7 +137,7 @@ func TestAgentAntiEntropy_Services(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv5, "") a.State.AddServiceWithChecks(srv5, nil, "")
srv5_mod := new(structs.NodeService) srv5_mod := new(structs.NodeService)
*srv5_mod = *srv5 *srv5_mod = *srv5
@ -290,7 +290,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv1, "") a.State.AddServiceWithChecks(srv1, nil, "")
require.NoError(t, a.RPC("Catalog.Register", &structs.RegisterRequest{ require.NoError(t, a.RPC("Catalog.Register", &structs.RegisterRequest{
Datacenter: "dc1", Datacenter: "dc1",
Node: a.Config.NodeName, Node: a.Config.NodeName,
@ -311,7 +311,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv2, "") a.State.AddServiceWithChecks(srv2, nil, "")
srv2_mod := clone(srv2) srv2_mod := clone(srv2)
srv2_mod.Port = 9000 srv2_mod.Port = 9000
@ -335,7 +335,7 @@ func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv3, "") a.State.AddServiceWithChecks(srv3, nil, "")
// Exists remote (delete) // Exists remote (delete)
srv4 := &structs.NodeService{ srv4 := &structs.NodeService{
@ -496,7 +496,7 @@ func TestAgent_ServiceWatchCh(t *testing.T) {
Tags: []string{"tag1"}, Tags: []string{"tag1"},
Port: 6100, Port: 6100,
} }
require.NoError(t, a.State.AddService(srv1, "")) require.NoError(t, a.State.AddServiceWithChecks(srv1, nil, ""))
verifyState := func(ss *local.ServiceState) { verifyState := func(ss *local.ServiceState) {
require.NotNil(t, ss) require.NotNil(t, ss)
@ -518,7 +518,7 @@ func TestAgent_ServiceWatchCh(t *testing.T) {
go func() { go func() {
srv2 := srv1 srv2 := srv1
srv2.Port = 6200 srv2.Port = 6200
require.NoError(t, a.State.AddService(srv2, "")) require.NoError(t, a.State.AddServiceWithChecks(srv2, nil, ""))
}() }()
// We should observe WatchCh close // We should observe WatchCh close
@ -595,7 +595,7 @@ func TestAgentAntiEntropy_EnableTagOverride(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv1, "") a.State.AddServiceWithChecks(srv1, nil, "")
// register a local service with tag override disabled // register a local service with tag override disabled
srv2 := &structs.NodeService{ srv2 := &structs.NodeService{
@ -610,7 +610,7 @@ func TestAgentAntiEntropy_EnableTagOverride(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv2, "") a.State.AddServiceWithChecks(srv2, nil, "")
// make sure they are both in the catalog // make sure they are both in the catalog
if err := a.State.SyncChanges(); err != nil { if err := a.State.SyncChanges(); err != nil {
@ -722,7 +722,7 @@ func TestAgentAntiEntropy_Services_WithChecks(t *testing.T) {
Tags: []string{"primary"}, Tags: []string{"primary"},
Port: 5000, Port: 5000,
} }
a.State.AddService(srv, "") a.State.AddServiceWithChecks(srv, nil, "")
chk := &structs.HealthCheck{ chk := &structs.HealthCheck{
Node: a.Config.NodeName, Node: a.Config.NodeName,
@ -772,7 +772,7 @@ func TestAgentAntiEntropy_Services_WithChecks(t *testing.T) {
Tags: []string{"primary"}, Tags: []string{"primary"},
Port: 5000, Port: 5000,
} }
a.State.AddService(srv, "") a.State.AddServiceWithChecks(srv, nil, "")
chk1 := &structs.HealthCheck{ chk1 := &structs.HealthCheck{
Node: a.Config.NodeName, Node: a.Config.NodeName,
@ -873,7 +873,7 @@ func TestAgentAntiEntropy_Services_ACLDeny(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv1, token) a.State.AddServiceWithChecks(srv1, nil, token)
// Create service (allowed) // Create service (allowed)
srv2 := &structs.NodeService{ srv2 := &structs.NodeService{
@ -887,7 +887,7 @@ func TestAgentAntiEntropy_Services_ACLDeny(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv2, token) a.State.AddServiceWithChecks(srv2, nil, token)
if err := a.State.SyncFull(); err != nil { if err := a.State.SyncFull(); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -1332,7 +1332,7 @@ func TestAgentAntiEntropy_Checks_ACLDeny(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv1, "root") a.State.AddServiceWithChecks(srv1, nil, "root")
srv2 := &structs.NodeService{ srv2 := &structs.NodeService{
ID: "api", ID: "api",
Service: "api", Service: "api",
@ -1344,7 +1344,7 @@ func TestAgentAntiEntropy_Checks_ACLDeny(t *testing.T) {
}, },
EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(), EnterpriseMeta: *structs.DefaultEnterpriseMetaInDefaultPartition(),
} }
a.State.AddService(srv2, "root") a.State.AddServiceWithChecks(srv2, nil, "root")
if err := a.State.SyncFull(); err != nil { if err := a.State.SyncFull(); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -1861,14 +1861,14 @@ func TestState_ServiceTokens(t *testing.T) {
}) })
t.Run("empty string when there is no token", func(t *testing.T) { t.Run("empty string when there is no token", func(t *testing.T) {
err := l.AddService(&structs.NodeService{ID: "redis"}, "") err := l.AddServiceWithChecks(&structs.NodeService{ID: "redis"}, nil, "")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "", l.ServiceToken(id)) require.Equal(t, "", l.ServiceToken(id))
}) })
t.Run("returns configured token", func(t *testing.T) { t.Run("returns configured token", func(t *testing.T) {
err := l.AddService(&structs.NodeService{ID: "redis"}, "abc123") err := l.AddServiceWithChecks(&structs.NodeService{ID: "redis"}, nil, "abc123")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "abc123", l.ServiceToken(id)) require.Equal(t, "abc123", l.ServiceToken(id))
@ -1931,7 +1931,7 @@ func TestAgent_CheckCriticalTime(t *testing.T) {
l.TriggerSyncChanges = func() {} l.TriggerSyncChanges = func() {}
svc := &structs.NodeService{ID: "redis", Service: "redis", Port: 8000} svc := &structs.NodeService{ID: "redis", Service: "redis", Port: 8000}
l.AddService(svc, "") l.AddServiceWithChecks(svc, nil, "")
// Add a passing check and make sure it's not critical. // Add a passing check and make sure it's not critical.
checkID := types.CheckID("redis:1") checkID := types.CheckID("redis:1")
@ -2017,8 +2017,8 @@ func TestAgent_AliasCheck(t *testing.T) {
l.TriggerSyncChanges = func() {} l.TriggerSyncChanges = func() {}
// Add checks // Add checks
require.NoError(t, l.AddService(&structs.NodeService{Service: "s1"}, "")) require.NoError(t, l.AddServiceWithChecks(&structs.NodeService{Service: "s1"}, nil, ""))
require.NoError(t, l.AddService(&structs.NodeService{Service: "s2"}, "")) require.NoError(t, l.AddServiceWithChecks(&structs.NodeService{Service: "s2"}, nil, ""))
require.NoError(t, l.AddCheck(&structs.HealthCheck{CheckID: types.CheckID("c1"), ServiceID: "s1"}, "")) require.NoError(t, l.AddCheck(&structs.HealthCheck{CheckID: types.CheckID("c1"), ServiceID: "s1"}, ""))
require.NoError(t, l.AddCheck(&structs.HealthCheck{CheckID: types.CheckID("c2"), ServiceID: "s2"}, "")) require.NoError(t, l.AddCheck(&structs.HealthCheck{CheckID: types.CheckID("c2"), ServiceID: "s2"}, ""))
@ -2071,7 +2071,7 @@ func TestAgent_AliasCheck_ServiceNotification(t *testing.T) {
require.NoError(t, l.AddAliasCheck(structs.NewCheckID(types.CheckID("a1"), nil), structs.NewServiceID("s1", nil), notifyCh)) require.NoError(t, l.AddAliasCheck(structs.NewCheckID(types.CheckID("a1"), nil), structs.NewServiceID("s1", nil), notifyCh))
// Add aliased service, s1, and verify we get notified // Add aliased service, s1, and verify we get notified
require.NoError(t, l.AddService(&structs.NodeService{Service: "s1"}, "")) require.NoError(t, l.AddServiceWithChecks(&structs.NodeService{Service: "s1"}, nil, ""))
select { select {
case <-notifyCh: case <-notifyCh:
default: default:
@ -2079,7 +2079,7 @@ func TestAgent_AliasCheck_ServiceNotification(t *testing.T) {
} }
// Re-adding same service should not lead to a notification // Re-adding same service should not lead to a notification
require.NoError(t, l.AddService(&structs.NodeService{Service: "s1"}, "")) require.NoError(t, l.AddServiceWithChecks(&structs.NodeService{Service: "s1"}, nil, ""))
select { select {
case <-notifyCh: case <-notifyCh:
t.Fatal("notify received") t.Fatal("notify received")
@ -2087,7 +2087,7 @@ func TestAgent_AliasCheck_ServiceNotification(t *testing.T) {
} }
// Add different service and verify we do not get notified // Add different service and verify we do not get notified
require.NoError(t, l.AddService(&structs.NodeService{Service: "s2"}, "")) require.NoError(t, l.AddServiceWithChecks(&structs.NodeService{Service: "s2"}, nil, ""))
select { select {
case <-notifyCh: case <-notifyCh:
t.Fatal("notify received") t.Fatal("notify received")
@ -2189,10 +2189,10 @@ func TestState_RemoveServiceErrorMessages(t *testing.T) {
state.TriggerSyncChanges = func() {} state.TriggerSyncChanges = func() {}
// Add 1 service // Add 1 service
err := state.AddService(&structs.NodeService{ err := state.AddServiceWithChecks(&structs.NodeService{
ID: "web-id", ID: "web-id",
Service: "web-name", Service: "web-name",
}, "") }, nil, "")
require.NoError(t, err) require.NoError(t, err)
// Attempt to remove service that doesn't exist // Attempt to remove service that doesn't exist
@ -2230,9 +2230,9 @@ func TestState_Notify(t *testing.T) {
drainCh(notifyCh) drainCh(notifyCh)
// Add a service // Add a service
err := state.AddService(&structs.NodeService{ err := state.AddServiceWithChecks(&structs.NodeService{
Service: "web", Service: "web",
}, "fake-token-web") }, nil, "fake-token-web")
require.NoError(t, err) require.NoError(t, err)
// Should have a notification // Should have a notification
@ -2240,10 +2240,10 @@ func TestState_Notify(t *testing.T) {
drainCh(notifyCh) drainCh(notifyCh)
// Re-Add same service // Re-Add same service
err = state.AddService(&structs.NodeService{ err = state.AddServiceWithChecks(&structs.NodeService{
Service: "web", Service: "web",
Port: 4444, Port: 4444,
}, "fake-token-web") }, nil, "fake-token-web")
require.NoError(t, err) require.NoError(t, err)
// Should have a notification // Should have a notification
@ -2261,9 +2261,9 @@ func TestState_Notify(t *testing.T) {
state.StopNotify(notifyCh) state.StopNotify(notifyCh)
// Add a service // Add a service
err = state.AddService(&structs.NodeService{ err = state.AddServiceWithChecks(&structs.NodeService{
Service: "web", Service: "web",
}, "fake-token-web") }, nil, "fake-token-web")
require.NoError(t, err) require.NoError(t, err)
// Should NOT have a notification // Should NOT have a notification
@ -2293,7 +2293,7 @@ func TestAliasNotifications_local(t *testing.T) {
Address: "127.0.0.10", Address: "127.0.0.10",
Port: 8080, Port: 8080,
} }
a.State.AddService(srv, "") a.State.AddServiceWithChecks(srv, nil, "")
scID := "socat-sidecar-proxy" scID := "socat-sidecar-proxy"
sc := &structs.NodeService{ sc := &structs.NodeService{
@ -2303,7 +2303,7 @@ func TestAliasNotifications_local(t *testing.T) {
Address: "127.0.0.10", Address: "127.0.0.10",
Port: 9090, Port: 9090,
} }
a.State.AddService(sc, "") a.State.AddServiceWithChecks(sc, nil, "")
tcpID := types.CheckID("service:socat-tcp") tcpID := types.CheckID("service:socat-tcp")
chk0 := &structs.HealthCheck{ chk0 := &structs.HealthCheck{

View File

@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
external "github.com/hashicorp/consul/agent/grpc-external"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/proto/pbpeering" "github.com/hashicorp/consul/proto/pbpeering"
@ -32,17 +33,20 @@ func (s *HTTPHandlers) PeeringEndpoint(resp http.ResponseWriter, req *http.Reque
// peeringRead fetches a peering that matches the name and partition. // peeringRead fetches a peering that matches the name and partition.
// This assumes that the name and partition parameters are valid // This assumes that the name and partition parameters are valid
func (s *HTTPHandlers) peeringRead(resp http.ResponseWriter, req *http.Request, name string) (interface{}, error) { func (s *HTTPHandlers) peeringRead(resp http.ResponseWriter, req *http.Request, name string) (interface{}, error) {
args := pbpeering.PeeringReadRequest{
Name: name,
Datacenter: s.agent.config.Datacenter,
}
var entMeta acl.EnterpriseMeta var entMeta acl.EnterpriseMeta
if err := s.parseEntMetaPartition(req, &entMeta); err != nil { if err := s.parseEntMetaPartition(req, &entMeta); err != nil {
return nil, err return nil, err
} }
args.Partition = entMeta.PartitionOrEmpty() args := pbpeering.PeeringReadRequest{
Name: name,
Partition: entMeta.PartitionOrEmpty(),
}
result, err := s.agent.rpcClientPeering.PeeringRead(req.Context(), &args) var token string
s.parseToken(req, &token)
ctx := external.ContextWithToken(req.Context(), token)
result, err := s.agent.rpcClientPeering.PeeringRead(ctx, &args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -55,16 +59,19 @@ func (s *HTTPHandlers) peeringRead(resp http.ResponseWriter, req *http.Request,
// PeeringList fetches all peerings in the datacenter in OSS or in a given partition in Consul Enterprise. // PeeringList fetches all peerings in the datacenter in OSS or in a given partition in Consul Enterprise.
func (s *HTTPHandlers) PeeringList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) PeeringList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
args := pbpeering.PeeringListRequest{
Datacenter: s.agent.config.Datacenter,
}
var entMeta acl.EnterpriseMeta var entMeta acl.EnterpriseMeta
if err := s.parseEntMetaPartition(req, &entMeta); err != nil { if err := s.parseEntMetaPartition(req, &entMeta); err != nil {
return nil, err return nil, err
} }
args.Partition = entMeta.PartitionOrEmpty() args := pbpeering.PeeringListRequest{
Partition: entMeta.PartitionOrEmpty(),
}
pbresp, err := s.agent.rpcClientPeering.PeeringList(req.Context(), &args) var token string
s.parseToken(req, &token)
ctx := external.ContextWithToken(req.Context(), token)
pbresp, err := s.agent.rpcClientPeering.PeeringList(ctx, &args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -79,14 +86,12 @@ func (s *HTTPHandlers) PeeringGenerateToken(resp http.ResponseWriter, req *http.
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "The peering arguments must be provided in the body"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "The peering arguments must be provided in the body"}
} }
apiRequest := &api.PeeringGenerateTokenRequest{ var apiRequest api.PeeringGenerateTokenRequest
Datacenter: s.agent.config.Datacenter, if err := lib.DecodeJSON(req.Body, &apiRequest); err != nil {
}
if err := lib.DecodeJSON(req.Body, apiRequest); err != nil {
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Body decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Body decoding failed: %v", err)}
} }
args := pbpeering.NewGenerateTokenRequestFromAPI(apiRequest)
args := pbpeering.NewGenerateTokenRequestFromAPI(&apiRequest)
if args.PeerName == "" { if args.PeerName == "" {
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "PeerName is required in the payload when generating a new peering token."} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "PeerName is required in the payload when generating a new peering token."}
} }
@ -99,7 +104,11 @@ func (s *HTTPHandlers) PeeringGenerateToken(resp http.ResponseWriter, req *http.
args.Partition = entMeta.PartitionOrEmpty() args.Partition = entMeta.PartitionOrEmpty()
} }
out, err := s.agent.rpcClientPeering.GenerateToken(req.Context(), args) var token string
s.parseToken(req, &token)
ctx := external.ContextWithToken(req.Context(), token)
out, err := s.agent.rpcClientPeering.GenerateToken(ctx, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -114,23 +123,32 @@ func (s *HTTPHandlers) PeeringEstablish(resp http.ResponseWriter, req *http.Requ
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "The peering arguments must be provided in the body"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "The peering arguments must be provided in the body"}
} }
apiRequest := &api.PeeringEstablishRequest{ var apiRequest api.PeeringEstablishRequest
Datacenter: s.agent.config.Datacenter, if err := lib.DecodeJSON(req.Body, &apiRequest); err != nil {
}
if err := lib.DecodeJSON(req.Body, apiRequest); err != nil {
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Body decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Body decoding failed: %v", err)}
} }
args := pbpeering.NewEstablishRequestFromAPI(apiRequest)
args := pbpeering.NewEstablishRequestFromAPI(&apiRequest)
if args.PeerName == "" { if args.PeerName == "" {
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "PeerName is required in the payload when establishing a peering."} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "PeerName is required in the payload when establishing a peering."}
} }
if args.PeeringToken == "" { if args.PeeringToken == "" {
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "PeeringToken is required in the payload when establishing a peering."} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "PeeringToken is required in the payload when establishing a peering."}
} }
out, err := s.agent.rpcClientPeering.Establish(req.Context(), args) var entMeta acl.EnterpriseMeta
if err := s.parseEntMetaPartition(req, &entMeta); err != nil {
return nil, err
}
if args.Partition == "" {
args.Partition = entMeta.PartitionOrEmpty()
}
var token string
s.parseToken(req, &token)
ctx := external.ContextWithToken(req.Context(), token)
out, err := s.agent.rpcClientPeering.Establish(ctx, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -141,17 +159,20 @@ func (s *HTTPHandlers) PeeringEstablish(resp http.ResponseWriter, req *http.Requ
// peeringDelete initiates a deletion for a peering that matches the name and partition. // peeringDelete initiates a deletion for a peering that matches the name and partition.
// This assumes that the name and partition parameters are valid. // This assumes that the name and partition parameters are valid.
func (s *HTTPHandlers) peeringDelete(resp http.ResponseWriter, req *http.Request, name string) (interface{}, error) { func (s *HTTPHandlers) peeringDelete(resp http.ResponseWriter, req *http.Request, name string) (interface{}, error) {
args := pbpeering.PeeringDeleteRequest{
Name: name,
Datacenter: s.agent.config.Datacenter,
}
var entMeta acl.EnterpriseMeta var entMeta acl.EnterpriseMeta
if err := s.parseEntMetaPartition(req, &entMeta); err != nil { if err := s.parseEntMetaPartition(req, &entMeta); err != nil {
return nil, err return nil, err
} }
args.Partition = entMeta.PartitionOrEmpty() args := pbpeering.PeeringDeleteRequest{
Name: name,
Partition: entMeta.PartitionOrEmpty(),
}
_, err := s.agent.rpcClientPeering.PeeringDelete(req.Context(), &args) var token string
s.parseToken(req, &token)
ctx := external.ContextWithToken(req.Context(), token)
_, err := s.agent.rpcClientPeering.PeeringDelete(ctx, &args)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -12,6 +12,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
@ -113,6 +114,104 @@ func TestHTTP_Peering_GenerateToken(t *testing.T) {
// The PeerID in the token is randomly generated so we don't assert on its value. // The PeerID in the token is randomly generated so we don't assert on its value.
require.NotEmpty(t, token.PeerID) require.NotEmpty(t, token.PeerID)
}) })
t.Run("Success with external address", func(t *testing.T) {
externalAddress := "32.1.2.3"
body := &pbpeering.GenerateTokenRequest{
PeerName: "peering-a",
ServerExternalAddresses: []string{externalAddress},
}
bodyBytes, err := json.Marshal(body)
require.NoError(t, err)
req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
var r pbpeering.GenerateTokenResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&r))
tokenJSON, err := base64.StdEncoding.DecodeString(r.PeeringToken)
require.NoError(t, err)
var token structs.PeeringToken
require.NoError(t, json.Unmarshal(tokenJSON, &token))
require.Nil(t, token.CA)
require.Equal(t, []string{externalAddress}, token.ServerAddresses)
require.Equal(t, "server.dc1.consul", token.ServerName)
// The PeerID in the token is randomly generated so we don't assert on its value.
require.NotEmpty(t, token.PeerID)
})
}
// Test for GenerateToken calls at various points in a peer's lifecycle
func TestHTTP_Peering_GenerateToken_EdgeCases(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
a := NewTestAgent(t, "")
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
body := &pbpeering.GenerateTokenRequest{
PeerName: "peering-a",
}
bodyBytes, err := json.Marshal(body)
require.NoError(t, err)
getPeering := func(t *testing.T) *api.Peering {
t.Helper()
// Check state of peering
req, err := http.NewRequest("GET", "/v1/peering/peering-a", bytes.NewReader(bodyBytes))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
var p *api.Peering
require.NoError(t, json.NewDecoder(resp.Body).Decode(&p))
return p
}
{
// Call once
req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
// Assertions tested in TestHTTP_Peering_GenerateToken
}
if !t.Run("generate token called again", func(t *testing.T) {
before := getPeering(t)
require.Equal(t, api.PeeringStatePending, before.State)
// Call again
req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes))
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
after := getPeering(t)
assert.NotEqual(t, before.ModifyIndex, after.ModifyIndex)
// blank out modify index so we can compare rest of struct
before.ModifyIndex, after.ModifyIndex = 0, 0
assert.Equal(t, before, after)
}) {
t.FailNow()
}
} }
func TestHTTP_Peering_Establish(t *testing.T) { func TestHTTP_Peering_Establish(t *testing.T) {

View File

@ -92,7 +92,7 @@ func TestPreparedQuery_Create(t *testing.T) {
Session: "my-session", Session: "my-session",
Service: structs.ServiceQuery{ Service: structs.ServiceQuery{
Service: "my-service", Service: "my-service",
Failover: structs.QueryDatacenterOptions{ Failover: structs.QueryFailoverOptions{
NearestN: 4, NearestN: 4,
Datacenters: []string{"dc1", "dc2"}, Datacenters: []string{"dc1", "dc2"},
}, },
@ -883,7 +883,7 @@ func TestPreparedQuery_Update(t *testing.T) {
Session: "my-session", Session: "my-session",
Service: structs.ServiceQuery{ Service: structs.ServiceQuery{
Service: "my-service", Service: "my-service",
Failover: structs.QueryDatacenterOptions{ Failover: structs.QueryFailoverOptions{
NearestN: 4, NearestN: 4,
Datacenters: []string{"dc1", "dc2"}, Datacenters: []string{"dc1", "dc2"},
}, },

View File

@ -12,6 +12,7 @@ import (
"github.com/hashicorp/consul/agent/proxycfg" "github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/submatview" "github.com/hashicorp/consul/agent/submatview"
"github.com/hashicorp/consul/proto/pbcommon"
"github.com/hashicorp/consul/proto/pbconfigentry" "github.com/hashicorp/consul/proto/pbconfigentry"
"github.com/hashicorp/consul/proto/pbsubscribe" "github.com/hashicorp/consul/proto/pbsubscribe"
) )
@ -19,6 +20,7 @@ import (
// ServerDataSourceDeps contains the dependencies needed for sourcing data from // ServerDataSourceDeps contains the dependencies needed for sourcing data from
// server-local sources (e.g. materialized views). // server-local sources (e.g. materialized views).
type ServerDataSourceDeps struct { type ServerDataSourceDeps struct {
Datacenter string
ViewStore *submatview.Store ViewStore *submatview.Store
EventPublisher *stream.EventPublisher EventPublisher *stream.EventPublisher
Logger hclog.Logger Logger hclog.Logger
@ -193,7 +195,7 @@ func (v *configEntryListView) Result(index uint64) any {
} }
func (v *configEntryListView) Update(events []*pbsubscribe.Event) error { func (v *configEntryListView) Update(events []*pbsubscribe.Event) error {
for _, event := range v.filterByEnterpriseMeta(events) { for _, event := range filterByEnterpriseMeta(events, v.entMeta) {
update := event.GetConfigEntry() update := event.GetConfigEntry()
configEntry := pbconfigentry.ConfigEntryToStructs(update.ConfigEntry) configEntry := pbconfigentry.ConfigEntryToStructs(update.ConfigEntry)
name := structs.NewServiceName(configEntry.GetName(), configEntry.GetEnterpriseMeta()).String() name := structs.NewServiceName(configEntry.GetName(), configEntry.GetEnterpriseMeta()).String()
@ -212,22 +214,26 @@ func (v *configEntryListView) Update(events []*pbsubscribe.Event) error {
// don't match the request's enterprise meta - this is necessary because when // don't match the request's enterprise meta - this is necessary because when
// subscribing to a topic with SubjectWildcard we'll get events for resources // subscribing to a topic with SubjectWildcard we'll get events for resources
// in all partitions and namespaces. // in all partitions and namespaces.
func (v *configEntryListView) filterByEnterpriseMeta(events []*pbsubscribe.Event) []*pbsubscribe.Event { func filterByEnterpriseMeta(events []*pbsubscribe.Event, entMeta acl.EnterpriseMeta) []*pbsubscribe.Event {
partition := v.entMeta.PartitionOrDefault() partition := entMeta.PartitionOrDefault()
namespace := v.entMeta.NamespaceOrDefault() namespace := entMeta.NamespaceOrDefault()
filtered := make([]*pbsubscribe.Event, 0, len(events)) filtered := make([]*pbsubscribe.Event, 0, len(events))
for _, event := range events { for _, event := range events {
configEntry := event.GetConfigEntry().GetConfigEntry() var eventEntMeta *pbcommon.EnterpriseMeta
if configEntry == nil { switch payload := event.Payload.(type) {
case *pbsubscribe.Event_ConfigEntry:
eventEntMeta = payload.ConfigEntry.ConfigEntry.GetEnterpriseMeta()
case *pbsubscribe.Event_Service:
eventEntMeta = payload.Service.GetEnterpriseMeta()
default:
continue continue
} }
entMeta := configEntry.GetEnterpriseMeta() if partition != acl.WildcardName && !acl.EqualPartitions(partition, eventEntMeta.GetPartition()) {
if partition != acl.WildcardName && !acl.EqualPartitions(partition, entMeta.GetPartition()) {
continue continue
} }
if namespace != acl.WildcardName && !acl.EqualNamespaces(namespace, entMeta.GetNamespace()) { if namespace != acl.WildcardName && !acl.EqualNamespaces(namespace, eventEntMeta.GetNamespace()) {
continue continue
} }

View File

@ -0,0 +1,95 @@
package proxycfgglue
import (
"context"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/consul/discoverychain"
"github.com/hashicorp/consul/agent/consul/watch"
"github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/structs"
)
// CacheCompiledDiscoveryChain satisfies the proxycfg.CompiledDiscoveryChain
// interface by sourcing data from the agent cache.
func CacheCompiledDiscoveryChain(c *cache.Cache) proxycfg.CompiledDiscoveryChain {
return &cacheProxyDataSource[*structs.DiscoveryChainRequest]{c, cachetype.CompiledDiscoveryChainName}
}
// ServerCompiledDiscoveryChain satisfies the proxycfg.CompiledDiscoveryChain
// interface by sourcing data from a blocking query against the server's state
// store.
//
// Requests for services in remote datacenters will be delegated to the given
// remoteSource (i.e. CacheCompiledDiscoveryChain).
func ServerCompiledDiscoveryChain(deps ServerDataSourceDeps, remoteSource proxycfg.CompiledDiscoveryChain) proxycfg.CompiledDiscoveryChain {
return &serverCompiledDiscoveryChain{deps, remoteSource}
}
type serverCompiledDiscoveryChain struct {
deps ServerDataSourceDeps
remoteSource proxycfg.CompiledDiscoveryChain
}
func (s serverCompiledDiscoveryChain) Notify(ctx context.Context, req *structs.DiscoveryChainRequest, correlationID string, ch chan<- proxycfg.UpdateEvent) error {
if req.Datacenter != s.deps.Datacenter {
return s.remoteSource.Notify(ctx, req, correlationID, ch)
}
entMeta := req.GetEnterpriseMeta()
evalDC := req.EvaluateInDatacenter
if evalDC == "" {
evalDC = s.deps.Datacenter
}
compileReq := discoverychain.CompileRequest{
ServiceName: req.Name,
EvaluateInNamespace: entMeta.NamespaceOrDefault(),
EvaluateInPartition: entMeta.PartitionOrDefault(),
EvaluateInDatacenter: evalDC,
OverrideMeshGateway: req.OverrideMeshGateway,
OverrideProtocol: req.OverrideProtocol,
OverrideConnectTimeout: req.OverrideConnectTimeout,
}
return watch.ServerLocalNotify(ctx, correlationID, s.deps.GetStore,
func(ws memdb.WatchSet, store Store) (uint64, *structs.DiscoveryChainResponse, error) {
var authzContext acl.AuthorizerContext
authz, err := s.deps.ACLResolver.ResolveTokenAndDefaultMeta(req.Token, req.GetEnterpriseMeta(), &authzContext)
if err != nil {
return 0, nil, err
}
if err := authz.ToAllowAuthorizer().ServiceReadAllowed(req.Name, &authzContext); err != nil {
// TODO(agentless): the agent cache handles acl.IsErrNotFound specially to
// prevent endlessly retrying if an ACL token is deleted. We should probably
// do this in watch.ServerLocalNotify too.
return 0, nil, err
}
index, chain, entries, err := store.ServiceDiscoveryChain(ws, req.Name, entMeta, compileReq)
if err != nil {
return 0, nil, err
}
rsp := &structs.DiscoveryChainResponse{
Chain: chain,
QueryMeta: structs.QueryMeta{
Backend: structs.QueryBackendBlocking,
Index: index,
},
}
// TODO(boxofrad): Check with @mkeeler that this is the correct thing to do.
if entries.IsEmpty() {
return index, rsp, watch.ErrorNotFound
}
return index, rsp, nil
},
dispatchBlockingQueryUpdate[*structs.DiscoveryChainResponse](ch),
)
}

View File

@ -0,0 +1,114 @@
package proxycfgglue
import (
"context"
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/structs"
)
func TestServerCompiledDiscoveryChain(t *testing.T) {
t.Run("remote queries are delegated to the remote source", func(t *testing.T) {
var (
ctx = context.Background()
req = &structs.DiscoveryChainRequest{Datacenter: "dc2"}
correlationID = "correlation-id"
ch = make(chan<- proxycfg.UpdateEvent)
result = errors.New("KABOOM")
)
remoteSource := newMockCompiledDiscoveryChain(t)
remoteSource.On("Notify", ctx, req, correlationID, ch).Return(result)
dataSource := ServerCompiledDiscoveryChain(ServerDataSourceDeps{Datacenter: "dc1"}, remoteSource)
err := dataSource.Notify(ctx, req, correlationID, ch)
require.Equal(t, result, err)
})
t.Run("local queries are served from the state store", func(t *testing.T) {
const (
serviceName = "web"
datacenter = "dc1"
index = 123
)
store := state.NewStateStore(nil)
require.NoError(t, store.CASetConfig(index, &structs.CAConfiguration{ClusterID: "cluster-id"}))
require.NoError(t, store.EnsureConfigEntry(index, &structs.ServiceConfigEntry{
Name: serviceName,
Kind: structs.ServiceDefaults,
}))
req := &structs.DiscoveryChainRequest{
Name: serviceName,
Datacenter: datacenter,
}
resolver := newStaticResolver(
policyAuthorizer(t, fmt.Sprintf(`service "%s" { policy = "read" }`, serviceName)),
)
dataSource := ServerCompiledDiscoveryChain(ServerDataSourceDeps{
ACLResolver: resolver,
Datacenter: datacenter,
GetStore: func() Store { return store },
}, nil)
eventCh := make(chan proxycfg.UpdateEvent)
err := dataSource.Notify(context.Background(), req, "", eventCh)
require.NoError(t, err)
// Check we get an event with the initial state.
result := getEventResult[*structs.DiscoveryChainResponse](t, eventCh)
require.NotNil(t, result.Chain)
// Change the protocol to HTTP and check we get a recompiled chain.
require.NoError(t, store.EnsureConfigEntry(index+1, &structs.ServiceConfigEntry{
Name: serviceName,
Kind: structs.ServiceDefaults,
Protocol: "http",
}))
result = getEventResult[*structs.DiscoveryChainResponse](t, eventCh)
require.NotNil(t, result.Chain)
require.Equal(t, "http", result.Chain.Protocol)
// Revoke access to the service.
resolver.SwapAuthorizer(acl.DenyAll())
// Write another config entry.
require.NoError(t, store.EnsureConfigEntry(index+2, &structs.ServiceConfigEntry{
Name: serviceName,
Kind: structs.ServiceDefaults,
MaxInboundConnections: 1,
}))
// Should no longer receive events for this service.
expectNoEvent(t, eventCh)
})
}
func newMockCompiledDiscoveryChain(t *testing.T) *mockCompiledDiscoveryChain {
mock := &mockCompiledDiscoveryChain{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
type mockCompiledDiscoveryChain struct {
mock.Mock
}
func (m *mockCompiledDiscoveryChain) Notify(ctx context.Context, req *structs.DiscoveryChainRequest, correlationID string, ch chan<- proxycfg.UpdateEvent) error {
return m.Called(ctx, req, correlationID, ch).Error(0)
}

View File

@ -0,0 +1,60 @@
package proxycfgglue
import (
"context"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/consul/watch"
"github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/structs/aclfilter"
)
// CacheExportedPeeredServices satisfies the proxycfg.ExportedPeeredServices
// interface by sourcing data from the agent cache.
func CacheExportedPeeredServices(c *cache.Cache) proxycfg.ExportedPeeredServices {
return &cacheProxyDataSource[*structs.DCSpecificRequest]{c, cachetype.ExportedPeeredServicesName}
}
// ServerExportedPeeredServices satisifies the proxycfg.ExportedPeeredServices
// interface by sourcing data from a blocking query against the server's state
// store.
func ServerExportedPeeredServices(deps ServerDataSourceDeps) proxycfg.ExportedPeeredServices {
return &serverExportedPeeredServices{deps}
}
type serverExportedPeeredServices struct {
deps ServerDataSourceDeps
}
func (s *serverExportedPeeredServices) Notify(ctx context.Context, req *structs.DCSpecificRequest, correlationID string, ch chan<- proxycfg.UpdateEvent) error {
return watch.ServerLocalNotify(ctx, correlationID, s.deps.GetStore,
func(ws memdb.WatchSet, store Store) (uint64, *structs.IndexedExportedServiceList, error) {
// TODO(peering): acls: mesh gateway needs appropriate wildcard service:read
authz, err := s.deps.ACLResolver.ResolveTokenAndDefaultMeta(req.Token, &req.EnterpriseMeta, nil)
if err != nil {
return 0, nil, err
}
index, serviceMap, err := store.ExportedServicesForAllPeersByName(ws, req.EnterpriseMeta)
if err != nil {
return 0, nil, err
}
result := &structs.IndexedExportedServiceList{
Services: serviceMap,
QueryMeta: structs.QueryMeta{
Backend: structs.QueryBackendBlocking,
Index: index,
},
}
aclfilter.New(authz, s.deps.Logger).Filter(result)
return index, result, nil
},
dispatchBlockingQueryUpdate[*structs.IndexedExportedServiceList](ch),
)
}

View File

@ -0,0 +1,113 @@
package proxycfgglue
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/consul/sdk/testutil"
)
func TestServerExportedPeeredServices(t *testing.T) {
nextIndex := indexGenerator()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
store := state.NewStateStore(nil)
for _, peer := range []string{"peer-1", "peer-2", "peer-3"} {
require.NoError(t, store.PeeringWrite(nextIndex(), &pbpeering.Peering{
ID: testUUID(t),
Name: peer,
State: pbpeering.PeeringState_ACTIVE,
}))
}
require.NoError(t, store.EnsureConfigEntry(nextIndex(), &structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "web",
Consumers: []structs.ServiceConsumer{
{PeerName: "peer-1"},
},
},
{
Name: "db",
Consumers: []structs.ServiceConsumer{
{PeerName: "peer-2"},
},
},
},
}))
authz := policyAuthorizer(t, `
service "web" { policy = "read" }
service "api" { policy = "read" }
service "db" { policy = "deny" }
`)
eventCh := make(chan proxycfg.UpdateEvent)
dataSource := ServerExportedPeeredServices(ServerDataSourceDeps{
GetStore: func() Store { return store },
ACLResolver: newStaticResolver(authz),
})
require.NoError(t, dataSource.Notify(ctx, &structs.DCSpecificRequest{}, "", eventCh))
testutil.RunStep(t, "initial state", func(t *testing.T) {
result := getEventResult[*structs.IndexedExportedServiceList](t, eventCh)
require.Equal(t,
map[string]structs.ServiceList{
"peer-1": {structs.NewServiceName("web", nil)},
},
result.Services,
)
})
testutil.RunStep(t, "update exported services", func(t *testing.T) {
require.NoError(t, store.EnsureConfigEntry(nextIndex(), &structs.ExportedServicesConfigEntry{
Name: "default",
Services: []structs.ExportedService{
{
Name: "web",
Consumers: []structs.ServiceConsumer{
{PeerName: "peer-1"},
},
},
{
Name: "db",
Consumers: []structs.ServiceConsumer{
{PeerName: "peer-2"},
},
},
{
Name: "api",
Consumers: []structs.ServiceConsumer{
{PeerName: "peer-1"},
{PeerName: "peer-3"},
},
},
},
}))
result := getEventResult[*structs.IndexedExportedServiceList](t, eventCh)
require.Equal(t,
map[string]structs.ServiceList{
"peer-1": {
structs.NewServiceName("api", nil),
structs.NewServiceName("web", nil),
},
"peer-3": {
structs.NewServiceName("api", nil),
},
},
result.Services,
)
})
}

View File

@ -0,0 +1,67 @@
package proxycfgglue
import (
"context"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/consul/watch"
"github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/structs/aclfilter"
)
// CacheFederationStateListMeshGateways satisfies the proxycfg.FederationStateListMeshGateways
// interface by sourcing data from the agent cache.
func CacheFederationStateListMeshGateways(c *cache.Cache) proxycfg.FederationStateListMeshGateways {
return &cacheProxyDataSource[*structs.DCSpecificRequest]{c, cachetype.FederationStateListMeshGatewaysName}
}
// ServerFederationStateListMeshGateways satisfies the proxycfg.FederationStateListMeshGateways
// interface by sourcing data from a blocking query against the server's state
// store.
func ServerFederationStateListMeshGateways(deps ServerDataSourceDeps) proxycfg.FederationStateListMeshGateways {
return &serverFederationStateListMeshGateways{deps}
}
type serverFederationStateListMeshGateways struct {
deps ServerDataSourceDeps
}
func (s *serverFederationStateListMeshGateways) Notify(ctx context.Context, req *structs.DCSpecificRequest, correlationID string, ch chan<- proxycfg.UpdateEvent) error {
return watch.ServerLocalNotify(ctx, correlationID, s.deps.GetStore,
func(ws memdb.WatchSet, store Store) (uint64, *structs.DatacenterIndexedCheckServiceNodes, error) {
authz, err := s.deps.ACLResolver.ResolveTokenAndDefaultMeta(req.Token, &req.EnterpriseMeta, nil)
if err != nil {
return 0, nil, err
}
index, fedStates, err := store.FederationStateList(ws)
if err != nil {
return 0, nil, err
}
results := make(map[string]structs.CheckServiceNodes)
for _, fs := range fedStates {
if gws := fs.MeshGateways; len(gws) != 0 {
// Shallow clone to prevent ACL filtering manipulating the slice in memdb.
results[fs.Datacenter] = gws.ShallowClone()
}
}
rsp := &structs.DatacenterIndexedCheckServiceNodes{
DatacenterNodes: results,
QueryMeta: structs.QueryMeta{
Index: index,
Backend: structs.QueryBackendBlocking,
},
}
aclfilter.New(authz, s.deps.Logger).Filter(rsp)
return index, rsp, nil
},
dispatchBlockingQueryUpdate[*structs.DatacenterIndexedCheckServiceNodes](ch),
)
}

View File

@ -0,0 +1,103 @@
package proxycfgglue
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/testutil"
)
func TestServerFederationStateListMeshGateways(t *testing.T) {
const index uint64 = 123
store := state.NewStateStore(nil)
authz := policyAuthorizer(t, `
service_prefix "dc2-" { policy = "read" }
node_prefix "dc2-" { policy = "read" }
service_prefix "dc3-" { policy = "read" }
node_prefix "dc3-" { policy = "read" }
`)
require.NoError(t, store.FederationStateSet(index, &structs.FederationState{
Datacenter: "dc2",
MeshGateways: structs.CheckServiceNodes{
{
Service: &structs.NodeService{Service: "dc2-gw1"},
Node: &structs.Node{Node: "dc2-gw1"},
},
},
}))
// No access to this DC, we shouldn't see it in results.
require.NoError(t, store.FederationStateSet(index, &structs.FederationState{
Datacenter: "dc4",
MeshGateways: structs.CheckServiceNodes{
{
Service: &structs.NodeService{Service: "dc4-gw1"},
Node: &structs.Node{Node: "dc4-gw1"},
},
},
}))
dataSource := ServerFederationStateListMeshGateways(ServerDataSourceDeps{
ACLResolver: newStaticResolver(authz),
GetStore: func() Store { return store },
})
eventCh := make(chan proxycfg.UpdateEvent)
require.NoError(t, dataSource.Notify(context.Background(), &structs.DCSpecificRequest{Datacenter: "dc1"}, "", eventCh))
testutil.RunStep(t, "initial state", func(t *testing.T) {
result := getEventResult[*structs.DatacenterIndexedCheckServiceNodes](t, eventCh)
require.Equal(t, map[string]structs.CheckServiceNodes{
"dc2": {
{
Service: &structs.NodeService{Service: "dc2-gw1"},
Node: &structs.Node{Node: "dc2-gw1"},
},
},
}, result.DatacenterNodes)
})
testutil.RunStep(t, "add new datacenter", func(t *testing.T) {
require.NoError(t, store.FederationStateSet(index+1, &structs.FederationState{
Datacenter: "dc3",
MeshGateways: structs.CheckServiceNodes{
{
Service: &structs.NodeService{Service: "dc3-gw1"},
Node: &structs.Node{Node: "dc3-gw1"},
},
},
}))
result := getEventResult[*structs.DatacenterIndexedCheckServiceNodes](t, eventCh)
require.Equal(t, map[string]structs.CheckServiceNodes{
"dc2": {
{
Service: &structs.NodeService{Service: "dc2-gw1"},
Node: &structs.Node{Node: "dc2-gw1"},
},
},
"dc3": {
{
Service: &structs.NodeService{Service: "dc3-gw1"},
Node: &structs.Node{Node: "dc3-gw1"},
},
},
}, result.DatacenterNodes)
})
testutil.RunStep(t, "delete datacenter", func(t *testing.T) {
require.NoError(t, store.FederationStateDelete(index+2, "dc3"))
result := getEventResult[*structs.DatacenterIndexedCheckServiceNodes](t, eventCh)
require.NotContains(t, result.DatacenterNodes, "dc3")
})
}

View File

@ -0,0 +1,63 @@
package proxycfgglue
import (
"context"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/consul/watch"
"github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/structs/aclfilter"
)
// CacheGatewayServices satisfies the proxycfg.GatewayServices interface by
// sourcing data from the agent cache.
func CacheGatewayServices(c *cache.Cache) proxycfg.GatewayServices {
return &cacheProxyDataSource[*structs.ServiceSpecificRequest]{c, cachetype.GatewayServicesName}
}
// ServerGatewayServices satisfies the proxycfg.GatewayServices interface by
// sourcing data from a blocking query against the server's state store.
func ServerGatewayServices(deps ServerDataSourceDeps) proxycfg.GatewayServices {
return &serverGatewayServices{deps}
}
type serverGatewayServices struct {
deps ServerDataSourceDeps
}
func (s *serverGatewayServices) Notify(ctx context.Context, req *structs.ServiceSpecificRequest, correlationID string, ch chan<- proxycfg.UpdateEvent) error {
return watch.ServerLocalNotify(ctx, correlationID, s.deps.GetStore,
func(ws memdb.WatchSet, store Store) (uint64, *structs.IndexedGatewayServices, error) {
var authzContext acl.AuthorizerContext
authz, err := s.deps.ACLResolver.ResolveTokenAndDefaultMeta(req.Token, &req.EnterpriseMeta, &authzContext)
if err != nil {
return 0, nil, err
}
if err := authz.ToAllowAuthorizer().ServiceReadAllowed(req.ServiceName, &authzContext); err != nil {
return 0, nil, err
}
index, services, err := store.GatewayServices(ws, req.ServiceName, &req.EnterpriseMeta)
if err != nil {
return 0, nil, err
}
response := &structs.IndexedGatewayServices{
Services: services,
QueryMeta: structs.QueryMeta{
Backend: structs.QueryBackendBlocking,
Index: index,
},
}
aclfilter.New(authz, s.deps.Logger).Filter(response)
return index, response, nil
},
dispatchBlockingQueryUpdate[*structs.IndexedGatewayServices](ch),
)
}

View File

@ -0,0 +1,155 @@
package proxycfgglue
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/testutil"
)
func TestServerGatewayServices(t *testing.T) {
const index uint64 = 123
t.Run("ingress gateway", func(t *testing.T) {
store := state.NewStateStore(nil)
authz := policyAuthorizer(t, `
service "igw" { policy = "read" }
service "web" { policy = "read" }
service "db" { policy = "read" }
`)
require.NoError(t, store.EnsureConfigEntry(index, &structs.IngressGatewayConfigEntry{
Name: "igw",
Listeners: []structs.IngressListener{
{
Protocol: "tcp",
Services: []structs.IngressService{
{Name: "web"},
},
},
{
Protocol: "tcp",
Services: []structs.IngressService{
{Name: "db"},
},
},
{
Protocol: "tcp",
Services: []structs.IngressService{
{Name: "no-access"},
},
},
},
}))
dataSource := ServerGatewayServices(ServerDataSourceDeps{
ACLResolver: newStaticResolver(authz),
GetStore: func() Store { return store },
})
eventCh := make(chan proxycfg.UpdateEvent)
require.NoError(t, dataSource.Notify(context.Background(), &structs.ServiceSpecificRequest{ServiceName: "igw"}, "", eventCh))
testutil.RunStep(t, "initial state", func(t *testing.T) {
result := getEventResult[*structs.IndexedGatewayServices](t, eventCh)
require.Len(t, result.Services, 2)
})
testutil.RunStep(t, "remove service mapping", func(t *testing.T) {
require.NoError(t, store.EnsureConfigEntry(index+1, &structs.IngressGatewayConfigEntry{
Name: "igw",
Listeners: []structs.IngressListener{
{
Protocol: "tcp",
Services: []structs.IngressService{
{Name: "web"},
},
},
},
}))
result := getEventResult[*structs.IndexedGatewayServices](t, eventCh)
require.Len(t, result.Services, 1)
})
})
t.Run("terminating gateway", func(t *testing.T) {
store := state.NewStateStore(nil)
authz := policyAuthorizer(t, `
service "tgw" { policy = "read" }
service "web" { policy = "read" }
service "db" { policy = "read" }
`)
require.NoError(t, store.EnsureConfigEntry(index, &structs.TerminatingGatewayConfigEntry{
Name: "tgw",
Services: []structs.LinkedService{
{Name: "web"},
{Name: "db"},
{Name: "no-access"},
},
}))
dataSource := ServerGatewayServices(ServerDataSourceDeps{
ACLResolver: newStaticResolver(authz),
GetStore: func() Store { return store },
})
eventCh := make(chan proxycfg.UpdateEvent)
require.NoError(t, dataSource.Notify(context.Background(), &structs.ServiceSpecificRequest{ServiceName: "tgw"}, "", eventCh))
testutil.RunStep(t, "initial state", func(t *testing.T) {
result := getEventResult[*structs.IndexedGatewayServices](t, eventCh)
require.Len(t, result.Services, 2)
})
testutil.RunStep(t, "remove service mapping", func(t *testing.T) {
require.NoError(t, store.EnsureConfigEntry(index+1, &structs.TerminatingGatewayConfigEntry{
Name: "tgw",
Services: []structs.LinkedService{
{Name: "web"},
},
}))
result := getEventResult[*structs.IndexedGatewayServices](t, eventCh)
require.Len(t, result.Services, 1)
})
})
t.Run("no access to gateway", func(t *testing.T) {
store := state.NewStateStore(nil)
authz := policyAuthorizer(t, `
service "tgw" { policy = "deny" }
service "web" { policy = "read" }
service "db" { policy = "read" }
`)
require.NoError(t, store.EnsureConfigEntry(index, &structs.TerminatingGatewayConfigEntry{
Name: "tgw",
Services: []structs.LinkedService{
{Name: "web"},
{Name: "db"},
},
}))
dataSource := ServerGatewayServices(ServerDataSourceDeps{
ACLResolver: newStaticResolver(authz),
GetStore: func() Store { return store },
})
eventCh := make(chan proxycfg.UpdateEvent)
require.NoError(t, dataSource.Notify(context.Background(), &structs.ServiceSpecificRequest{ServiceName: "tgw"}, "", eventCh))
err := getEventError(t, eventCh)
require.True(t, acl.IsErrPermissionDenied(err), "expected permission denied error")
})
}

View File

@ -3,23 +3,33 @@ package proxycfgglue
import ( import (
"context" "context"
"github.com/hashicorp/consul/proto/pbpeering"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types" cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/configentry"
"github.com/hashicorp/consul/agent/consul/discoverychain"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/watch" "github.com/hashicorp/consul/agent/consul/watch"
"github.com/hashicorp/consul/agent/proxycfg" "github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/rpcclient/health"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/proto/pbpeering"
) )
// Store is the state store interface required for server-local data sources. // Store is the state store interface required for server-local data sources.
type Store interface { type Store interface {
watch.StateStore watch.StateStore
ExportedServicesForAllPeersByName(ws memdb.WatchSet, entMeta acl.EnterpriseMeta) (uint64, map[string]structs.ServiceList, error)
FederationStateList(ws memdb.WatchSet) (uint64, []*structs.FederationState, error)
GatewayServices(ws memdb.WatchSet, gateway string, entMeta *acl.EnterpriseMeta) (uint64, structs.GatewayServices, error)
IntentionTopology(ws memdb.WatchSet, target structs.ServiceName, downstreams bool, defaultDecision acl.EnforcementDecision, intentionTarget structs.IntentionTargetType) (uint64, structs.ServiceList, error) IntentionTopology(ws memdb.WatchSet, target structs.ServiceName, downstreams bool, defaultDecision acl.EnforcementDecision, intentionTarget structs.IntentionTargetType) (uint64, structs.ServiceList, error)
ServiceDiscoveryChain(ws memdb.WatchSet, serviceName string, entMeta *acl.EnterpriseMeta, req discoverychain.CompileRequest) (uint64, *structs.CompiledDiscoveryChain, *configentry.DiscoveryChainSet, error)
PeeringTrustBundleRead(ws memdb.WatchSet, q state.Query) (uint64, *pbpeering.PeeringTrustBundle, error)
PeeringTrustBundleList(ws memdb.WatchSet, entMeta acl.EnterpriseMeta) (uint64, []*pbpeering.PeeringTrustBundle, error)
TrustBundleListByService(ws memdb.WatchSet, service, dc string, entMeta acl.EnterpriseMeta) (uint64, []*pbpeering.PeeringTrustBundle, error)
VirtualIPsForAllImportedServices(ws memdb.WatchSet, entMeta acl.EnterpriseMeta) (uint64, []state.ServiceVirtualIP, error)
} }
// CacheCARoots satisfies the proxycfg.CARoots interface by sourcing data from // CacheCARoots satisfies the proxycfg.CARoots interface by sourcing data from
@ -28,12 +38,6 @@ func CacheCARoots(c *cache.Cache) proxycfg.CARoots {
return &cacheProxyDataSource[*structs.DCSpecificRequest]{c, cachetype.ConnectCARootName} return &cacheProxyDataSource[*structs.DCSpecificRequest]{c, cachetype.ConnectCARootName}
} }
// CacheCompiledDiscoveryChain satisfies the proxycfg.CompiledDiscoveryChain
// interface by sourcing data from the agent cache.
func CacheCompiledDiscoveryChain(c *cache.Cache) proxycfg.CompiledDiscoveryChain {
return &cacheProxyDataSource[*structs.DiscoveryChainRequest]{c, cachetype.CompiledDiscoveryChainName}
}
// CacheConfigEntry satisfies the proxycfg.ConfigEntry interface by sourcing // CacheConfigEntry satisfies the proxycfg.ConfigEntry interface by sourcing
// data from the agent cache. // data from the agent cache.
func CacheConfigEntry(c *cache.Cache) proxycfg.ConfigEntry { func CacheConfigEntry(c *cache.Cache) proxycfg.ConfigEntry {
@ -52,16 +56,10 @@ func CacheDatacenters(c *cache.Cache) proxycfg.Datacenters {
return &cacheProxyDataSource[*structs.DatacentersRequest]{c, cachetype.CatalogDatacentersName} return &cacheProxyDataSource[*structs.DatacentersRequest]{c, cachetype.CatalogDatacentersName}
} }
// CacheFederationStateListMeshGateways satisfies the proxycfg.FederationStateListMeshGateways // CacheServiceGateways satisfies the proxycfg.ServiceGateways interface by
// interface by sourcing data from the agent cache.
func CacheFederationStateListMeshGateways(c *cache.Cache) proxycfg.FederationStateListMeshGateways {
return &cacheProxyDataSource[*structs.DCSpecificRequest]{c, cachetype.FederationStateListMeshGatewaysName}
}
// CacheGatewayServices satisfies the proxycfg.GatewayServices interface by
// sourcing data from the agent cache. // sourcing data from the agent cache.
func CacheGatewayServices(c *cache.Cache) proxycfg.GatewayServices { func CacheServiceGateways(c *cache.Cache) proxycfg.GatewayServices {
return &cacheProxyDataSource[*structs.ServiceSpecificRequest]{c, cachetype.GatewayServicesName} return &cacheProxyDataSource[*structs.ServiceSpecificRequest]{c, cachetype.ServiceGatewaysName}
} }
// CacheHTTPChecks satisifies the proxycfg.HTTPChecks interface by sourcing // CacheHTTPChecks satisifies the proxycfg.HTTPChecks interface by sourcing
@ -76,6 +74,12 @@ func CacheIntentionUpstreams(c *cache.Cache) proxycfg.IntentionUpstreams {
return &cacheProxyDataSource[*structs.ServiceSpecificRequest]{c, cachetype.IntentionUpstreamsName} return &cacheProxyDataSource[*structs.ServiceSpecificRequest]{c, cachetype.IntentionUpstreamsName}
} }
// CacheIntentionUpstreamsDestination satisfies the proxycfg.IntentionUpstreamsDestination interface
// by sourcing data from the agent cache.
func CacheIntentionUpstreamsDestination(c *cache.Cache) proxycfg.IntentionUpstreams {
return &cacheProxyDataSource[*structs.ServiceSpecificRequest]{c, cachetype.IntentionUpstreamsDestinationName}
}
// CacheInternalServiceDump satisfies the proxycfg.InternalServiceDump // CacheInternalServiceDump satisfies the proxycfg.InternalServiceDump
// interface by sourcing data from the agent cache. // interface by sourcing data from the agent cache.
func CacheInternalServiceDump(c *cache.Cache) proxycfg.InternalServiceDump { func CacheInternalServiceDump(c *cache.Cache) proxycfg.InternalServiceDump {
@ -88,12 +92,6 @@ func CacheLeafCertificate(c *cache.Cache) proxycfg.LeafCertificate {
return &cacheProxyDataSource[*cachetype.ConnectCALeafRequest]{c, cachetype.ConnectCALeafName} return &cacheProxyDataSource[*cachetype.ConnectCALeafRequest]{c, cachetype.ConnectCALeafName}
} }
// CachePeeredUpstreams satisfies the proxycfg.PeeredUpstreams interface
// by sourcing data from the agent cache.
func CachePeeredUpstreams(c *cache.Cache) proxycfg.PeeredUpstreams {
return &cacheProxyDataSource[*structs.PartitionSpecificRequest]{c, cachetype.PeeredUpstreamsName}
}
// CachePrepraredQuery satisfies the proxycfg.PreparedQuery interface by // CachePrepraredQuery satisfies the proxycfg.PreparedQuery interface by
// sourcing data from the agent cache. // sourcing data from the agent cache.
func CachePrepraredQuery(c *cache.Cache) proxycfg.PreparedQuery { func CachePrepraredQuery(c *cache.Cache) proxycfg.PreparedQuery {
@ -106,30 +104,6 @@ func CacheResolvedServiceConfig(c *cache.Cache) proxycfg.ResolvedServiceConfig {
return &cacheProxyDataSource[*structs.ServiceConfigRequest]{c, cachetype.ResolvedServiceConfigName} return &cacheProxyDataSource[*structs.ServiceConfigRequest]{c, cachetype.ResolvedServiceConfigName}
} }
// CacheServiceList satisfies the proxycfg.ServiceList interface by sourcing
// data from the agent cache.
func CacheServiceList(c *cache.Cache) proxycfg.ServiceList {
return &cacheProxyDataSource[*structs.DCSpecificRequest]{c, cachetype.CatalogServiceListName}
}
// CacheTrustBundle satisfies the proxycfg.TrustBundle interface by sourcing
// data from the agent cache.
func CacheTrustBundle(c *cache.Cache) proxycfg.TrustBundle {
return &cacheProxyDataSource[*pbpeering.TrustBundleReadRequest]{c, cachetype.TrustBundleReadName}
}
// CacheTrustBundleList satisfies the proxycfg.TrustBundleList interface by sourcing
// data from the agent cache.
func CacheTrustBundleList(c *cache.Cache) proxycfg.TrustBundleList {
return &cacheProxyDataSource[*pbpeering.TrustBundleListByServiceRequest]{c, cachetype.TrustBundleListName}
}
// CacheExportedPeeredServices satisfies the proxycfg.ExportedPeeredServices
// interface by sourcing data from the agent cache.
func CacheExportedPeeredServices(c *cache.Cache) proxycfg.ExportedPeeredServices {
return &cacheProxyDataSource[*structs.DCSpecificRequest]{c, cachetype.ExportedPeeredServicesName}
}
// cacheProxyDataSource implements a generic wrapper around the agent cache to // cacheProxyDataSource implements a generic wrapper around the agent cache to
// provide data to the proxycfg.Manager. // provide data to the proxycfg.Manager.
type cacheProxyDataSource[ReqType cache.Request] struct { type cacheProxyDataSource[ReqType cache.Request] struct {
@ -148,25 +122,6 @@ func (c *cacheProxyDataSource[ReqType]) Notify(
return c.c.NotifyCallback(ctx, c.t, req, correlationID, dispatchCacheUpdate(ch)) return c.c.NotifyCallback(ctx, c.t, req, correlationID, dispatchCacheUpdate(ch))
} }
// Health wraps health.Client so that the proxycfg package doesn't need to
// reference cache.UpdateEvent directly.
func Health(client *health.Client) proxycfg.Health {
return &healthWrapper{client}
}
type healthWrapper struct {
client *health.Client
}
func (h *healthWrapper) Notify(
ctx context.Context,
req *structs.ServiceSpecificRequest,
correlationID string,
ch chan<- proxycfg.UpdateEvent,
) error {
return h.client.Notify(ctx, *req, correlationID, dispatchCacheUpdate(ch))
}
func dispatchCacheUpdate(ch chan<- proxycfg.UpdateEvent) cache.Callback { func dispatchCacheUpdate(ch chan<- proxycfg.UpdateEvent) cache.Callback {
return func(ctx context.Context, e cache.UpdateEvent) { return func(ctx context.Context, e cache.UpdateEvent) {
u := proxycfg.UpdateEvent{ u := proxycfg.UpdateEvent{

View File

@ -0,0 +1,82 @@
package proxycfgglue
import (
"context"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/rpcclient/health"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/submatview"
)
// ClientHealth satisfies the proxycfg.Health interface by sourcing data from
// the given health.Client.
func ClientHealth(client *health.Client) proxycfg.Health {
return &clientHealth{client}
}
type clientHealth struct {
client *health.Client
}
func (h *clientHealth) Notify(
ctx context.Context,
req *structs.ServiceSpecificRequest,
correlationID string,
ch chan<- proxycfg.UpdateEvent,
) error {
return h.client.Notify(ctx, *req, correlationID, dispatchCacheUpdate(ch))
}
// ServerHealth satisfies the proxycfg.Health interface by sourcing data from
// a local materialized view (backed by an EventPublisher subscription).
//
// Requests for services in remote datacenters will be delegated to the given
// remoteSource (i.e. ClientHealth).
func ServerHealth(deps ServerDataSourceDeps, remoteSource proxycfg.Health) proxycfg.Health {
return &serverHealth{deps, remoteSource}
}
type serverHealth struct {
deps ServerDataSourceDeps
remoteSource proxycfg.Health
}
func (h *serverHealth) Notify(ctx context.Context, req *structs.ServiceSpecificRequest, correlationID string, ch chan<- proxycfg.UpdateEvent) error {
if req.Datacenter != h.deps.Datacenter {
return h.remoteSource.Notify(ctx, req, correlationID, ch)
}
return h.deps.ViewStore.NotifyCallback(
ctx,
&healthRequest{h.deps, *req},
correlationID,
dispatchCacheUpdate(ch),
)
}
type healthRequest struct {
deps ServerDataSourceDeps
req structs.ServiceSpecificRequest
}
func (r *healthRequest) CacheInfo() cache.RequestInfo { return r.req.CacheInfo() }
func (r *healthRequest) NewMaterializer() (submatview.Materializer, error) {
view, err := health.NewHealthView(r.req)
if err != nil {
return nil, err
}
return submatview.NewLocalMaterializer(submatview.LocalMaterializerDeps{
Backend: r.deps.EventPublisher,
ACLResolver: r.deps.ACLResolver,
Deps: submatview.Deps{
View: view,
Logger: r.deps.Logger,
Request: health.NewMaterializerRequest(r.req),
},
}), nil
}
func (r *healthRequest) Type() string { return "proxycfgglue.Health" }

View File

@ -0,0 +1,149 @@
package proxycfgglue
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/proxycfg"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/submatview"
"github.com/hashicorp/consul/proto/pbsubscribe"
"github.com/hashicorp/consul/sdk/testutil"
)
func TestServerHealth(t *testing.T) {
t.Run("remote queries are delegated to the remote source", func(t *testing.T) {
var (
ctx = context.Background()
req = &structs.ServiceSpecificRequest{Datacenter: "dc2"}
correlationID = "correlation-id"
ch = make(chan<- proxycfg.UpdateEvent)
result = errors.New("KABOOM")
)
remoteSource := newMockHealth(t)
remoteSource.On("Notify", ctx, req, correlationID, ch).Return(result)
dataSource := ServerHealth(ServerDataSourceDeps{Datacenter: "dc1"}, remoteSource)
err := dataSource.Notify(ctx, req, correlationID, ch)
require.Equal(t, result, err)
})
t.Run("local queries are served from a materialized view", func(t *testing.T) {
// Note: the view is tested more thoroughly in the agent/rpcclient/health
// package, so this is more of a high-level integration test with the local
// materializer.
const (
index uint64 = 123
datacenter = "dc1"
serviceName = "web"
)
logger := testutil.Logger(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
store := submatview.NewStore(logger)
go store.Run(ctx)
publisher := stream.NewEventPublisher(10 * time.Second)
publisher.RegisterHandler(pbsubscribe.Topic_ServiceHealth,
func(stream.SubscribeRequest, stream.SnapshotAppender) (uint64, error) { return index, nil },
true)
go publisher.Run(ctx)
dataSource := ServerHealth(ServerDataSourceDeps{
Datacenter: datacenter,
ACLResolver: newStaticResolver(acl.ManageAll()),
ViewStore: store,
EventPublisher: publisher,
Logger: logger,
}, nil)
eventCh := make(chan proxycfg.UpdateEvent)
require.NoError(t, dataSource.Notify(ctx, &structs.ServiceSpecificRequest{
Datacenter: datacenter,
ServiceName: serviceName,
}, "", eventCh))
testutil.RunStep(t, "initial state", func(t *testing.T) {
result := getEventResult[*structs.IndexedCheckServiceNodes](t, eventCh)
require.Empty(t, result.Nodes)
})
testutil.RunStep(t, "register services", func(t *testing.T) {
publisher.Publish([]stream.Event{
{
Index: index + 1,
Topic: pbsubscribe.Topic_ServiceHealth,
Payload: &state.EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Register,
Value: &structs.CheckServiceNode{
Node: &structs.Node{Node: "node1"},
Service: &structs.NodeService{Service: serviceName},
},
},
},
{
Index: index + 1,
Topic: pbsubscribe.Topic_ServiceHealth,
Payload: &state.EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Register,
Value: &structs.CheckServiceNode{
Node: &structs.Node{Node: "node2"},
Service: &structs.NodeService{Service: serviceName},
},
},
},
})
result := getEventResult[*structs.IndexedCheckServiceNodes](t, eventCh)
require.Len(t, result.Nodes, 2)
})
testutil.RunStep(t, "deregister service", func(t *testing.T) {
publisher.Publish([]stream.Event{
{
Index: index + 2,
Topic: pbsubscribe.Topic_ServiceHealth,
Payload: &state.EventPayloadCheckServiceNode{
Op: pbsubscribe.CatalogOp_Deregister,
Value: &structs.CheckServiceNode{
Node: &structs.Node{Node: "node2"},
Service: &structs.NodeService{Service: serviceName},
},
},
},
})
result := getEventResult[*structs.IndexedCheckServiceNodes](t, eventCh)
require.Len(t, result.Nodes, 1)
})
})
}
func newMockHealth(t *testing.T) *mockHealth {
mock := &mockHealth{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
type mockHealth struct {
mock.Mock
}
func (m *mockHealth) Notify(ctx context.Context, req *structs.ServiceSpecificRequest, correlationID string, ch chan<- proxycfg.UpdateEvent) error {
return m.Called(ctx, req, correlationID, ch).Error(0)
}

View File

@ -0,0 +1,56 @@
package proxycfgglue
import (
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/proxycfg"
)
func indexGenerator() func() uint64 {
var idx uint64
return func() uint64 {
idx++
return idx
}
}
func getEventResult[ResultType any](t *testing.T, eventCh <-chan proxycfg.UpdateEvent) ResultType {
t.Helper()
select {
case event := <-eventCh:
require.NoError(t, event.Err, "event should not have an error")
result, ok := event.Result.(ResultType)
require.Truef(t, ok, "unexpected result type: %T", event.Result)
return result
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout waiting for event")
}
panic("this should never be reached")
}
func expectNoEvent(t *testing.T, eventCh <-chan proxycfg.UpdateEvent) {
select {
case <-eventCh:
t.Fatal("expected no event")
case <-time.After(100 * time.Millisecond):
}
}
func getEventError(t *testing.T, eventCh <-chan proxycfg.UpdateEvent) error {
t.Helper()
select {
case event := <-eventCh:
require.Error(t, event.Err)
return event.Err
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout waiting for event")
}
panic("this should never be reached")
}

View File

@ -3,7 +3,6 @@ package proxycfgglue
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -62,7 +61,7 @@ func TestServerIntentionUpstreams(t *testing.T) {
authz := policyAuthorizer(t, `service "db" { policy = "read" }`) authz := policyAuthorizer(t, `service "db" { policy = "read" }`)
dataSource := ServerIntentionUpstreams(ServerDataSourceDeps{ dataSource := ServerIntentionUpstreams(ServerDataSourceDeps{
ACLResolver: staticResolver{authz}, ACLResolver: newStaticResolver(authz),
GetStore: func() Store { return store }, GetStore: func() Store { return store },
}) })
@ -70,28 +69,16 @@ func TestServerIntentionUpstreams(t *testing.T) {
err := dataSource.Notify(ctx, &structs.ServiceSpecificRequest{ServiceName: serviceName}, "", ch) err := dataSource.Notify(ctx, &structs.ServiceSpecificRequest{ServiceName: serviceName}, "", ch)
require.NoError(t, err) require.NoError(t, err)
select { result := getEventResult[*structs.IndexedServiceList](t, ch)
case event := <-ch:
result, ok := event.Result.(*structs.IndexedServiceList)
require.Truef(t, ok, "expected IndexedServiceList, got: %T", event.Result)
require.Len(t, result.Services, 0) require.Len(t, result.Services, 0)
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout waiting for event")
}
// Create an allow intention for the db service. This should *not* be filtered // Create an allow intention for the db service. This should *not* be filtered
// out because the ACL token *does* have read access on it. // out because the ACL token *does* have read access on it.
createIntention("db") createIntention("db")
select { result = getEventResult[*structs.IndexedServiceList](t, ch)
case event := <-ch:
result, ok := event.Result.(*structs.IndexedServiceList)
require.Truef(t, ok, "expected IndexedServiceList, got: %T", event.Result)
require.Len(t, result.Services, 1) require.Len(t, result.Services, 1)
require.Equal(t, "db", result.Services[0].Name) require.Equal(t, "db", result.Services[0].Name)
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout waiting for event")
}
} }
func disableLegacyIntentions(t *testing.T, store *state.Store) { func disableLegacyIntentions(t *testing.T, store *state.Store) {

Some files were not shown because too many files have changed in this diff Show More