Merge branch 'main' of github.com:hashicorp/consul into what_service_mesh

This commit is contained in:
Karl Cardenas 2022-02-04 09:00:14 -07:00
commit 3665e95f99
No known key found for this signature in database
GPG Key ID: 0AC61D76B41F1EDC
495 changed files with 19266 additions and 11587 deletions

3
.changelog/11827.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:breaking-change
sdk: several changes to the testutil configuration structs (removed `ACLMasterToken`, renamed `Master` to `InitialManagement`, and `AgentMaster` to `AgentRecovery`)
```

3
.changelog/12080.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:enhancement
streaming: Improved performance when the server is handling many concurrent subscriptions and has a high number of CPU cores
```

3
.changelog/12081.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
ui: Fixed a bug with creating multiple nested KVs in one interaction
```

3
.changelog/12126.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
sdk: Add support for `Partition` and `RetryJoin` to the TestServerConfig struct.
```

3
.changelog/12166.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:deprecation
acl: The `consul.acl.ResolveTokenToIdentity` metric is no longer reported. The values that were previous reported as part of this metric will now be part of the `consul.acl.ResolveToken` metric.
```

3
.changelog/12174.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
xds: fix for delta xDS reconnect bug in LDS/CDS
```

3
.changelog/12176.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:enhancement
systemd: Support starting/stopping the systemd service for linux packages when the optional EnvironmentFile does not exist.
```

3
.changelog/12209.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:enhancement
ui: Use @hashicorp/flight icons for all our icons.
```

3
.changelog/_1502.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
partitions: **(Enterprise only)** Do not leave a serf partition when the partition is deleted
```

View File

@ -359,7 +359,7 @@ jobs:
path: /tmp/jsonfile
- run: *notify-slack-failure
# build all distros
# build is a templated job for build-x
build-distros: &build-distros
docker:
- image: *GOLANG_IMAGE
@ -367,7 +367,13 @@ jobs:
<<: *ENVIRONMENT
steps:
- checkout
- run: ./build-support/scripts/build-local.sh
- run:
name: Build
command: |
for os in $XC_OS; do
target="./pkg/bin/${GOOS}_${GOARCH}/"
GOOS="$os" CGO_ENABLED=0 go build -o "$target" -ldflags "$(GOLDFLAGS)" -tags "$(GOTAGS)"
done
# save dev build to CircleCI
- store_artifacts:
@ -380,7 +386,7 @@ jobs:
environment:
<<: *build-env
XC_OS: "freebsd linux windows"
XC_ARCH: "386"
GOARCH: "386"
# build all amd64 architecture supported OS binaries
build-amd64:
@ -388,7 +394,7 @@ jobs:
environment:
<<: *build-env
XC_OS: "darwin freebsd linux solaris windows"
XC_ARCH: "amd64"
GOARCH: "amd64"
# build all arm/arm64 architecture supported OS binaries
build-arm:
@ -433,7 +439,11 @@ jobs:
- attach_workspace: # this normally runs as the first job and has nothing to attach; only used in main branch after rebuilding UI
at: .
- run:
command: make dev
name: Build
command: |
make dev
mkdir -p /home/circleci/go/bin
cp ./bin/consul /home/circleci/go/bin/consul
# save dev build to pass to downstream jobs
- persist_to_workspace:
@ -689,11 +699,20 @@ jobs:
if ! git diff --quiet --exit-code HEAD^! ui/; then
git config --local user.email "github-team-consul-core@hashicorp.com"
git config --local user.name "hc-github-team-consul-core"
# stash newly built bindata_assetfs.go
git stash push
# checkout the CI branch and merge latest from main
git checkout ci/main-assetfs-build
git merge --no-edit main
git stash pop
short_sha=$(git rev-parse --short HEAD)
git add agent/uiserver/bindata_assetfs.go
git commit -m "auto-updated agent/uiserver/bindata_assetfs.go from commit ${short_sha}"
git push origin main
git push origin ci/main-assetfs-build
else
echo "no UI changes so no static assets to publish"
fi

View File

@ -1,4 +1,5 @@
# Contributing to Consul
>**Note:** We take Consul's security and our users' trust very seriously.
>If you believe you have found a security issue in Consul, please responsibly
>disclose by contacting us at security@hashicorp.com.
@ -14,7 +15,9 @@ talk to us! A great way to do this is in issues themselves. When you want to
work on an issue, comment on it first and tell us the approach you want to take.
## Getting Started
### Some Ways to Contribute
* Report potential bugs.
* Suggest product enhancements.
* Increase our test coverage.
@ -24,7 +27,8 @@ work on an issue, comment on it first and tell us the approach you want to take.
are deployed from this repo.
* Respond to questions about usage on the issue tracker or the Consul section of the [HashiCorp forum]: (https://discuss.hashicorp.com/c/consul)
### Reporting an Issue:
### Reporting an Issue
>Note: Issues on GitHub for Consul are intended to be related to bugs or feature requests.
>Questions should be directed to other community resources such as the: [Discuss Forum](https://discuss.hashicorp.com/c/consul/29), [FAQ](https://www.consul.io/docs/faq.html), or [Guides](https://www.consul.io/docs/guides/index.html).
@ -53,42 +57,47 @@ issue. Stale issues will be closed.
4. The issue is addressed in a pull request or commit. The issue will be
referenced in the commit message so that the code that fixes it is clearly
linked.
linked. Any change a Consul user might need to know about will include a
changelog entry in the PR.
5. The issue is closed.
## Building Consul
If you wish to work on Consul itself, you'll first need [Go](https://golang.org)
installed (The version of Go should match the one of our [CI config's](https://github.com/hashicorp/consul/blob/main/.circleci/config.yml) Go image).
Next, clone this repository and then run `make dev`. In a few moments, you'll have a working
`consul` executable in `consul/bin` and `$GOPATH/bin`:
>Note: `make dev` will build for your local machine's os/architecture. If you wish to build for all os/architecture combinations use `make`.
## Making Changes to Consul
The first step to making changes is to fork Consul. Afterwards, the easiest way
to work on the fork is to set it as a remote of the Consul project:
### Prerequisites
1. Navigate to `$GOPATH/src/github.com/hashicorp/consul`
2. Rename the existing remote's name: `git remote rename origin upstream`.
3. Add your fork as a remote by running
`git remote add origin <github url of fork>`. For example:
`git remote add origin https://github.com/myusername/consul`.
4. Checkout a feature branch: `git checkout -t -b new-feature`
5. Make changes
6. Push changes to the fork when ready to submit PR:
`git push -u origin new-feature`
If you wish to work on Consul itself, you'll first need to:
- install [Go](https://golang.org) (the version should match that of our
[CI config's](https://github.com/hashicorp/consul/blob/main/.circleci/config.yml) Go image).
- [fork the Consul repo](../docs/contributing/fork-the-project.md)
By following these steps you can push to your fork to create a PR, but the code on disk still
lives in the spot where the go cli tools are expecting to find it.
### Building Consul
>Note: If you make any changes to the code, run `gofmt -s -w` to automatically format the code according to Go standards.
To build Consul, run `make dev`. In a few moments, you'll have a working
`consul` executable in `consul/bin` and `$GOPATH/bin`:
## Testing
>Note: `make dev` will build for your local machine's os/architecture. If you wish to build for all os/architecture combinations, use `make`.
### Modifying the Code
#### Code Formatting
Go provides [tooling to apply consistent code formatting](https://golang.org/doc/effective_go#formatting).
If you make any changes to the code, run `gofmt -s -w` to automatically format the code according to Go standards.
#### Updating Go Module Dependencies
If a dependency is added or change, run `go mod tidy` to update `go.mod` and `go.sum`.
#### Developer Documentation
Developer-focused documentation about the Consul code base is under [./docs],
and godoc package document can be read at [pkg.go.dev/github.com/hashicorp/consul].
[./docs]: ../docs/README.md
[pkg.go.dev/github.com/hashicorp/consul]: https://pkg.go.dev/github.com/hashicorp/consul
### Testing
During development, it may be more convenient to check your work-in-progress by running only the tests which you expect to be affected by your changes, as the full test suite can take several minutes to execute. [Go's built-in test tool](https://golang.org/pkg/cmd/go/internal/test/) allows specifying a list of packages to test and the `-run` option to only include test names matching a regular expression.
The `go test -short` flag can also be used to skip slower tests.
@ -99,22 +108,44 @@ Examples (run from the repository root):
When a pull request is opened CI will run all tests and lint to verify the change.
## Go Module Dependencies
### Submitting a Pull Request
If a dependency is added or change, run `go mod tidy` to update `go.mod` and `go.sum`.
Before writing any code, we recommend:
- Create a Github issue if none already exists for the code change you'd like to make.
- Write a comment on the Github issue indicating you're interested in contributing so
maintainers can provide their perspective if needed.
## Developer Documentation
Keep your pull requests (PRs) small and open them early so you can get feedback on
approach from maintainers before investing your time in larger changes. For example,
see how [applying URL-decoding of resource names across the whole HTTP API](https://github.com/hashicorp/consul/issues/11258)
started with [iterating on the right approach for a few endpoints](https://github.com/hashicorp/consul/pull/11335)
before applying more broadly.
Documentation about the Consul code base is under [./docs],
and godoc package document can be read at [pkg.go.dev/github.com/hashicorp/consul].
When you're ready to submit a pull request:
1. Review the [list of checklists](#checklists) for common changes and follow any
that apply to your work.
2. Include evidence that your changes work as intended (e.g., add/modify unit tests;
describe manual tests you ran, in what environment,
and the results including screenshots or terminal output).
3. Open the PR from your fork against base repository `hashicorp/consul` and branch `main`.
- [Link the PR to its associated issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue).
4. Include any specific questions that you have for the reviewer in the PR description
or as a PR comment in Github.
- If there's anything you find the need to explain or clarify in the PR, consider
whether that explanation should be added in the source code as comments.
- You can submit a [draft PR](https://github.blog/2019-02-14-introducing-draft-pull-requests/)
if your changes aren't finalized but would benefit from in-process feedback.
5. If there's any reason Consul users might need to know about this change,
[add a changelog entry](../docs/contributing/add-a-changelog-entry.md).
6. After you submit, the Consul maintainers team needs time to carefully review your
contribution and ensure it is production-ready, considering factors such as: security,
backwards-compatibility, potential regressions, etc.
7. After you address Consul maintainer feedback and the PR is approved, a Consul maintainer
will merge it. Your contribution will be available from the next major release (e.g., 1.x)
unless explicitly backported to an existing or previous major release by the maintainer.
[./docs]: ../docs/README.md
[pkg.go.dev/github.com/hashicorp/consul]: https://pkg.go.dev/github.com/hashicorp/consul
#### Checklists
### Checklists
Some common changes that many PRs require such as adding config fields, are
documented through checklists.
Please check in [docs/](../docs/) for any `checklist-*.md` files that might help
with your change.
Some common changes that many PRs require are documented through checklists as
`checklist-*.md` files in [docs/](../docs/), including:
- [Adding config fields](../docs/config/checklist-adding-config-fields.md)

View File

@ -3,7 +3,7 @@ name: build
on:
push:
# Sequence of patterns matched against refs/heads
branches: [
branches: [
"main"
]
@ -145,6 +145,7 @@ jobs:
config_dir: ".release/linux/package"
preinstall: ".release/linux/preinstall"
postinstall: ".release/linux/postinstall"
preremove: ".release/linux/preremove"
postremove: ".release/linux/postremove"
- name: Set Package Names
@ -218,7 +219,7 @@ jobs:
GOLDFLAGS: "${{ needs.get-product-version.outputs.shared-ldflags }}"
run: |
mkdir dist out
go build -ldflags="$GOLDFLAGS" -o dist/ .
go build -ldflags="$GOLDFLAGS" -tags netcgo -o dist/ .
zip -r -j out/${{ env.PKG_NAME }}_${{ needs.get-product-version.outputs.product-version }}_${{ matrix.goos }}_${{ matrix.goarch }}.zip dist/
- uses: actions/upload-artifact@v2

View File

@ -21,8 +21,10 @@ on:
jobs:
# checks that a 'type/docs-cherrypick' label is attached to PRs with website/ changes
website-check:
# If there's a `type/docs-cherrypick` label we ignore this check
if: "!contains(github.event.pull_request.labels.*.name, 'type/docs-cherrypick')"
# If there's already a `type/docs-cherrypick` label or an explicit `pr/no-docs` label, we ignore this check
if: >-
!contains(github.event.pull_request.labels.*.name, 'type/docs-cherrypick') ||
!contains(github.event.pull_request.labels.*.name, 'pr/no-docs')
runs-on: ubuntu-latest
steps:
@ -40,7 +42,7 @@ jobs:
# post PR comment to GitHub to check if a 'type/docs-cherrypick' label needs to be applied to the PR
echo "website-check: Did not find a 'type/docs-cherrypick' label, posting a reminder in the PR"
github_message="🤔 This PR has changes in the \`website/\` directory but does not have a \`type/docs-cherrypick\` label. If the changes are for the next version, this can be ignored. If they are updates to current docs, attach the label to auto cherrypick to the \`stable-website\` branch after merging."
curl -f -s -H "Authorization: token ${{ secrets.PR_COMMENT_TOKEN }}" \
curl -s -H "Authorization: token ${{ secrets.PR_COMMENT_TOKEN }}" \
-X POST \
-d "{ \"body\": \"${github_message}\"}" \
"https://api.github.com/repos/${GITHUB_REPOSITORY}/issues/${{ github.event.pull_request.number }}/comments"

View File

@ -7,6 +7,7 @@ linters:
- staticcheck
- ineffassign
- unparam
- forbidigo
issues:
# Disable the default exclude list so that all excludes are explicitly
@ -57,6 +58,14 @@ issues:
linters-settings:
gofmt:
simplify: true
forbidigo:
# Forbid the following identifiers (list of regexp).
forbid:
- '\brequire\.New\b(# Use package-level functions with explicit TestingT)?'
- '\bassert\.New\b(# Use package-level functions with explicit TestingT)?'
# Exclude godoc examples from forbidigo checks.
# Default: true
exclude_godoc_examples: false
run:
timeout: 10m

View File

@ -6,7 +6,7 @@ After=network-online.target
ConditionFileNotEmpty=/etc/consul.d/consul.hcl
[Service]
EnvironmentFile=/etc/consul.d/consul.env
EnvironmentFile=-/etc/consul.d/consul.env
User=consul
Group=consul
ExecStart=/usr/bin/consul agent -config-dir=/etc/consul.d/

View File

@ -1,14 +1,19 @@
#!/bin/bash
if [ "$1" = "purge" ]
then
userdel consul
if [ -d "/run/systemd/system" ]; then
systemctl --system daemon-reload >/dev/null || :
fi
if [ "$1" == "upgrade" ] && [ -d /run/systemd/system ]; then
systemctl --system daemon-reload >/dev/null || true
systemctl restart consul >/dev/null || true
fi
case "$1" in
purge | 0)
userdel consul
;;
upgrade | [1-9]*)
if [ -d "/run/systemd/system" ]; then
systemctl try-restart consul.service >/dev/null || :
fi
;;
esac
exit 0

11
.release/linux/preremove Normal file
View File

@ -0,0 +1,11 @@
#!/bin/bash
case "$1" in
remove | 0)
if [ -d "/run/systemd/system" ]; then
systemctl --no-reload disable consul.service > /dev/null || :
systemctl stop consul.service > /dev/null || :
fi
;;
esac
exit 0

View File

@ -1,3 +1,6 @@
# For documentation on building consul from source, refer to:
# https://www.consul.io/docs/install#compiling-from-source
SHELL = bash
GOGOVERSION?=$(shell grep github.com/gogo/protobuf go.mod | awk '{print $$2}')
GOTOOLS = \
@ -12,8 +15,6 @@ GOTOOLS = \
github.com/hashicorp/lint-consul-retry@master
GOTAGS ?=
GOOS?=$(shell go env GOOS)
GOARCH?=$(shell go env GOARCH)
GOPATH=$(shell go env GOPATH)
MAIN_GOPATH=$(shell go env GOPATH | cut -d: -f1)
@ -134,20 +135,17 @@ ifdef SKIP_DOCKER_BUILD
ENVOY_INTEG_DEPS=noop
endif
# all builds binaries for all targets
all: bin
all: dev-build
# used to make integration dependencies conditional
noop: ;
bin: tools
@$(SHELL) $(CURDIR)/build-support/scripts/build-local.sh
# dev creates binaries for testing locally - these are put into ./bin and $GOPATH
# dev creates binaries for testing locally - these are put into ./bin
dev: dev-build
dev-build:
@$(SHELL) $(CURDIR)/build-support/scripts/build-local.sh -o $(GOOS) -a $(GOARCH)
mkdir -p bin
CGO_ENABLED=0 go build -o ./bin -ldflags "$(GOLDFLAGS)" -tags "$(GOTAGS)"
dev-docker: linux
@echo "Pulling consul container image - $(CONSUL_IMAGE_VERSION)"
@ -175,9 +173,10 @@ ifeq ($(CIRCLE_BRANCH), main)
@docker push $(CI_DEV_DOCKER_NAMESPACE)/$(CI_DEV_DOCKER_IMAGE_NAME):latest
endif
# linux builds a linux package independent of the source platform
# linux builds a linux binary independent of the source platform
linux:
@$(SHELL) $(CURDIR)/build-support/scripts/build-local.sh -o linux -a amd64
mkdir -p bin
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o ./bin -ldflags "$(GOLDFLAGS)" -tags "$(GOTAGS)"
# dist builds binaries for all platforms and packages them for distribution
dist:

View File

@ -1,11 +1,20 @@
# Consul [![CircleCI](https://circleci.com/gh/hashicorp/consul/tree/main.svg?style=svg)](https://circleci.com/gh/hashicorp/consul/tree/main) [![Discuss](https://img.shields.io/badge/discuss-consul-ca2171.svg?style=flat)](https://discuss.hashicorp.com/c/consul)
# Consul
<p>
<a href="https://consul.io" title="Consul website">
<img src="./website/public/img/logo-hashicorp.svg" alt="HashiCorp Consul logo" width="200px">
</a>
</p>
[![Docker Pulls](https://img.shields.io/docker/pulls/_/consul.svg)](https://hub.docker.com/_/consul)
[![Go Report Card](https://goreportcard.com/badge/github.com/hashicorp/consul)](https://goreportcard.com/report/github.com/hashicorp/consul)
Consul is a distributed, highly available, and data center aware solution to connect and configure applications across dynamic, distributed infrastructure.
* Website: https://www.consul.io
* Tutorials: [HashiCorp Learn](https://learn.hashicorp.com/consul)
* Forum: [Discuss](https://discuss.hashicorp.com/c/consul)
Consul is a distributed, highly available, and data center aware solution to connect and configure applications across dynamic, distributed infrastructure.
Consul provides several key features:
* **Multi-Datacenter** - Consul is built to be datacenter aware, and can

View File

@ -1516,30 +1516,28 @@ func TestMergePolicies(t *testing.T) {
},
}
req := require.New(t)
for _, tcase := range tests {
t.Run(tcase.name, func(t *testing.T) {
act := MergePolicies(tcase.input)
exp := tcase.expected
req.Equal(exp.ACL, act.ACL)
req.Equal(exp.Keyring, act.Keyring)
req.Equal(exp.Operator, act.Operator)
req.Equal(exp.Mesh, act.Mesh)
req.ElementsMatch(exp.Agents, act.Agents)
req.ElementsMatch(exp.AgentPrefixes, act.AgentPrefixes)
req.ElementsMatch(exp.Events, act.Events)
req.ElementsMatch(exp.EventPrefixes, act.EventPrefixes)
req.ElementsMatch(exp.Keys, act.Keys)
req.ElementsMatch(exp.KeyPrefixes, act.KeyPrefixes)
req.ElementsMatch(exp.Nodes, act.Nodes)
req.ElementsMatch(exp.NodePrefixes, act.NodePrefixes)
req.ElementsMatch(exp.PreparedQueries, act.PreparedQueries)
req.ElementsMatch(exp.PreparedQueryPrefixes, act.PreparedQueryPrefixes)
req.ElementsMatch(exp.Services, act.Services)
req.ElementsMatch(exp.ServicePrefixes, act.ServicePrefixes)
req.ElementsMatch(exp.Sessions, act.Sessions)
req.ElementsMatch(exp.SessionPrefixes, act.SessionPrefixes)
require.Equal(t, exp.ACL, act.ACL)
require.Equal(t, exp.Keyring, act.Keyring)
require.Equal(t, exp.Operator, act.Operator)
require.Equal(t, exp.Mesh, act.Mesh)
require.ElementsMatch(t, exp.Agents, act.Agents)
require.ElementsMatch(t, exp.AgentPrefixes, act.AgentPrefixes)
require.ElementsMatch(t, exp.Events, act.Events)
require.ElementsMatch(t, exp.EventPrefixes, act.EventPrefixes)
require.ElementsMatch(t, exp.Keys, act.Keys)
require.ElementsMatch(t, exp.KeyPrefixes, act.KeyPrefixes)
require.ElementsMatch(t, exp.Nodes, act.Nodes)
require.ElementsMatch(t, exp.NodePrefixes, act.NodePrefixes)
require.ElementsMatch(t, exp.PreparedQueries, act.PreparedQueries)
require.ElementsMatch(t, exp.PreparedQueryPrefixes, act.PreparedQueryPrefixes)
require.ElementsMatch(t, exp.Services, act.Services)
require.ElementsMatch(t, exp.ServicePrefixes, act.ServicePrefixes)
require.ElementsMatch(t, exp.Sessions, act.Sessions)
require.ElementsMatch(t, exp.SessionPrefixes, act.SessionPrefixes)
})
}

View File

@ -15,7 +15,7 @@ import (
// critical purposes, such as logging. Therefore we interpret all errors as empty-string
// so we can safely log it without handling non-critical errors at the usage site.
func (a *Agent) aclAccessorID(secretID string) string {
ident, err := a.delegate.ResolveTokenToIdentity(secretID)
ident, err := a.delegate.ResolveTokenAndDefaultMeta(secretID, nil, nil)
if acl.IsErrNotFound(err) {
return ""
}
@ -23,10 +23,7 @@ func (a *Agent) aclAccessorID(secretID string) string {
a.logger.Debug("non-critical error resolving acl token accessor for logging", "error", err)
return ""
}
if ident == nil {
return ""
}
return ident.ID()
return ident.AccessorID()
}
// vetServiceRegister makes sure the service registration action is allowed by
@ -174,7 +171,7 @@ func (a *Agent) filterMembers(token string, members *[]serf.Member) error {
if authz.NodeRead(node, &authzContext) == acl.Allow {
continue
}
accessorID := a.aclAccessorID(token)
accessorID := authz.AccessorID()
a.logger.Debug("dropping node from result due to ACLs", "node", node, "accessorID", accessorID)
m = append(m[:i], m[i+1:]...)
i--

View File

@ -849,10 +849,10 @@ func TestACL_HTTP(t *testing.T) {
tokens, ok := raw.(structs.ACLTokenListStubs)
require.True(t, ok)
// 3 tokens created but 1 was deleted + master token + anon token
// 3 tokens created but 1 was deleted + initial management token + anon token
require.Len(t, tokens, 4)
// this loop doesn't verify anything about the master token
// this loop doesn't verify anything about the initial management token
for tokenID, expected := range tokenMap {
found := false
for _, actual := range tokens {
@ -1880,7 +1880,7 @@ func TestACL_Authorize(t *testing.T) {
var localToken structs.ACLToken
require.NoError(t, a2.RPC("ACL.TokenSet", &localTokenReq, &localToken))
t.Run("master-token", func(t *testing.T) {
t.Run("initial-management-token", func(t *testing.T) {
request := []structs.ACLAuthorizationRequest{
{
Resource: "acl",
@ -2016,7 +2016,7 @@ func TestACL_Authorize(t *testing.T) {
resp := responses[idx]
require.Equal(t, req, resp.ACLAuthorizationRequest)
require.True(t, resp.Allow, "should have allowed all access for master token")
require.True(t, resp.Allow, "should have allowed all access for initial management token")
}
})
}
@ -2277,7 +2277,7 @@ func TestACL_Authorize(t *testing.T) {
type rpcFn func(string, interface{}, interface{}) error
func upsertTestCustomizedAuthMethod(
rpc rpcFn, masterToken string, datacenter string,
rpc rpcFn, initialManagementToken string, datacenter string,
modify func(method *structs.ACLAuthMethod),
) (*structs.ACLAuthMethod, error) {
name, err := uuid.GenerateUUID()
@ -2291,7 +2291,7 @@ func upsertTestCustomizedAuthMethod(
Name: "test-method-" + name,
Type: "testing",
},
WriteRequest: structs.WriteRequest{Token: masterToken},
WriteRequest: structs.WriteRequest{Token: initialManagementToken},
}
if modify != nil {
@ -2308,11 +2308,11 @@ func upsertTestCustomizedAuthMethod(
return &out, nil
}
func upsertTestCustomizedBindingRule(rpc rpcFn, masterToken string, datacenter string, modify func(rule *structs.ACLBindingRule)) (*structs.ACLBindingRule, error) {
func upsertTestCustomizedBindingRule(rpc rpcFn, initialManagementToken string, datacenter string, modify func(rule *structs.ACLBindingRule)) (*structs.ACLBindingRule, error) {
req := structs.ACLBindingRuleSetRequest{
Datacenter: datacenter,
BindingRule: structs.ACLBindingRule{},
WriteRequest: structs.WriteRequest{Token: masterToken},
WriteRequest: structs.WriteRequest{Token: initialManagementToken},
}
if modify != nil {

View File

@ -39,6 +39,12 @@ type TestACLAgent struct {
func NewTestACLAgent(t *testing.T, name string, hcl string, resolveAuthz authzResolver, resolveIdent identResolver) *TestACLAgent {
t.Helper()
if resolveIdent == nil {
resolveIdent = func(s string) (structs.ACLIdentity, error) {
return nil, nil
}
}
a := &TestACLAgent{resolveAuthzFn: resolveAuthz, resolveIdentFn: resolveIdent}
dataDir := testutil.TempDir(t, "acl-agent")
@ -86,26 +92,15 @@ func (a *TestACLAgent) ResolveToken(secretID string) (acl.Authorizer, error) {
return authz, err
}
func (a *TestACLAgent) ResolveTokenToIdentityAndAuthorizer(secretID string) (structs.ACLIdentity, acl.Authorizer, error) {
if a.resolveAuthzFn == nil {
return nil, nil, fmt.Errorf("ResolveTokenToIdentityAndAuthorizer call is unexpected - no authz resolver callback set")
}
return a.resolveAuthzFn(secretID)
}
func (a *TestACLAgent) ResolveTokenToIdentity(secretID string) (structs.ACLIdentity, error) {
if a.resolveIdentFn == nil {
return nil, fmt.Errorf("ResolveTokenToIdentity call is unexpected - no ident resolver callback set")
}
return a.resolveIdentFn(secretID)
}
func (a *TestACLAgent) ResolveTokenAndDefaultMeta(secretID string, entMeta *structs.EnterpriseMeta, authzContext *acl.AuthorizerContext) (acl.Authorizer, error) {
identity, authz, err := a.ResolveTokenToIdentityAndAuthorizer(secretID)
func (a *TestACLAgent) ResolveTokenAndDefaultMeta(secretID string, entMeta *structs.EnterpriseMeta, authzContext *acl.AuthorizerContext) (consul.ACLResolveResult, error) {
authz, err := a.ResolveToken(secretID)
if err != nil {
return nil, err
return consul.ACLResolveResult{}, err
}
identity, err := a.resolveIdentFn(secretID)
if err != nil {
return consul.ACLResolveResult{}, err
}
// Default the EnterpriseMeta based on the Tokens meta or actual defaults
@ -119,7 +114,7 @@ func (a *TestACLAgent) ResolveTokenAndDefaultMeta(secretID string, entMeta *stru
// Use the meta to fill in the ACL authorization context
entMeta.FillAuthzContext(authzContext)
return authz, err
return consul.ACLResolveResult{Authorizer: authz, ACLIdentity: identity}, err
}
// All of these are stubs to satisfy the interface
@ -523,22 +518,3 @@ func TestACL_filterChecksWithAuthorizer(t *testing.T) {
_, ok = checks["my-other"]
require.False(t, ok)
}
// TODO: remove?
func TestACL_ResolveIdentity(t *testing.T) {
t.Parallel()
a := NewTestACLAgent(t, t.Name(), TestACLConfig(), nil, catalogIdent)
// this test is meant to ensure we are calling the correct function
// which is ResolveTokenToIdentity on the Agent delegate. Our
// nil authz resolver will cause it to emit an error if used
ident, err := a.delegate.ResolveTokenToIdentity(nodeROSecret)
require.NoError(t, err)
require.NotNil(t, ident)
// just double checkingto ensure if we had used the wrong function
// that an error would be produced
_, err = a.delegate.ResolveTokenAndDefaultMeta(nodeROSecret, nil, nil)
require.Error(t, err)
}

View File

@ -167,14 +167,11 @@ type delegate interface {
// RemoveFailedNode is used to remove a failed node from the cluster.
RemoveFailedNode(node string, prune bool, entMeta *structs.EnterpriseMeta) error
// TODO: replace this method with consul.ACLResolver
ResolveTokenToIdentity(token string) (structs.ACLIdentity, error)
// ResolveTokenAndDefaultMeta returns an acl.Authorizer which authorizes
// actions based on the permissions granted to the token.
// If either entMeta or authzContext are non-nil they will be populated with the
// default partition and namespace from the token.
ResolveTokenAndDefaultMeta(token string, entMeta *structs.EnterpriseMeta, authzContext *acl.AuthorizerContext) (acl.Authorizer, error)
ResolveTokenAndDefaultMeta(token string, entMeta *structs.EnterpriseMeta, authzContext *acl.AuthorizerContext) (consul.ACLResolveResult, error)
RPC(method string, args interface{}, reply interface{}) error
SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error
@ -209,9 +206,6 @@ type Agent struct {
// depending on the configuration
delegate delegate
// aclMasterAuthorizer is an object that helps manage local ACL enforcement.
aclMasterAuthorizer acl.Authorizer
// state stores a local representation of the node,
// services and checks. Used for anti-entropy.
State *local.State

File diff suppressed because it is too large Load Diff

View File

@ -1855,7 +1855,6 @@ func TestAgent_AddCheck_Alias(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "")
defer a.Shutdown()
@ -1869,19 +1868,19 @@ func TestAgent_AddCheck_Alias(t *testing.T) {
AliasService: "foo",
}
err := a.AddCheck(health, chk, false, "", ConfigSourceLocal)
require.NoError(err)
require.NoError(t, err)
// Ensure we have a check mapping
sChk := requireCheckExists(t, a, "aliashealth")
require.Equal(api.HealthCritical, sChk.Status)
require.Equal(t, api.HealthCritical, sChk.Status)
chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)]
require.True(ok, "missing aliashealth check")
require.Equal("", chkImpl.RPCReq.Token)
require.True(t, ok, "missing aliashealth check")
require.Equal(t, "", chkImpl.RPCReq.Token)
cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil))
require.NotNil(cs)
require.Equal("", cs.Token)
require.NotNil(t, cs)
require.Equal(t, "", cs.Token)
}
func TestAgent_AddCheck_Alias_setToken(t *testing.T) {
@ -1891,7 +1890,6 @@ func TestAgent_AddCheck_Alias_setToken(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "")
defer a.Shutdown()
@ -1905,15 +1903,15 @@ func TestAgent_AddCheck_Alias_setToken(t *testing.T) {
AliasService: "foo",
}
err := a.AddCheck(health, chk, false, "foo", ConfigSourceLocal)
require.NoError(err)
require.NoError(t, err)
cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil))
require.NotNil(cs)
require.Equal("foo", cs.Token)
require.NotNil(t, cs)
require.Equal(t, "foo", cs.Token)
chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)]
require.True(ok, "missing aliashealth check")
require.Equal("foo", chkImpl.RPCReq.Token)
require.True(t, ok, "missing aliashealth check")
require.Equal(t, "foo", chkImpl.RPCReq.Token)
}
func TestAgent_AddCheck_Alias_userToken(t *testing.T) {
@ -1923,7 +1921,6 @@ func TestAgent_AddCheck_Alias_userToken(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t, `
acl_token = "hello"
`)
@ -1939,15 +1936,15 @@ acl_token = "hello"
AliasService: "foo",
}
err := a.AddCheck(health, chk, false, "", ConfigSourceLocal)
require.NoError(err)
require.NoError(t, err)
cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil))
require.NotNil(cs)
require.Equal("", cs.Token) // State token should still be empty
require.NotNil(t, cs)
require.Equal(t, "", cs.Token) // State token should still be empty
chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)]
require.True(ok, "missing aliashealth check")
require.Equal("hello", chkImpl.RPCReq.Token) // Check should use the token
require.True(t, ok, "missing aliashealth check")
require.Equal(t, "hello", chkImpl.RPCReq.Token) // Check should use the token
}
func TestAgent_AddCheck_Alias_userAndSetToken(t *testing.T) {
@ -1957,7 +1954,6 @@ func TestAgent_AddCheck_Alias_userAndSetToken(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t, `
acl_token = "hello"
`)
@ -1973,15 +1969,15 @@ acl_token = "hello"
AliasService: "foo",
}
err := a.AddCheck(health, chk, false, "goodbye", ConfigSourceLocal)
require.NoError(err)
require.NoError(t, err)
cs := a.State.CheckState(structs.NewCheckID("aliashealth", nil))
require.NotNil(cs)
require.Equal("goodbye", cs.Token)
require.NotNil(t, cs)
require.Equal(t, "goodbye", cs.Token)
chkImpl, ok := a.checkAliases[structs.NewCheckID("aliashealth", nil)]
require.True(ok, "missing aliashealth check")
require.Equal("goodbye", chkImpl.RPCReq.Token)
require.True(t, ok, "missing aliashealth check")
require.Equal(t, "goodbye", chkImpl.RPCReq.Token)
}
func TestAgent_RemoveCheck(t *testing.T) {

View File

@ -11,7 +11,6 @@ import (
)
func TestCatalogServices(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &CatalogServices{RPC: rpc}
@ -22,10 +21,10 @@ func TestCatalogServices(t *testing.T) {
rpc.On("RPC", "Catalog.ServiceNodes", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.ServiceSpecificRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal("web", req.ServiceName)
require.True(req.AllowStale)
require.Equal(t, uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(t, 1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal(t, "web", req.ServiceName)
require.True(t, req.AllowStale)
reply := args.Get(2).(*structs.IndexedServiceNodes)
reply.ServiceNodes = []*structs.ServiceNode{
@ -44,15 +43,14 @@ func TestCatalogServices(t *testing.T) {
ServiceName: "web",
ServiceTags: []string{"tag1", "tag2"},
})
require.NoError(err)
require.Equal(cache.FetchResult{
require.NoError(t, err)
require.Equal(t, cache.FetchResult{
Value: resp,
Index: 48,
}, resultA)
}
func TestCatalogServices_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &CatalogServices{RPC: rpc}
@ -60,7 +58,7 @@ func TestCatalogServices_badReqType(t *testing.T) {
// Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err)
require.Contains(err.Error(), "wrong type")
require.Error(t, err)
require.Contains(t, err.Error(), "wrong type")
}

View File

@ -123,23 +123,22 @@ func TestCalculateSoftExpire(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
now, err := time.Parse("2006-01-02 15:04:05", tc.now)
require.NoError(err)
require.NoError(t, err)
issued, err := time.Parse("2006-01-02 15:04:05", tc.issued)
require.NoError(err)
require.NoError(t, err)
wantMin, err := time.Parse("2006-01-02 15:04:05", tc.wantMin)
require.NoError(err)
require.NoError(t, err)
wantMax, err := time.Parse("2006-01-02 15:04:05", tc.wantMax)
require.NoError(err)
require.NoError(t, err)
min, max := calculateSoftExpiry(now, &structs.IssuedCert{
ValidAfter: issued,
ValidBefore: issued.Add(tc.lifetime),
})
require.Equal(wantMin, min)
require.Equal(wantMax, max)
require.Equal(t, wantMin, min)
require.Equal(t, wantMax, max)
})
}
}
@ -156,7 +155,6 @@ func TestConnectCALeaf_changingRoots(t *testing.T) {
}
t.Parallel()
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
@ -211,8 +209,8 @@ func TestConnectCALeaf_changingRoots(t *testing.T) {
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
v := mustFetchResult(t, result)
require.Equal(resp, v.Value)
require.Equal(uint64(1), v.Index)
require.Equal(t, resp, v.Value)
require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
}
@ -244,9 +242,9 @@ func TestConnectCALeaf_changingRoots(t *testing.T) {
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
v := mustFetchResult(t, result)
require.Equal(resp, v.Value)
require.Equal(t, resp, v.Value)
// 3 since the second CA "update" used up 2
require.Equal(uint64(3), v.Index)
require.Equal(t, uint64(3), v.Index)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
opts.MinIndex = 3
@ -267,7 +265,6 @@ func TestConnectCALeaf_changingRoots(t *testing.T) {
func TestConnectCALeaf_changingRootsJitterBetweenCalls(t *testing.T) {
t.Parallel()
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
@ -323,8 +320,8 @@ func TestConnectCALeaf_changingRootsJitterBetweenCalls(t *testing.T) {
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
v := mustFetchResult(t, result)
require.Equal(resp, v.Value)
require.Equal(uint64(1), v.Index)
require.Equal(t, resp, v.Value)
require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
}
@ -378,24 +375,24 @@ func TestConnectCALeaf_changingRootsJitterBetweenCalls(t *testing.T) {
if v.Index > uint64(1) {
// Got a new cert
require.Equal(resp, v.Value)
require.Equal(uint64(3), v.Index)
require.Equal(t, resp, v.Value)
require.Equal(t, uint64(3), v.Index)
// Should not have been delivered before the delay
require.True(time.Since(earliestRootDelivery) > typ.TestOverrideCAChangeInitialDelay)
require.True(t, time.Since(earliestRootDelivery) > typ.TestOverrideCAChangeInitialDelay)
// All good. We are done!
rootsDelivered = true
} else {
// Should be the cached cert
require.Equal(resp, v.Value)
require.Equal(uint64(1), v.Index)
require.Equal(t, resp, v.Value)
require.Equal(t, uint64(1), v.Index)
// Sanity check we blocked for the whole timeout
require.Truef(timeTaken > opts.Timeout,
require.Truef(t, timeTaken > opts.Timeout,
"should block for at least %s, returned after %s",
opts.Timeout, timeTaken)
// Sanity check that the forceExpireAfter state was set correctly
shouldExpireAfter = v.State.(*fetchState).forceExpireAfter
require.True(shouldExpireAfter.After(time.Now()))
require.True(shouldExpireAfter.Before(time.Now().Add(typ.TestOverrideCAChangeInitialDelay)))
require.True(t, shouldExpireAfter.After(time.Now()))
require.True(t, shouldExpireAfter.Before(time.Now().Add(typ.TestOverrideCAChangeInitialDelay)))
}
// Set the LastResult for subsequent fetches
opts.LastResult = &v
@ -406,8 +403,7 @@ func TestConnectCALeaf_changingRootsJitterBetweenCalls(t *testing.T) {
// Sanity check that we've not gone way beyond the deadline without a
// new cert. We give some leeway to make it less brittle.
require.Falsef(
time.Now().After(shouldExpireAfter.Add(100*time.Millisecond)),
require.Falsef(t, time.Now().After(shouldExpireAfter.Add(100*time.Millisecond)),
"waited extra 100ms and delayed CA rotate renew didn't happen")
}
}
@ -416,7 +412,6 @@ func TestConnectCALeaf_changingRootsJitterBetweenCalls(t *testing.T) {
func TestConnectCALeaf_changingRootsBetweenBlockingCalls(t *testing.T) {
t.Parallel()
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
@ -461,8 +456,8 @@ func TestConnectCALeaf_changingRootsBetweenBlockingCalls(t *testing.T) {
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
v := mustFetchResult(t, result)
require.Equal(resp, v.Value)
require.Equal(uint64(1), v.Index)
require.Equal(t, resp, v.Value)
require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
}
@ -475,11 +470,11 @@ func TestConnectCALeaf_changingRootsBetweenBlockingCalls(t *testing.T) {
t.Fatal("shouldn't block for too long waiting for fetch")
case result := <-fetchCh:
v := mustFetchResult(t, result)
require.Equal(resp, v.Value)
require.Equal(t, resp, v.Value)
// Still the initial cached result
require.Equal(uint64(1), v.Index)
require.Equal(t, uint64(1), v.Index)
// Sanity check that it waited
require.True(time.Since(start) > opts.Timeout)
require.True(t, time.Since(start) > opts.Timeout)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
}
@ -507,11 +502,11 @@ func TestConnectCALeaf_changingRootsBetweenBlockingCalls(t *testing.T) {
t.Fatal("shouldn't block too long waiting for fetch")
case result := <-fetchCh:
v := mustFetchResult(t, result)
require.Equal(resp, v.Value)
require.Equal(t, resp, v.Value)
// Index should be 3 since root change consumed 2
require.Equal(uint64(3), v.Index)
require.Equal(t, uint64(3), v.Index)
// Sanity check that we didn't wait too long
require.True(time.Since(earliestRootDelivery) < opts.Timeout)
require.True(t, time.Since(earliestRootDelivery) < opts.Timeout)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
}
@ -525,7 +520,6 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
t.Parallel()
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
@ -594,8 +588,8 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
case result := <-fetchCh:
switch v := result.(type) {
case error:
require.Error(v)
require.Equal(consul.ErrRateLimited.Error(), v.Error())
require.Error(t, v)
require.Equal(t, consul.ErrRateLimited.Error(), v.Error())
case cache.FetchResult:
t.Fatalf("Expected error")
}
@ -608,8 +602,8 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
v := mustFetchResult(t, result)
require.Equal(resp, v.Value)
require.Equal(uint64(1), v.Index)
require.Equal(t, resp, v.Value)
require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
// Set MinIndex
@ -633,7 +627,7 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
earliestRootDelivery := time.Now()
// Sanity check state
require.Equal(uint64(1), atomic.LoadUint64(&rateLimitedRPCs))
require.Equal(t, uint64(1), atomic.LoadUint64(&rateLimitedRPCs))
// After root rotation jitter has been waited out, a new CSR will
// be attempted but will fail and return the previous cached result with no
@ -646,14 +640,14 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
// We should block for _at least_ one jitter period since we set that to
// 100ms and in test override mode we always pick the max jitter not a
// random amount.
require.True(time.Since(earliestRootDelivery) > 100*time.Millisecond)
require.Equal(uint64(2), atomic.LoadUint64(&rateLimitedRPCs))
require.True(t, time.Since(earliestRootDelivery) > 100*time.Millisecond)
require.Equal(t, uint64(2), atomic.LoadUint64(&rateLimitedRPCs))
v := mustFetchResult(t, result)
require.Equal(resp, v.Value)
require.Equal(t, resp, v.Value)
// 1 since this should still be the original cached result as we failed to
// get a new cert.
require.Equal(uint64(1), v.Index)
require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
}
@ -667,14 +661,14 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
t.Fatal("shouldn't block too long waiting for fetch")
case result := <-fetchCh:
// We should block for _at least_ two jitter periods now.
require.True(time.Since(earliestRootDelivery) > 200*time.Millisecond)
require.Equal(uint64(3), atomic.LoadUint64(&rateLimitedRPCs))
require.True(t, time.Since(earliestRootDelivery) > 200*time.Millisecond)
require.Equal(t, uint64(3), atomic.LoadUint64(&rateLimitedRPCs))
v := mustFetchResult(t, result)
require.Equal(resp, v.Value)
require.Equal(t, resp, v.Value)
// 1 since this should still be the original cached result as we failed to
// get a new cert.
require.Equal(uint64(1), v.Index)
require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
}
@ -689,13 +683,13 @@ func TestConnectCALeaf_CSRRateLimiting(t *testing.T) {
t.Fatal("shouldn't block too long waiting for fetch")
case result := <-fetchCh:
// We should block for _at least_ three jitter periods now.
require.True(time.Since(earliestRootDelivery) > 300*time.Millisecond)
require.Equal(uint64(3), atomic.LoadUint64(&rateLimitedRPCs))
require.True(t, time.Since(earliestRootDelivery) > 300*time.Millisecond)
require.Equal(t, uint64(3), atomic.LoadUint64(&rateLimitedRPCs))
v := mustFetchResult(t, result)
require.Equal(resp, v.Value)
require.Equal(t, resp, v.Value)
// 3 since the rootCA change used 2
require.Equal(uint64(3), v.Index)
require.Equal(t, uint64(3), v.Index)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
}
@ -909,7 +903,6 @@ func TestConnectCALeaf_expiringLeaf(t *testing.T) {
t.Parallel()
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
@ -963,10 +956,10 @@ func TestConnectCALeaf_expiringLeaf(t *testing.T) {
case result := <-fetchCh:
switch v := result.(type) {
case error:
require.NoError(v)
require.NoError(t, v)
case cache.FetchResult:
require.Equal(resp, v.Value)
require.Equal(uint64(1), v.Index)
require.Equal(t, resp, v.Value)
require.Equal(t, uint64(1), v.Index)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
}
@ -981,10 +974,10 @@ func TestConnectCALeaf_expiringLeaf(t *testing.T) {
case result := <-fetchCh:
switch v := result.(type) {
case error:
require.NoError(v)
require.NoError(t, v)
case cache.FetchResult:
require.Equal(resp, v.Value)
require.Equal(uint64(2), v.Index)
require.Equal(t, resp, v.Value)
require.Equal(t, uint64(2), v.Index)
// Set the LastResult for subsequent fetches
opts.LastResult = &v
}
@ -1004,7 +997,6 @@ func TestConnectCALeaf_expiringLeaf(t *testing.T) {
func TestConnectCALeaf_DNSSANForService(t *testing.T) {
t.Parallel()
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
@ -1040,12 +1032,12 @@ func TestConnectCALeaf_DNSSANForService(t *testing.T) {
DNSSAN: []string{"test.example.com"},
}
_, err := typ.Fetch(opts, req)
require.NoError(err)
require.NoError(t, err)
pemBlock, _ := pem.Decode([]byte(caReq.CSR))
csr, err := x509.ParseCertificateRequest(pemBlock.Bytes)
require.NoError(err)
require.Equal(csr.DNSNames, []string{"test.example.com"})
require.NoError(t, err)
require.Equal(t, csr.DNSNames, []string{"test.example.com"})
}
// testConnectCaRoot wraps ConnectCARoot to disable refresh so that the gated

View File

@ -11,7 +11,6 @@ import (
)
func TestConnectCARoot(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &ConnectCARoot{RPC: rpc}
@ -22,8 +21,8 @@ func TestConnectCARoot(t *testing.T) {
rpc.On("RPC", "ConnectCA.Roots", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.DCSpecificRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal(t, uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(t, 1*time.Second, req.QueryOptions.MaxQueryTime)
reply := args.Get(2).(*structs.IndexedCARoots)
reply.QueryMeta.Index = 48
@ -35,15 +34,14 @@ func TestConnectCARoot(t *testing.T) {
MinIndex: 24,
Timeout: 1 * time.Second,
}, &structs.DCSpecificRequest{Datacenter: "dc1"})
require.Nil(err)
require.Equal(cache.FetchResult{
require.Nil(t, err)
require.Equal(t, cache.FetchResult{
Value: resp,
Index: 48,
}, result)
}
func TestConnectCARoot_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &ConnectCARoot{RPC: rpc}
@ -51,7 +49,7 @@ func TestConnectCARoot_badReqType(t *testing.T) {
// Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.NotNil(err)
require.Contains(err.Error(), "wrong type")
require.NotNil(t, err)
require.Contains(t, err.Error(), "wrong type")
}

View File

@ -11,7 +11,6 @@ import (
)
func TestHealthServices(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &HealthServices{RPC: rpc}
@ -22,10 +21,10 @@ func TestHealthServices(t *testing.T) {
rpc.On("RPC", "Health.ServiceNodes", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.ServiceSpecificRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal("web", req.ServiceName)
require.True(req.AllowStale)
require.Equal(t, uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(t, 1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal(t, "web", req.ServiceName)
require.True(t, req.AllowStale)
reply := args.Get(2).(*structs.IndexedCheckServiceNodes)
reply.Nodes = []structs.CheckServiceNode{
@ -44,15 +43,14 @@ func TestHealthServices(t *testing.T) {
ServiceName: "web",
ServiceTags: []string{"tag1", "tag2"},
})
require.NoError(err)
require.Equal(cache.FetchResult{
require.NoError(t, err)
require.Equal(t, cache.FetchResult{
Value: resp,
Index: 48,
}, resultA)
}
func TestHealthServices_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &HealthServices{RPC: rpc}
@ -60,7 +58,7 @@ func TestHealthServices_badReqType(t *testing.T) {
// Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err)
require.Contains(err.Error(), "wrong type")
require.Error(t, err)
require.Contains(t, err.Error(), "wrong type")
}

View File

@ -11,7 +11,6 @@ import (
)
func TestIntentionMatch(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &IntentionMatch{RPC: rpc}
@ -22,8 +21,8 @@ func TestIntentionMatch(t *testing.T) {
rpc.On("RPC", "Intention.Match", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.IntentionQueryRequest)
require.Equal(uint64(24), req.MinQueryIndex)
require.Equal(1*time.Second, req.MaxQueryTime)
require.Equal(t, uint64(24), req.MinQueryIndex)
require.Equal(t, 1*time.Second, req.MaxQueryTime)
reply := args.Get(2).(*structs.IndexedIntentionMatches)
reply.Index = 48
@ -35,15 +34,14 @@ func TestIntentionMatch(t *testing.T) {
MinIndex: 24,
Timeout: 1 * time.Second,
}, &structs.IntentionQueryRequest{Datacenter: "dc1"})
require.NoError(err)
require.Equal(cache.FetchResult{
require.NoError(t, err)
require.Equal(t, cache.FetchResult{
Value: resp,
Index: 48,
}, result)
}
func TestIntentionMatch_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &IntentionMatch{RPC: rpc}
@ -51,7 +49,7 @@ func TestIntentionMatch_badReqType(t *testing.T) {
// Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err)
require.Contains(err.Error(), "wrong type")
require.Error(t, err)
require.Contains(t, err.Error(), "wrong type")
}

View File

@ -11,7 +11,6 @@ import (
)
func TestNodeServices(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &NodeServices{RPC: rpc}
@ -22,10 +21,10 @@ func TestNodeServices(t *testing.T) {
rpc.On("RPC", "Catalog.NodeServices", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.NodeSpecificRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal("node-01", req.Node)
require.True(req.AllowStale)
require.Equal(t, uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(t, 1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal(t, "node-01", req.Node)
require.True(t, req.AllowStale)
reply := args.Get(2).(*structs.IndexedNodeServices)
reply.NodeServices = &structs.NodeServices{
@ -49,15 +48,14 @@ func TestNodeServices(t *testing.T) {
Datacenter: "dc1",
Node: "node-01",
})
require.NoError(err)
require.Equal(cache.FetchResult{
require.NoError(t, err)
require.Equal(t, cache.FetchResult{
Value: resp,
Index: 48,
}, resultA)
}
func TestNodeServices_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &NodeServices{RPC: rpc}
@ -65,7 +63,7 @@ func TestNodeServices_badReqType(t *testing.T) {
// Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err)
require.Contains(err.Error(), "wrong type")
require.Error(t, err)
require.Contains(t, err.Error(), "wrong type")
}

View File

@ -10,7 +10,6 @@ import (
)
func TestPreparedQuery(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &PreparedQuery{RPC: rpc}
@ -21,9 +20,9 @@ func TestPreparedQuery(t *testing.T) {
rpc.On("RPC", "PreparedQuery.Execute", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.PreparedQueryExecuteRequest)
require.Equal("geo-db", req.QueryIDOrName)
require.Equal(10, req.Limit)
require.True(req.AllowStale)
require.Equal(t, "geo-db", req.QueryIDOrName)
require.Equal(t, 10, req.Limit)
require.True(t, req.AllowStale)
reply := args.Get(2).(*structs.PreparedQueryExecuteResponse)
reply.QueryMeta.Index = 48
@ -36,15 +35,14 @@ func TestPreparedQuery(t *testing.T) {
QueryIDOrName: "geo-db",
Limit: 10,
})
require.NoError(err)
require.Equal(cache.FetchResult{
require.NoError(t, err)
require.Equal(t, cache.FetchResult{
Value: resp,
Index: 48,
}, result)
}
func TestPreparedQuery_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &PreparedQuery{RPC: rpc}
@ -52,6 +50,6 @@ func TestPreparedQuery_badReqType(t *testing.T) {
// Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err)
require.Contains(err.Error(), "wrong type")
require.Error(t, err)
require.Contains(t, err.Error(), "wrong type")
}

View File

@ -11,7 +11,6 @@ import (
)
func TestResolvedServiceConfig(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &ResolvedServiceConfig{RPC: rpc}
@ -22,10 +21,10 @@ func TestResolvedServiceConfig(t *testing.T) {
rpc.On("RPC", "ConfigEntry.ResolveServiceConfig", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.ServiceConfigRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal("foo", req.Name)
require.True(req.AllowStale)
require.Equal(t, uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(t, 1*time.Second, req.QueryOptions.MaxQueryTime)
require.Equal(t, "foo", req.Name)
require.True(t, req.AllowStale)
reply := args.Get(2).(*structs.ServiceConfigResponse)
reply.ProxyConfig = map[string]interface{}{
@ -49,15 +48,14 @@ func TestResolvedServiceConfig(t *testing.T) {
Datacenter: "dc1",
Name: "foo",
})
require.NoError(err)
require.Equal(cache.FetchResult{
require.NoError(t, err)
require.Equal(t, cache.FetchResult{
Value: resp,
Index: 48,
}, resultA)
}
func TestResolvedServiceConfig_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &ResolvedServiceConfig{RPC: rpc}
@ -65,7 +63,7 @@ func TestResolvedServiceConfig_badReqType(t *testing.T) {
// Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err)
require.Contains(err.Error(), "wrong type")
require.Error(t, err)
require.Contains(t, err.Error(), "wrong type")
}

View File

@ -24,8 +24,6 @@ import (
func TestCacheGet_noIndex(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := New(Options{})
@ -37,15 +35,15 @@ func TestCacheGet_noIndex(t *testing.T) {
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// Get, should not fetch since we already have a satisfying value
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.True(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.True(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
@ -57,8 +55,6 @@ func TestCacheGet_noIndex(t *testing.T) {
func TestCacheGet_initError(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := New(Options{})
@ -71,15 +67,15 @@ func TestCacheGet_initError(t *testing.T) {
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req)
require.Error(err)
require.Nil(result)
require.False(meta.Hit)
require.Error(t, err)
require.Nil(t, result)
require.False(t, meta.Hit)
// Get, should fetch again since our last fetch was an error
result, meta, err = c.Get(context.Background(), "t", req)
require.Error(err)
require.Nil(result)
require.False(meta.Hit)
require.Error(t, err)
require.Nil(t, result)
require.False(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
@ -96,8 +92,6 @@ func TestCacheGet_cachedErrorsDontStick(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := New(Options{})
@ -115,15 +109,15 @@ func TestCacheGet_cachedErrorsDontStick(t *testing.T) {
// Get, should fetch and get error
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req)
require.Error(err)
require.Nil(result)
require.False(meta.Hit)
require.Error(t, err)
require.Nil(t, result)
require.False(t, meta.Hit)
// Get, should fetch again since our last fetch was an error, but get success
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// Now get should block until timeout and then get the same response NOT the
// cached error.
@ -157,8 +151,6 @@ func TestCacheGet_cachedErrorsDontStick(t *testing.T) {
func TestCacheGet_blankCacheKey(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := New(Options{})
@ -170,15 +162,15 @@ func TestCacheGet_blankCacheKey(t *testing.T) {
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: ""})
result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// Get, should not fetch since we already have a satisfying value
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
@ -225,8 +217,6 @@ func TestCacheGet_blockingInitSameKey(t *testing.T) {
func TestCacheGet_blockingInitDiffKeys(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := New(Options{})
@ -269,7 +259,7 @@ func TestCacheGet_blockingInitDiffKeys(t *testing.T) {
// Verify proper keys
sort.Strings(keys)
require.Equal([]string{"goodbye", "hello"}, keys)
require.Equal(t, []string{"goodbye", "hello"}, keys)
}
// Test a get with an index set will wait until an index that is higher
@ -414,8 +404,6 @@ func TestCacheGet_emptyFetchResult(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := New(Options{})
@ -429,29 +417,29 @@ func TestCacheGet_emptyFetchResult(t *testing.T) {
typ.Static(FetchResult{Value: nil, State: 32}, nil).Run(func(args mock.Arguments) {
// We should get back the original state
opts := args.Get(0).(FetchOptions)
require.NotNil(opts.LastResult)
require.NotNil(t, opts.LastResult)
stateCh <- opts.LastResult.State.(int)
})
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// Get, should not fetch since we already have a satisfying value
req = TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 1, Timeout: 100 * time.Millisecond})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// State delivered to second call should be the result from first call.
select {
case state := <-stateCh:
require.Equal(31, state)
require.Equal(t, 31, state)
case <-time.After(20 * time.Millisecond):
t.Fatal("timed out")
}
@ -461,12 +449,12 @@ func TestCacheGet_emptyFetchResult(t *testing.T) {
req = TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 1, Timeout: 100 * time.Millisecond})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
select {
case state := <-stateCh:
require.Equal(32, state)
require.Equal(t, 32, state)
case <-time.After(20 * time.Millisecond):
t.Fatal("timed out")
}
@ -737,8 +725,6 @@ func TestCacheGet_noIndexSetsOne(t *testing.T) {
func TestCacheGet_fetchTimeout(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := &MockType{}
timeout := 10 * time.Minute
typ.On("RegisterOptions").Return(RegisterOptions{
@ -761,12 +747,12 @@ func TestCacheGet_fetchTimeout(t *testing.T) {
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// Test the timeout
require.Equal(timeout, actual)
require.Equal(t, timeout, actual)
}
// Test that entries expire
@ -777,8 +763,6 @@ func TestCacheGet_expire(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{
LastGetTTL: 400 * time.Millisecond,
@ -795,9 +779,9 @@ func TestCacheGet_expire(t *testing.T) {
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// Wait for a non-trivial amount of time to sanity check the age increases at
// least this amount. Note that this is not a fudge for some timing-dependent
@ -808,10 +792,10 @@ func TestCacheGet_expire(t *testing.T) {
// Get, should not fetch, verified via the mock assertions above
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.True(meta.Hit)
require.True(meta.Age > 5*time.Millisecond)
require.NoError(t, err)
require.Equal(t, 42, result)
require.True(t, meta.Hit)
require.True(t, meta.Age > 5*time.Millisecond)
// Sleep for the expiry
time.Sleep(500 * time.Millisecond)
@ -819,9 +803,9 @@ func TestCacheGet_expire(t *testing.T) {
// Get, should fetch
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen then verify
// that we still only got the one call
@ -837,8 +821,6 @@ func TestCacheGet_expire(t *testing.T) {
func TestCacheGet_expireBackgroudRefreshCancel(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{
LastGetTTL: 400 * time.Millisecond,
@ -879,18 +861,18 @@ func TestCacheGet_expireBackgroudRefreshCancel(t *testing.T) {
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(8, result)
require.Equal(uint64(4), meta.Index)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 8, result)
require.Equal(t, uint64(4), meta.Index)
require.False(t, meta.Hit)
// Get, should not fetch, verified via the mock assertions above
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(8, result)
require.Equal(uint64(4), meta.Index)
require.True(meta.Hit)
require.NoError(t, err)
require.Equal(t, 8, result)
require.Equal(t, uint64(4), meta.Index)
require.True(t, meta.Hit)
// Sleep for the expiry
time.Sleep(500 * time.Millisecond)
@ -898,10 +880,10 @@ func TestCacheGet_expireBackgroudRefreshCancel(t *testing.T) {
// Get, should fetch
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(8, result)
require.Equal(uint64(4), meta.Index)
require.False(meta.Hit, "the fetch should not have re-populated the cache "+
require.NoError(t, err)
require.Equal(t, 8, result)
require.Equal(t, uint64(4), meta.Index)
require.False(t, meta.Hit, "the fetch should not have re-populated the cache "+
"entry after it expired so this get should be a miss")
// Sleep a tiny bit just to let maybe some background calls happen
@ -915,8 +897,6 @@ func TestCacheGet_expireBackgroudRefreshCancel(t *testing.T) {
func TestCacheGet_expireBackgroudRefresh(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{
LastGetTTL: 400 * time.Millisecond,
@ -948,18 +928,18 @@ func TestCacheGet_expireBackgroudRefresh(t *testing.T) {
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(8, result)
require.Equal(uint64(4), meta.Index)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 8, result)
require.Equal(t, uint64(4), meta.Index)
require.False(t, meta.Hit)
// Get, should not fetch, verified via the mock assertions above
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(8, result)
require.Equal(uint64(4), meta.Index)
require.True(meta.Hit)
require.NoError(t, err)
require.Equal(t, 8, result)
require.Equal(t, uint64(4), meta.Index)
require.True(t, meta.Hit)
// Sleep for the expiry
time.Sleep(500 * time.Millisecond)
@ -971,10 +951,10 @@ func TestCacheGet_expireBackgroudRefresh(t *testing.T) {
// re-insert the value back into the cache and make it live forever).
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(8, result)
require.Equal(uint64(4), meta.Index)
require.False(meta.Hit, "the fetch should not have re-populated the cache "+
require.NoError(t, err)
require.Equal(t, 8, result)
require.Equal(t, uint64(4), meta.Index)
require.False(t, meta.Hit, "the fetch should not have re-populated the cache "+
"entry after it expired so this get should be a miss")
// Sleep a tiny bit just to let maybe some background calls happen
@ -991,8 +971,6 @@ func TestCacheGet_expireResetGet(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{
LastGetTTL: 150 * time.Millisecond,
@ -1009,9 +987,9 @@ func TestCacheGet_expireResetGet(t *testing.T) {
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// Fetch multiple times, where the total time is well beyond
// the TTL. We should not trigger any fetches during this time.
@ -1022,9 +1000,9 @@ func TestCacheGet_expireResetGet(t *testing.T) {
// Get, should not fetch
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.True(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.True(t, meta.Hit)
}
time.Sleep(200 * time.Millisecond)
@ -1032,9 +1010,9 @@ func TestCacheGet_expireResetGet(t *testing.T) {
// Get, should fetch
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
@ -1046,8 +1024,6 @@ func TestCacheGet_expireResetGet(t *testing.T) {
func TestCacheGet_expireResetGetNoChange(t *testing.T) {
t.Parallel()
require := require.New(t)
// Create a closer so we can tell if the entry gets evicted.
closer := &testCloser{}
@ -1080,19 +1056,19 @@ func TestCacheGet_expireResetGetNoChange(t *testing.T) {
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.Equal(uint64(10), meta.Index)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.Equal(t, uint64(10), meta.Index)
require.False(t, meta.Hit)
// Do a blocking watch of the value that won't time out until after the TTL.
start := time.Now()
req = TestRequest(t, RequestInfo{Key: "hello", MinIndex: 10, Timeout: 300 * time.Millisecond})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.Equal(uint64(10), meta.Index)
require.GreaterOrEqual(time.Since(start).Milliseconds(), int64(300))
require.NoError(t, err)
require.Equal(t, 42, result)
require.Equal(t, uint64(10), meta.Index)
require.GreaterOrEqual(t, time.Since(start).Milliseconds(), int64(300))
// This is the point of this test! Even though we waited for a change for
// longer than the TTL, we should have been updating the TTL so that the cache
@ -1100,7 +1076,7 @@ func TestCacheGet_expireResetGetNoChange(t *testing.T) {
// since that is not set for blocking Get calls but we can assert that the
// entry was never closed (which assuming the test for eviction closing is
// also passing is a reliable signal).
require.False(closer.isClosed(), "cache entry should not have been evicted")
require.False(t, closer.isClosed(), "cache entry should not have been evicted")
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
@ -1116,8 +1092,6 @@ func TestCacheGet_expireClose(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := &MockType{}
defer typ.AssertExpectations(t)
c := New(Options{})
@ -1137,16 +1111,16 @@ func TestCacheGet_expireClose(t *testing.T) {
ctx := context.Background()
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(ctx, "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.False(state.isClosed())
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
require.False(t, state.isClosed())
// Sleep for the expiry
time.Sleep(200 * time.Millisecond)
// state.Close() should have been called
require.True(state.isClosed())
require.True(t, state.isClosed())
}
type testCloser struct {
@ -1171,8 +1145,6 @@ func (t *testCloser) isClosed() bool {
func TestCacheGet_duplicateKeyDifferentType(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
typ2 := TestType(t)
@ -1189,23 +1161,23 @@ func TestCacheGet_duplicateKeyDifferentType(t *testing.T) {
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "foo"})
result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(100, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 100, result)
require.False(t, meta.Hit)
// Get from t2 with same key, should fetch
req = TestRequest(t, RequestInfo{Key: "foo"})
result, meta, err = c.Get(context.Background(), "t2", req)
require.NoError(err)
require.Equal(200, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 200, result)
require.False(t, meta.Hit)
// Get from t again with same key, should cache
req = TestRequest(t, RequestInfo{Key: "foo"})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(100, result)
require.True(meta.Hit)
require.NoError(t, err)
require.Equal(t, 100, result)
require.True(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
@ -1283,8 +1255,6 @@ func TestCacheGet_refreshAge(t *testing.T) {
}
t.Parallel()
require := require.New(t)
typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{
Refresh: true,
@ -1330,11 +1300,11 @@ func TestCacheGet_refreshAge(t *testing.T) {
// Fetch again, non-blocking
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err)
require.Equal(8, result)
require.True(meta.Hit)
require.NoError(t, err)
require.Equal(t, 8, result)
require.True(t, meta.Hit)
// Age should be zero since background refresh was "active"
require.Equal(time.Duration(0), meta.Age)
require.Equal(t, time.Duration(0), meta.Age)
}
// Now fail the next background sync
@ -1350,21 +1320,21 @@ func TestCacheGet_refreshAge(t *testing.T) {
var lastAge time.Duration
{
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err)
require.Equal(8, result)
require.True(meta.Hit)
require.NoError(t, err)
require.Equal(t, 8, result)
require.True(t, meta.Hit)
// Age should be non-zero since background refresh was "active"
require.True(meta.Age > 0)
require.True(t, meta.Age > 0)
lastAge = meta.Age
}
// Wait a bit longer - age should increase by at least this much
time.Sleep(5 * time.Millisecond)
{
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err)
require.Equal(8, result)
require.True(meta.Hit)
require.True(meta.Age > (lastAge + (1 * time.Millisecond)))
require.NoError(t, err)
require.Equal(t, 8, result)
require.True(t, meta.Hit)
require.True(t, meta.Age > (lastAge+(1*time.Millisecond)))
}
// Now unfail the background refresh
@ -1384,18 +1354,18 @@ func TestCacheGet_refreshAge(t *testing.T) {
time.Sleep(100 * time.Millisecond)
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
// Should never error even if background is failing as we have cached value
require.NoError(err)
require.True(meta.Hit)
require.NoError(t, err)
require.True(t, meta.Hit)
// Got the new value!
if result == 10 {
// Age should be zero since background refresh is "active" again
t.Logf("Succeeded after %d attempts", attempts)
require.Equal(time.Duration(0), meta.Age)
require.Equal(t, time.Duration(0), meta.Age)
timeout = false
break
}
}
require.False(timeout, "failed to observe update after %s", time.Since(t0))
require.False(t, timeout, "failed to observe update after %s", time.Since(t0))
}
func TestCacheGet_nonRefreshAge(t *testing.T) {
@ -1405,8 +1375,6 @@ func TestCacheGet_nonRefreshAge(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := &MockType{}
typ.On("RegisterOptions").Return(RegisterOptions{
Refresh: false,
@ -1440,10 +1408,10 @@ func TestCacheGet_nonRefreshAge(t *testing.T) {
// Fetch again, non-blocking
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err)
require.Equal(8, result)
require.True(meta.Hit)
require.True(meta.Age > (5 * time.Millisecond))
require.NoError(t, err)
require.Equal(t, 8, result)
require.True(t, meta.Hit)
require.True(t, meta.Age > (5*time.Millisecond))
lastAge = meta.Age
}
@ -1452,11 +1420,11 @@ func TestCacheGet_nonRefreshAge(t *testing.T) {
{
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err)
require.Equal(8, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 8, result)
require.False(t, meta.Hit)
// Age should smaller again
require.True(meta.Age < lastAge)
require.True(t, meta.Age < lastAge)
}
{
@ -1468,10 +1436,10 @@ func TestCacheGet_nonRefreshAge(t *testing.T) {
// Fetch again, non-blocking
result, meta, err := c.Get(context.Background(), "t", TestRequest(t, RequestInfo{Key: "hello"}))
require.NoError(err)
require.Equal(8, result)
require.True(meta.Hit)
require.True(meta.Age > (5 * time.Millisecond))
require.NoError(t, err)
require.Equal(t, 8, result)
require.True(t, meta.Hit)
require.True(t, meta.Age > (5*time.Millisecond))
lastAge = meta.Age
}
@ -1481,11 +1449,11 @@ func TestCacheGet_nonRefreshAge(t *testing.T) {
Key: "hello",
MaxAge: 1 * time.Millisecond,
}))
require.NoError(err)
require.Equal(8, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 8, result)
require.False(t, meta.Hit)
// Age should smaller again
require.True(meta.Age < lastAge)
require.True(t, meta.Age < lastAge)
}
}
@ -1505,21 +1473,19 @@ func TestCacheGet_nonBlockingType(t *testing.T) {
require.Equal(t, uint64(0), opts.MinIndex)
})
require := require.New(t)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.False(t, meta.Hit)
// Get, should not fetch since we have a cached value
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.True(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.True(t, meta.Hit)
// Get, should not attempt to fetch with blocking even if requested. The
// assertions below about the value being the same combined with the fact the
@ -1531,25 +1497,25 @@ func TestCacheGet_nonBlockingType(t *testing.T) {
Timeout: 10 * time.Minute,
})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(42, result)
require.True(meta.Hit)
require.NoError(t, err)
require.Equal(t, 42, result)
require.True(t, meta.Hit)
time.Sleep(10 * time.Millisecond)
// Get with a max age should fetch again
req = TestRequest(t, RequestInfo{Key: "hello", MaxAge: 5 * time.Millisecond})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(43, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 43, result)
require.False(t, meta.Hit)
// Get with a must revalidate should fetch again even without a delay.
req = TestRequest(t, RequestInfo{Key: "hello", MustRevalidate: true})
result, meta, err = c.Get(context.Background(), "t", req)
require.NoError(err)
require.Equal(43, result)
require.False(meta.Hit)
require.NoError(t, err)
require.Equal(t, 43, result)
require.False(t, meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call

View File

@ -51,15 +51,13 @@ func TestCacheNotify(t *testing.T) {
// after cancellation as if it had timed out.
typ.Static(FetchResult{Value: 42, Index: 8}, nil).WaitUntil(trigger[4])
require := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := make(chan UpdateEvent)
err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello"}), "test", ch)
require.NoError(err)
require.NoError(t, err)
// Should receive the error with index == 0 first.
TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -70,7 +68,7 @@ func TestCacheNotify(t *testing.T) {
})
// There should be no more updates delivered yet
require.Len(ch, 0)
require.Len(t, ch, 0)
// Trigger blocking query to return a "change"
close(trigger[0])
@ -102,7 +100,7 @@ func TestCacheNotify(t *testing.T) {
// requests to the "backend"
// - that multiple watchers can distinguish their results using correlationID
err = c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello"}), "test2", ch)
require.NoError(err)
require.NoError(t, err)
// Should get test2 notify immediately, and it should be a cache hit
TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -121,7 +119,7 @@ func TestCacheNotify(t *testing.T) {
// it's only a sanity check, if we somehow _do_ get the change delivered later
// than 10ms the next value assertion will fail anyway.
time.Sleep(10 * time.Millisecond)
require.Len(ch, 0)
require.Len(t, ch, 0)
// Trigger final update
close(trigger[3])
@ -183,15 +181,13 @@ func TestCacheNotifyPolling(t *testing.T) {
typ.Static(FetchResult{Value: 12, Index: 1}, nil).Once()
typ.Static(FetchResult{Value: 42, Index: 1}, nil).Once()
require := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := make(chan UpdateEvent)
err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello", MaxAge: 100 * time.Millisecond}), "test", ch)
require.NoError(err)
require.NoError(t, err)
// Should receive the first result pretty soon
TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -202,32 +198,32 @@ func TestCacheNotifyPolling(t *testing.T) {
})
// There should be no more updates delivered yet
require.Len(ch, 0)
require.Len(t, ch, 0)
// make sure the updates do not come too quickly
select {
case <-time.After(50 * time.Millisecond):
case <-ch:
require.Fail("Received update too early")
require.Fail(t, "Received update too early")
}
// make sure we get the update not too far out.
select {
case <-time.After(100 * time.Millisecond):
require.Fail("Didn't receive the notification")
require.Fail(t, "Didn't receive the notification")
case result := <-ch:
require.Equal(result.Result, 12)
require.Equal(result.CorrelationID, "test")
require.Equal(result.Meta.Hit, false)
require.Equal(result.Meta.Index, uint64(1))
require.Equal(t, result.Result, 12)
require.Equal(t, result.CorrelationID, "test")
require.Equal(t, result.Meta.Hit, false)
require.Equal(t, result.Meta.Index, uint64(1))
// pretty conservative check it should be even newer because without a second
// notifier each value returned will have been executed just then and not served
// from the cache.
require.True(result.Meta.Age < 50*time.Millisecond)
require.NoError(result.Err)
require.True(t, result.Meta.Age < 50*time.Millisecond)
require.NoError(t, result.Err)
}
require.Len(ch, 0)
require.Len(t, ch, 0)
// Register a second observer using same chan and request. Note that this is
// testing a few things implicitly:
@ -235,7 +231,7 @@ func TestCacheNotifyPolling(t *testing.T) {
// requests to the "backend"
// - that multiple watchers can distinguish their results using correlationID
err = c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello", MaxAge: 100 * time.Millisecond}), "test2", ch)
require.NoError(err)
require.NoError(t, err)
// Should get test2 notify immediately, and it should be a cache hit
TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -245,7 +241,7 @@ func TestCacheNotifyPolling(t *testing.T) {
Err: nil,
})
require.Len(ch, 0)
require.Len(t, ch, 0)
// wait for the next batch of responses
events := make([]UpdateEvent, 0)
@ -255,25 +251,25 @@ func TestCacheNotifyPolling(t *testing.T) {
for i := 0; i < 2; i++ {
select {
case <-timeout:
require.Fail("UpdateEvent not received in time")
require.Fail(t, "UpdateEvent not received in time")
case eve := <-ch:
events = append(events, eve)
}
}
require.Equal(events[0].Result, 42)
require.Equal(events[0].Meta.Hit && events[1].Meta.Hit, false)
require.Equal(events[0].Meta.Index, uint64(1))
require.True(events[0].Meta.Age < 50*time.Millisecond)
require.NoError(events[0].Err)
require.Equal(events[1].Result, 42)
require.Equal(t, events[0].Result, 42)
require.Equal(t, events[0].Meta.Hit && events[1].Meta.Hit, false)
require.Equal(t, events[0].Meta.Index, uint64(1))
require.True(t, events[0].Meta.Age < 50*time.Millisecond)
require.NoError(t, events[0].Err)
require.Equal(t, events[1].Result, 42)
// Sometimes this would be a hit and others not. It all depends on when the various getWithIndex calls got fired.
// If both are done concurrently then it will not be a cache hit but the request gets single flighted and both
// get notified at the same time.
// require.Equal(events[1].Meta.Hit, true)
require.Equal(events[1].Meta.Index, uint64(1))
require.True(events[1].Meta.Age < 100*time.Millisecond)
require.NoError(events[1].Err)
// require.Equal(t,events[1].Meta.Hit, true)
require.Equal(t, events[1].Meta.Index, uint64(1))
require.True(t, events[1].Meta.Age < 100*time.Millisecond)
require.NoError(t, events[1].Err)
}
// Test that a refresh performs a backoff.
@ -298,15 +294,13 @@ func TestCacheWatch_ErrorBackoff(t *testing.T) {
atomic.AddUint32(&retries, 1)
})
require := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := make(chan UpdateEvent)
err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello"}), "test", ch)
require.NoError(err)
require.NoError(t, err)
// Should receive the first result pretty soon
TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -331,15 +325,15 @@ OUT:
break OUT
case u := <-ch:
numErrors++
require.Error(u.Err)
require.Error(t, u.Err)
}
}
// Must be fewer than 10 failures in that time
require.True(numErrors < 10, fmt.Sprintf("numErrors: %d", numErrors))
require.True(t, numErrors < 10, fmt.Sprintf("numErrors: %d", numErrors))
// Check the number of RPCs as a sanity check too
actual := atomic.LoadUint32(&retries)
require.True(actual < 10, fmt.Sprintf("actual: %d", actual))
require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
}
// Test that a refresh performs a backoff.
@ -363,15 +357,13 @@ func TestCacheWatch_ErrorBackoffNonBlocking(t *testing.T) {
atomic.AddUint32(&retries, 1)
})
require := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := make(chan UpdateEvent)
err := c.Notify(ctx, "t", TestRequest(t, RequestInfo{Key: "hello", MaxAge: 100 * time.Millisecond}), "test", ch)
require.NoError(err)
require.NoError(t, err)
// Should receive the first result pretty soon
TestCacheNotifyChResult(t, ch, UpdateEvent{
@ -399,13 +391,13 @@ OUT:
break OUT
case u := <-ch:
numErrors++
require.Error(u.Err)
require.Error(t, u.Err)
}
}
// Must be fewer than 10 failures in that time
require.True(numErrors < 10, fmt.Sprintf("numErrors: %d", numErrors))
require.True(t, numErrors < 10, fmt.Sprintf("numErrors: %d", numErrors))
// Check the number of RPCs as a sanity check too
actual := atomic.LoadUint32(&retries)
require.True(actual < 10, fmt.Sprintf("actual: %d", actual))
require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
}

View File

@ -136,9 +136,7 @@ func (s *HTTPHandlers) CatalogRegister(resp http.ResponseWriter, req *http.Reque
}
if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Setup the default DC if not provided
@ -168,9 +166,7 @@ func (s *HTTPHandlers) CatalogDeregister(resp http.ResponseWriter, req *http.Req
return nil, err
}
if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Setup the default DC if not provided
@ -367,9 +363,7 @@ func (s *HTTPHandlers) catalogServiceNodes(resp http.ResponseWriter, req *http.R
return nil, err
}
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service name")
return nil, nil
return nil, BadRequestError{Reason: "Missing service name"}
}
// Make the RPC request
@ -444,9 +438,7 @@ func (s *HTTPHandlers) CatalogNodeServices(resp http.ResponseWriter, req *http.R
return nil, err
}
if args.Node == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing node name")
return nil, nil
return nil, BadRequestError{Reason: "Missing node name"}
}
// Make the RPC request
@ -511,9 +503,7 @@ func (s *HTTPHandlers) CatalogNodeServiceList(resp http.ResponseWriter, req *htt
return nil, err
}
if args.Node == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing node name")
return nil, nil
return nil, BadRequestError{Reason: "Missing node name"}
}
// Make the RPC request
@ -564,9 +554,7 @@ func (s *HTTPHandlers) CatalogGatewayServices(resp http.ResponseWriter, req *htt
return nil, err
}
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing gateway name")
return nil, nil
return nil, BadRequestError{Reason: "Missing gateway name"}
}
// Make the RPC request

View File

@ -635,9 +635,6 @@ func TestCatalogServiceNodes(t *testing.T) {
a := NewTestAgent(t, "")
defer a.Shutdown()
assert := assert.New(t)
require := require.New(t)
// Make sure an empty list is returned, not a nil
{
req, _ := http.NewRequest("GET", "/v1/catalog/service/api?tag=a", nil)
@ -691,12 +688,12 @@ func TestCatalogServiceNodes(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/catalog/service/api?cached", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogServiceNodes(resp, req)
require.NoError(err)
require.NoError(t, err)
nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1)
assert.Len(t, nodes, 1)
// Should be a cache miss
assert.Equal("MISS", resp.Header().Get("X-Cache"))
assert.Equal(t, "MISS", resp.Header().Get("X-Cache"))
}
{
@ -704,13 +701,13 @@ func TestCatalogServiceNodes(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/catalog/service/api?cached", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogServiceNodes(resp, req)
require.NoError(err)
require.NoError(t, err)
nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1)
assert.Len(t, nodes, 1)
// Should be a cache HIT now!
assert.Equal("HIT", resp.Header().Get("X-Cache"))
assert.Equal("0", resp.Header().Get("Age"))
assert.Equal(t, "HIT", resp.Header().Get("X-Cache"))
assert.Equal(t, "0", resp.Header().Get("Age"))
}
// Ensure background refresh works
@ -719,7 +716,7 @@ func TestCatalogServiceNodes(t *testing.T) {
args2 := args
args2.Node = "bar"
args2.Address = "127.0.0.2"
require.NoError(a.RPC("Catalog.Register", args, &out))
require.NoError(t, a.RPC("Catalog.Register", args, &out))
retry.Run(t, func(r *retry.R) {
// List it again
@ -1057,7 +1054,6 @@ func TestCatalogServiceNodes_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
@ -1065,19 +1061,19 @@ func TestCatalogServiceNodes_ConnectProxy(t *testing.T) {
// Register
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
assert.Nil(t, a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/service/%s", args.Service.Service), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogServiceNodes(resp, req)
assert.Nil(err)
assert.Nil(t, err)
assertIndex(t, resp)
nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1)
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal(args.Service.Proxy, nodes[0].ServiceProxy)
assert.Len(t, nodes, 1)
assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal(t, args.Service.Proxy, nodes[0].ServiceProxy)
}
// Test that the Connect-compatible endpoints can be queried for a
@ -1089,7 +1085,6 @@ func TestCatalogConnectServiceNodes_good(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
@ -1098,20 +1093,20 @@ func TestCatalogConnectServiceNodes_good(t *testing.T) {
args := structs.TestRegisterRequestProxy(t)
args.Service.Address = "127.0.0.55"
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
assert.Nil(t, a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/connect/%s", args.Service.Proxy.DestinationServiceName), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogConnectServiceNodes(resp, req)
assert.Nil(err)
assert.Nil(t, err)
assertIndex(t, resp)
nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1)
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal(args.Service.Address, nodes[0].ServiceAddress)
assert.Equal(args.Service.Proxy, nodes[0].ServiceProxy)
assert.Len(t, nodes, 1)
assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal(t, args.Service.Address, nodes[0].ServiceAddress)
assert.Equal(t, args.Service.Proxy, nodes[0].ServiceProxy)
}
func TestCatalogConnectServiceNodes_Filter(t *testing.T) {
@ -1307,7 +1302,6 @@ func TestCatalogNodeServices_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -1315,19 +1309,19 @@ func TestCatalogNodeServices_ConnectProxy(t *testing.T) {
// Register
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
assert.Nil(t, a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/node/%s", args.Node), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogNodeServices(resp, req)
assert.Nil(err)
assert.Nil(t, err)
assertIndex(t, resp)
ns := obj.(*structs.NodeServices)
assert.Len(ns.Services, 1)
assert.Len(t, ns.Services, 1)
v := ns.Services[args.Service.Service]
assert.Equal(structs.ServiceKindConnectProxy, v.Kind)
assert.Equal(t, structs.ServiceKindConnectProxy, v.Kind)
}
func TestCatalogNodeServices_WanTranslation(t *testing.T) {

View File

@ -88,7 +88,6 @@ enable_acl_replication = true
func TestLoad_DeprecatedConfig_ACLMasterTokens(t *testing.T) {
t.Run("top-level fields", func(t *testing.T) {
require := require.New(t)
opts := LoadOpts{
HCL: []string{`
@ -101,21 +100,20 @@ func TestLoad_DeprecatedConfig_ACLMasterTokens(t *testing.T) {
patchLoadOptsShims(&opts)
result, err := Load(opts)
require.NoError(err)
require.NoError(t, err)
expectWarns := []string{
deprecationWarning("acl_master_token", "acl.tokens.initial_management"),
deprecationWarning("acl_agent_master_token", "acl.tokens.agent_recovery"),
}
require.ElementsMatch(expectWarns, result.Warnings)
require.ElementsMatch(t, expectWarns, result.Warnings)
rt := result.RuntimeConfig
require.Equal("token1", rt.ACLInitialManagementToken)
require.Equal("token2", rt.ACLTokens.ACLAgentRecoveryToken)
require.Equal(t, "token1", rt.ACLInitialManagementToken)
require.Equal(t, "token2", rt.ACLTokens.ACLAgentRecoveryToken)
})
t.Run("embedded in tokens struct", func(t *testing.T) {
require := require.New(t)
opts := LoadOpts{
HCL: []string{`
@ -132,21 +130,20 @@ func TestLoad_DeprecatedConfig_ACLMasterTokens(t *testing.T) {
patchLoadOptsShims(&opts)
result, err := Load(opts)
require.NoError(err)
require.NoError(t, err)
expectWarns := []string{
deprecationWarning("acl.tokens.master", "acl.tokens.initial_management"),
deprecationWarning("acl.tokens.agent_master", "acl.tokens.agent_recovery"),
}
require.ElementsMatch(expectWarns, result.Warnings)
require.ElementsMatch(t, expectWarns, result.Warnings)
rt := result.RuntimeConfig
require.Equal("token1", rt.ACLInitialManagementToken)
require.Equal("token2", rt.ACLTokens.ACLAgentRecoveryToken)
require.Equal(t, "token1", rt.ACLInitialManagementToken)
require.Equal(t, "token2", rt.ACLTokens.ACLAgentRecoveryToken)
})
t.Run("both", func(t *testing.T) {
require := require.New(t)
opts := LoadOpts{
HCL: []string{`
@ -166,10 +163,10 @@ func TestLoad_DeprecatedConfig_ACLMasterTokens(t *testing.T) {
patchLoadOptsShims(&opts)
result, err := Load(opts)
require.NoError(err)
require.NoError(t, err)
rt := result.RuntimeConfig
require.Equal("token3", rt.ACLInitialManagementToken)
require.Equal("token4", rt.ACLTokens.ACLAgentRecoveryToken)
require.Equal(t, "token3", rt.ACLInitialManagementToken)
require.Equal(t, "token4", rt.ACLTokens.ACLAgentRecoveryToken)
})
}

View File

@ -90,16 +90,12 @@ func (s *HTTPHandlers) configDelete(resp http.ResponseWriter, req *http.Request)
pathArgs := strings.SplitN(kindAndName, "/", 2)
if len(pathArgs) != 2 {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprintf(resp, "Must provide both a kind and name to delete")
return nil, nil
return nil, NotFoundError{Reason: "Must provide both a kind and name to delete"}
}
entry, err := structs.MakeConfigEntry(pathArgs[0], pathArgs[1])
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "%v", err)
return nil, nil
return nil, BadRequestError{Reason: err.Error()}
}
args.Entry = entry
// Parse enterprise meta.

View File

@ -149,7 +149,6 @@ func TestConfig_Delete(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -171,7 +170,7 @@ func TestConfig_Delete(t *testing.T) {
}
for _, req := range reqs {
out := false
require.NoError(a.RPC("ConfigEntry.Apply", &req, &out))
require.NoError(t, a.RPC("ConfigEntry.Apply", &req, &out))
}
// Delete an entry.
@ -179,7 +178,7 @@ func TestConfig_Delete(t *testing.T) {
req, _ := http.NewRequest("DELETE", "/v1/config/service-defaults/bar", nil)
resp := httptest.NewRecorder()
_, err := a.srv.Config(resp, req)
require.NoError(err)
require.NoError(t, err)
}
// Get the remaining entry.
{
@ -188,11 +187,11 @@ func TestConfig_Delete(t *testing.T) {
Datacenter: "dc1",
}
var out structs.IndexedConfigEntries
require.NoError(a.RPC("ConfigEntry.List", &args, &out))
require.Equal(structs.ServiceDefaults, out.Kind)
require.Len(out.Entries, 1)
require.NoError(t, a.RPC("ConfigEntry.List", &args, &out))
require.Equal(t, structs.ServiceDefaults, out.Kind)
require.Len(t, out.Entries, 1)
entry := out.Entries[0].(*structs.ServiceConfigEntry)
require.Equal(entry.Name, "foo")
require.Equal(t, entry.Name, "foo")
}
}
@ -202,8 +201,6 @@ func TestConfig_Delete_CAS(t *testing.T) {
}
t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -214,20 +211,20 @@ func TestConfig_Delete_CAS(t *testing.T) {
Name: "foo",
}
var created bool
require.NoError(a.RPC("ConfigEntry.Apply", &structs.ConfigEntryRequest{
require.NoError(t, a.RPC("ConfigEntry.Apply", &structs.ConfigEntryRequest{
Datacenter: "dc1",
Entry: entry,
}, &created))
require.True(created)
require.True(t, created)
// Read it back to get its ModifyIndex.
var out structs.ConfigEntryResponse
require.NoError(a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{
require.NoError(t, a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{
Datacenter: "dc1",
Kind: entry.Kind,
Name: entry.Name,
}, &out))
require.NotNil(out.Entry)
require.NotNil(t, out.Entry)
modifyIndex := out.Entry.GetRaftIndex().ModifyIndex
@ -238,20 +235,20 @@ func TestConfig_Delete_CAS(t *testing.T) {
nil,
)
rawRsp, err := a.srv.Config(httptest.NewRecorder(), req)
require.NoError(err)
require.NoError(t, err)
deleted, isBool := rawRsp.(bool)
require.True(isBool, "response should be a boolean")
require.False(deleted, "entry should not have been deleted")
require.True(t, isBool, "response should be a boolean")
require.False(t, deleted, "entry should not have been deleted")
// Verify it was not deleted.
var out structs.ConfigEntryResponse
require.NoError(a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{
require.NoError(t, a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{
Datacenter: "dc1",
Kind: entry.Kind,
Name: entry.Name,
}, &out))
require.NotNil(out.Entry)
require.NotNil(t, out.Entry)
})
t.Run("attempt to delete with a valid index", func(t *testing.T) {
@ -261,20 +258,20 @@ func TestConfig_Delete_CAS(t *testing.T) {
nil,
)
rawRsp, err := a.srv.Config(httptest.NewRecorder(), req)
require.NoError(err)
require.NoError(t, err)
deleted, isBool := rawRsp.(bool)
require.True(isBool, "response should be a boolean")
require.True(deleted, "entry should have been deleted")
require.True(t, isBool, "response should be a boolean")
require.True(t, deleted, "entry should have been deleted")
// Verify it was deleted.
var out structs.ConfigEntryResponse
require.NoError(a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{
require.NoError(t, a.RPC("ConfigEntry.Get", &structs.ConfigEntryQuery{
Datacenter: "dc1",
Kind: entry.Kind,
Name: entry.Name,
}, &out))
require.Nil(out.Entry)
require.Nil(t, out.Entry)
})
}
@ -285,7 +282,6 @@ func TestConfig_Apply(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -301,7 +297,7 @@ func TestConfig_Apply(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/config", body)
resp := httptest.NewRecorder()
_, err := a.srv.ConfigApply(resp, req)
require.NoError(err)
require.NoError(t, err)
if resp.Code != 200 {
t.Fatalf(resp.Body.String())
}
@ -314,10 +310,10 @@ func TestConfig_Apply(t *testing.T) {
Datacenter: "dc1",
}
var out structs.ConfigEntryResponse
require.NoError(a.RPC("ConfigEntry.Get", &args, &out))
require.NotNil(out.Entry)
require.NoError(t, a.RPC("ConfigEntry.Get", &args, &out))
require.NotNil(t, out.Entry)
entry := out.Entry.(*structs.ServiceConfigEntry)
require.Equal(entry.Name, "foo")
require.Equal(t, entry.Name, "foo")
}
}
@ -503,7 +499,6 @@ func TestConfig_Apply_CAS(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -519,7 +514,7 @@ func TestConfig_Apply_CAS(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/config", body)
resp := httptest.NewRecorder()
_, err := a.srv.ConfigApply(resp, req)
require.NoError(err)
require.NoError(t, err)
if resp.Code != 200 {
t.Fatalf(resp.Body.String())
}
@ -532,8 +527,8 @@ func TestConfig_Apply_CAS(t *testing.T) {
}
out := &structs.ConfigEntryResponse{}
require.NoError(a.RPC("ConfigEntry.Get", &args, out))
require.NotNil(out.Entry)
require.NoError(t, a.RPC("ConfigEntry.Get", &args, out))
require.NotNil(t, out.Entry)
entry := out.Entry.(*structs.ServiceConfigEntry)
body = bytes.NewBuffer([]byte(`
@ -546,11 +541,11 @@ func TestConfig_Apply_CAS(t *testing.T) {
req, _ = http.NewRequest("PUT", "/v1/config?cas=0", body)
resp = httptest.NewRecorder()
writtenRaw, err := a.srv.ConfigApply(resp, req)
require.NoError(err)
require.NoError(t, err)
written, ok := writtenRaw.(bool)
require.True(ok)
require.False(written)
require.EqualValues(200, resp.Code, resp.Body.String())
require.True(t, ok)
require.False(t, written)
require.EqualValues(t, 200, resp.Code, resp.Body.String())
body = bytes.NewBuffer([]byte(`
{
@ -562,11 +557,11 @@ func TestConfig_Apply_CAS(t *testing.T) {
req, _ = http.NewRequest("PUT", fmt.Sprintf("/v1/config?cas=%d", entry.GetRaftIndex().ModifyIndex), body)
resp = httptest.NewRecorder()
writtenRaw, err = a.srv.ConfigApply(resp, req)
require.NoError(err)
require.NoError(t, err)
written, ok = writtenRaw.(bool)
require.True(ok)
require.True(written)
require.EqualValues(200, resp.Code, resp.Body.String())
require.True(t, ok)
require.True(t, written)
require.EqualValues(t, 200, resp.Code, resp.Body.String())
// Get the entry remaining entry.
args = structs.ConfigEntryQuery{
@ -576,10 +571,10 @@ func TestConfig_Apply_CAS(t *testing.T) {
}
out = &structs.ConfigEntryResponse{}
require.NoError(a.RPC("ConfigEntry.Get", &args, out))
require.NotNil(out.Entry)
require.NoError(t, a.RPC("ConfigEntry.Get", &args, out))
require.NotNil(t, out.Entry)
newEntry := out.Entry.(*structs.ServiceConfigEntry)
require.NotEqual(entry.GetRaftIndex(), newEntry.GetRaftIndex())
require.NotEqual(t, entry.GetRaftIndex(), newEntry.GetRaftIndex())
}
func TestConfig_Apply_Decoding(t *testing.T) {

View File

@ -34,34 +34,13 @@ func (_m *MockProvider) ActiveIntermediate() (string, error) {
return r0, r1
}
// ActiveRoot provides a mock function with given fields:
func (_m *MockProvider) ActiveRoot() (string, error) {
ret := _m.Called()
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Cleanup provides a mock function with given fields: providerTypeChange, config
func (_m *MockProvider) Cleanup(providerTypeChange bool, config map[string]interface{}) error {
ret := _m.Called(providerTypeChange, config)
// Cleanup provides a mock function with given fields: providerTypeChange, otherConfig
func (_m *MockProvider) Cleanup(providerTypeChange bool, otherConfig map[string]interface{}) error {
ret := _m.Called(providerTypeChange, otherConfig)
var r0 error
if rf, ok := ret.Get(0).(func(bool, map[string]interface{}) error); ok {
r0 = rf(providerTypeChange, config)
r0 = rf(providerTypeChange, otherConfig)
} else {
r0 = ret.Error(0)
}
@ -147,17 +126,24 @@ func (_m *MockProvider) GenerateIntermediateCSR() (string, error) {
}
// GenerateRoot provides a mock function with given fields:
func (_m *MockProvider) GenerateRoot() error {
func (_m *MockProvider) GenerateRoot() (RootResult, error) {
ret := _m.Called()
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
var r0 RootResult
if rf, ok := ret.Get(0).(func() RootResult); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
r0 = ret.Get(0).(RootResult)
}
return r0
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// SetIntermediate provides a mock function with given fields: intermediatePEM, rootPEM

View File

@ -118,17 +118,18 @@ type Provider interface {
}
type PrimaryProvider interface {
// GenerateRoot causes the creation of a new root certificate for this provider.
// This can also be a no-op if a root certificate already exists for the given
// config. If IsPrimary is false, calling this method is an error.
GenerateRoot() error
// ActiveRoot returns the currently active root CA for this
// provider. This should be a parent of the certificate returned by
// ActiveIntermediate()
// GenerateRoot is called:
// * to initialize the CA system when a server is elected as a raft leader
// * when the CA configuration is updated in a way that might require
// generating a new root certificate.
//
// TODO: currently called from secondaries, but shouldn't be so is on PrimaryProvider
ActiveRoot() (string, error)
// In both cases GenerateRoot is always called on a newly created provider
// after calling Provider.Configure, and before any other calls to the
// provider.
//
// The provider should return an existing root certificate if one exists,
// otherwise it should generate a new root certificate and return it.
GenerateRoot() (RootResult, error)
// GenerateIntermediate returns a new intermediate signing cert and sets it to
// the active intermediate. If multiple intermediates are needed to complete
@ -181,6 +182,14 @@ type SecondaryProvider interface {
SetIntermediate(intermediatePEM, rootPEM string) error
}
// RootResult is the result returned by PrimaryProvider.GenerateRoot.
//
// TODO: rename this struct
type RootResult struct {
// PEM encoded certificate that will be used as the primary CA.
PEM string
}
// NeedsStop is an optional interface that allows a CA to define a function
// to be called when the CA instance is no longer in use. This is different
// from Cleanup(), as only the local provider instance is being shut down

View File

@ -134,12 +134,19 @@ func (a *AWSProvider) State() (map[string]string, error) {
}
// GenerateRoot implements Provider
func (a *AWSProvider) GenerateRoot() error {
func (a *AWSProvider) GenerateRoot() (RootResult, error) {
if !a.isPrimary {
return fmt.Errorf("provider is not the root certificate authority")
return RootResult{}, fmt.Errorf("provider is not the root certificate authority")
}
return a.ensureCA()
if err := a.ensureCA(); err != nil {
return RootResult{}, err
}
if a.rootPEM == "" {
return RootResult{}, fmt.Errorf("AWS CA provider not fully Initialized")
}
return RootResult{PEM: a.rootPEM}, nil
}
// ensureCA loads the CA resource to check it exists if configured by User or in
@ -489,19 +496,6 @@ func (a *AWSProvider) signCSR(csrPEM string, templateARN string, ttl time.Durati
})
}
// ActiveRoot implements Provider
func (a *AWSProvider) ActiveRoot() (string, error) {
err := a.ensureCA()
if err != nil {
return "", err
}
if a.rootPEM == "" {
return "", fmt.Errorf("Secondary AWS CA provider not fully Initialized")
}
return a.rootPEM, nil
}
// GenerateIntermediateCSR implements Provider
func (a *AWSProvider) GenerateIntermediateCSR() (string, error) {
if a.isPrimary {

View File

@ -38,7 +38,6 @@ func TestAWSBootstrapAndSignPrimary(t *testing.T) {
for _, tc := range KeyTestCases {
tc := tc
t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
cfg := map[string]interface{}{
"PrivateKeyType": tc.KeyType,
"PrivateKeyBits": tc.KeyBits,
@ -47,34 +46,31 @@ func TestAWSBootstrapAndSignPrimary(t *testing.T) {
provider := testAWSProvider(t, testProviderConfigPrimary(t, cfg))
defer provider.Cleanup(true, nil)
// Generate the root
require.NoError(provider.GenerateRoot())
// Fetch Active Root
rootPEM, err := provider.ActiveRoot()
require.NoError(err)
root, err := provider.GenerateRoot()
require.NoError(t, err)
rootPEM := root.PEM
// Generate Intermediate (not actually needed for this provider for now
// but this simulates the calls in Server.initializeRoot).
interPEM, err := provider.GenerateIntermediate()
require.NoError(err)
require.NoError(t, err)
// Should be the same for now
require.Equal(rootPEM, interPEM)
require.Equal(t, rootPEM, interPEM)
// Ensure they use the right key type
rootCert, err := connect.ParseCert(rootPEM)
require.NoError(err)
require.NoError(t, err)
keyType, keyBits, err := connect.KeyInfoFromCert(rootCert)
require.NoError(err)
require.Equal(tc.KeyType, keyType)
require.Equal(tc.KeyBits, keyBits)
require.NoError(t, err)
require.Equal(t, tc.KeyType, keyType)
require.Equal(t, tc.KeyBits, keyBits)
// Ensure that the root cert ttl is withing the configured value
// computation is similar to how we are passing the TTL thru the aws client
expectedTime := time.Now().AddDate(0, 0, int(8761*60*time.Minute/day)).UTC()
require.WithinDuration(expectedTime, rootCert.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured")
require.WithinDuration(t, expectedTime, rootCert.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured")
// Sign a leaf with it
testSignAndValidate(t, provider, rootPEM, nil)
@ -82,16 +78,12 @@ func TestAWSBootstrapAndSignPrimary(t *testing.T) {
}
t.Run("Test default root ttl for aws ca provider", func(t *testing.T) {
provider := testAWSProvider(t, testProviderConfigPrimary(t, nil))
defer provider.Cleanup(true, nil)
// Generate the root
require.NoError(t, provider.GenerateRoot())
// Fetch Active Root
rootPEM, err := provider.ActiveRoot()
root, err := provider.GenerateRoot()
require.NoError(t, err)
rootPEM := root.PEM
// Ensure they use the right key type
rootCert, err := connect.ParseCert(rootPEM)
@ -124,8 +116,9 @@ func TestAWSBootstrapAndSignSecondary(t *testing.T) {
p1 := testAWSProvider(t, testProviderConfigPrimary(t, nil))
defer p1.Cleanup(true, nil)
rootPEM, err := p1.ActiveRoot()
root, err := p1.GenerateRoot()
require.NoError(t, err)
rootPEM := root.PEM
p2 := testAWSProvider(t, testProviderConfigSecondary(t, nil))
defer p2.Cleanup(true, nil)
@ -152,8 +145,9 @@ func TestAWSBootstrapAndSignSecondary(t *testing.T) {
cfg1 := testProviderConfigPrimary(t, nil)
cfg1.State = p1State
p1 = testAWSProvider(t, cfg1)
newRootPEM, err := p1.ActiveRoot()
root, err := p1.GenerateRoot()
require.NoError(t, err)
newRootPEM := root.PEM
cfg2 := testProviderConfigPrimary(t, nil)
cfg2.State = p2State
@ -185,8 +179,9 @@ func TestAWSBootstrapAndSignSecondary(t *testing.T) {
"ExistingARN": p1State[AWSStateCAARNKey],
})
p1 = testAWSProvider(t, cfg1)
newRootPEM, err := p1.ActiveRoot()
root, err := p1.GenerateRoot()
require.NoError(t, err)
newRootPEM := root.PEM
cfg2 := testProviderConfigPrimary(t, map[string]interface{}{
"ExistingARN": p2State[AWSStateCAARNKey],
@ -223,8 +218,9 @@ func TestAWSBootstrapAndSignSecondary(t *testing.T) {
p2 = testAWSProvider(t, cfg2)
require.NoError(t, p2.SetIntermediate(newIntPEM, newRootPEM))
newRootPEM, err = p1.ActiveRoot()
root, err = p1.GenerateRoot()
require.NoError(t, err)
newRootPEM = root.PEM
newIntPEM, err = p2.ActiveIntermediate()
require.NoError(t, err)
@ -244,7 +240,8 @@ func TestAWSBootstrapAndSignSecondaryConsul(t *testing.T) {
p1 := TestConsulProvider(t, delegate)
cfg := testProviderConfig(conf)
require.NoError(t, p1.Configure(cfg))
require.NoError(t, p1.GenerateRoot())
_, err := p1.GenerateRoot()
require.NoError(t, err)
p2 := testAWSProvider(t, testProviderConfigSecondary(t, nil))
defer p2.Cleanup(true, nil)
@ -255,7 +252,9 @@ func TestAWSBootstrapAndSignSecondaryConsul(t *testing.T) {
t.Run("pri=aws,sec=consul", func(t *testing.T) {
p1 := testAWSProvider(t, testProviderConfigPrimary(t, nil))
defer p1.Cleanup(true, nil)
require.NoError(t, p1.GenerateRoot())
_, err := p1.GenerateRoot()
require.NoError(t, err)
conf := testConsulCAConfig()
delegate := newMockDelegate(t, conf)
@ -316,11 +315,13 @@ func TestAWSProvider_Cleanup(t *testing.T) {
}
requirePCADeleted := func(t *testing.T, provider *AWSProvider) {
t.Helper()
deleted, err := describeCA(t, provider)
require.True(t, err != nil || deleted, "The AWS PCA instance has not been deleted")
}
requirePCANotDeleted := func(t *testing.T, provider *AWSProvider) {
t.Helper()
deleted, err := describeCA(t, provider)
require.NoError(t, err)
require.False(t, deleted, "The AWS PCA instance should not have been deleted")

View File

@ -149,29 +149,18 @@ func (c *ConsulProvider) State() (map[string]string, error) {
return c.testState, nil
}
// ActiveRoot returns the active root CA certificate.
func (c *ConsulProvider) ActiveRoot() (string, error) {
// GenerateRoot initializes a new root certificate and private key if needed.
func (c *ConsulProvider) GenerateRoot() (RootResult, error) {
providerState, err := c.getState()
if err != nil {
return "", err
}
return providerState.RootCert, nil
}
// GenerateRoot initializes a new root certificate and private key
// if needed.
func (c *ConsulProvider) GenerateRoot() error {
providerState, err := c.getState()
if err != nil {
return err
return RootResult{}, err
}
if !c.isPrimary {
return fmt.Errorf("provider is not the root certificate authority")
return RootResult{}, fmt.Errorf("provider is not the root certificate authority")
}
if providerState.RootCert != "" {
return nil
return RootResult{PEM: providerState.RootCert}, nil
}
// Generate a private key if needed
@ -179,7 +168,7 @@ func (c *ConsulProvider) GenerateRoot() error {
if c.config.PrivateKey == "" {
_, pk, err := connect.GeneratePrivateKeyWithConfig(c.config.PrivateKeyType, c.config.PrivateKeyBits)
if err != nil {
return err
return RootResult{}, err
}
newState.PrivateKey = pk
} else {
@ -190,12 +179,12 @@ func (c *ConsulProvider) GenerateRoot() error {
if c.config.RootCert == "" {
nextSerial, err := c.incrementAndGetNextSerialNumber()
if err != nil {
return fmt.Errorf("error computing next serial number: %v", err)
return RootResult{}, fmt.Errorf("error computing next serial number: %v", err)
}
ca, err := c.generateCA(newState.PrivateKey, nextSerial, c.config.RootCertTTL)
if err != nil {
return fmt.Errorf("error generating CA: %v", err)
return RootResult{}, fmt.Errorf("error generating CA: %v", err)
}
newState.RootCert = ca
} else {
@ -208,10 +197,10 @@ func (c *ConsulProvider) GenerateRoot() error {
ProviderState: &newState,
}
if _, err := c.Delegate.ApplyCARequest(args); err != nil {
return err
return RootResult{}, err
}
return nil
return RootResult{PEM: newState.RootCert}, nil
}
// GenerateIntermediateCSR creates a private key and generates a CSR
@ -288,18 +277,15 @@ func (c *ConsulProvider) SetIntermediate(intermediatePEM, rootPEM string) error
return nil
}
// We aren't maintaining separate root/intermediate CAs for the builtin
// provider, so just return the root.
func (c *ConsulProvider) ActiveIntermediate() (string, error) {
if c.isPrimary {
return c.ActiveRoot()
}
providerState, err := c.getState()
if err != nil {
return "", err
}
if c.isPrimary {
return providerState.RootCert, nil
}
return providerState.IntermediateCert, nil
}

View File

@ -78,26 +78,24 @@ func requireNotEncoded(t *testing.T, v []byte) {
func TestConsulCAProvider_Bootstrap(t *testing.T) {
t.Parallel()
require := require.New(t)
conf := testConsulCAConfig()
delegate := newMockDelegate(t, conf)
provider := TestConsulProvider(t, delegate)
require.NoError(provider.Configure(testProviderConfig(conf)))
require.NoError(provider.GenerateRoot())
require.NoError(t, provider.Configure(testProviderConfig(conf)))
root, err := provider.ActiveRoot()
require.NoError(err)
root, err := provider.GenerateRoot()
require.NoError(t, err)
// Intermediate should be the same cert.
inter, err := provider.ActiveIntermediate()
require.NoError(err)
require.Equal(root, inter)
require.NoError(t, err)
require.Equal(t, root.PEM, inter)
// Should be a valid cert
parsed, err := connect.ParseCert(root)
require.NoError(err)
require.Equal(parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", conf.ClusterID))
parsed, err := connect.ParseCert(root.PEM)
require.NoError(t, err)
require.Equal(t, parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", conf.ClusterID))
requireNotEncoded(t, parsed.SubjectKeyId)
requireNotEncoded(t, parsed.AuthorityKeyId)
@ -105,16 +103,15 @@ func TestConsulCAProvider_Bootstrap(t *testing.T) {
// notice that we allow a margin of "error" of 10 minutes between the
// generateCA() creation and this check
defaultRootCertTTL, err := time.ParseDuration(structs.DefaultRootCertTTL)
require.NoError(err)
require.NoError(t, err)
expectedNotAfter := time.Now().Add(defaultRootCertTTL).UTC()
require.WithinDuration(expectedNotAfter, parsed.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured")
require.WithinDuration(t, expectedNotAfter, parsed.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured")
}
func TestConsulCAProvider_Bootstrap_WithCert(t *testing.T) {
t.Parallel()
// Make sure setting a custom private key/root cert works.
require := require.New(t)
rootCA := connect.TestCAWithTTL(t, nil, 5*time.Hour)
conf := testConsulCAConfig()
conf.Config = map[string]interface{}{
@ -124,24 +121,23 @@ func TestConsulCAProvider_Bootstrap_WithCert(t *testing.T) {
delegate := newMockDelegate(t, conf)
provider := TestConsulProvider(t, delegate)
require.NoError(provider.Configure(testProviderConfig(conf)))
require.NoError(provider.GenerateRoot())
require.NoError(t, provider.Configure(testProviderConfig(conf)))
root, err := provider.ActiveRoot()
require.NoError(err)
require.Equal(root, rootCA.RootCert)
root, err := provider.GenerateRoot()
require.NoError(t, err)
require.Equal(t, root.PEM, rootCA.RootCert)
// Should be a valid cert
parsed, err := connect.ParseCert(root)
require.NoError(err)
parsed, err := connect.ParseCert(root.PEM)
require.NoError(t, err)
// test that the default root cert ttl was not applied to the provided cert
defaultRootCertTTL, err := time.ParseDuration(structs.DefaultRootCertTTL)
require.NoError(err)
require.NoError(t, err)
defaultNotAfter := time.Now().Add(defaultRootCertTTL).UTC()
// we can't compare given the "delta" between the time the cert is generated
// and when we start the test; so just look at the years for now, given different years
require.NotEqualf(defaultNotAfter.Year(), parsed.NotAfter.Year(), "parsed cert ttl expected to be different from default root cert ttl")
require.NotEqualf(t, defaultNotAfter.Year(), parsed.NotAfter.Year(), "parsed cert ttl expected to be different from default root cert ttl")
}
func TestConsulCAProvider_SignLeaf(t *testing.T) {
@ -154,7 +150,6 @@ func TestConsulCAProvider_SignLeaf(t *testing.T) {
for _, tc := range KeyTestCases {
tc := tc
t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
conf := testConsulCAConfig()
conf.Config["LeafCertTTL"] = "1h"
conf.Config["PrivateKeyType"] = tc.KeyType
@ -162,8 +157,9 @@ func TestConsulCAProvider_SignLeaf(t *testing.T) {
delegate := newMockDelegate(t, conf)
provider := TestConsulProvider(t, delegate)
require.NoError(provider.Configure(testProviderConfig(conf)))
require.NoError(provider.GenerateRoot())
require.NoError(t, provider.Configure(testProviderConfig(conf)))
_, err := provider.GenerateRoot()
require.NoError(t, err)
spiffeService := &connect.SpiffeIDService{
Host: connect.TestClusterID + ".consul",
@ -177,26 +173,26 @@ func TestConsulCAProvider_SignLeaf(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
require.NoError(err)
require.NoError(t, err)
cert, err := provider.Sign(csr)
require.NoError(err)
require.NoError(t, err)
requireTrailingNewline(t, cert)
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.Equal(spiffeService.URI(), parsed.URIs[0])
require.Empty(parsed.Subject.CommonName)
require.Equal(uint64(3), parsed.SerialNumber.Uint64())
require.NoError(t, err)
require.Equal(t, spiffeService.URI(), parsed.URIs[0])
require.Empty(t, parsed.Subject.CommonName)
require.Equal(t, uint64(3), parsed.SerialNumber.Uint64())
subjectKeyID, err := connect.KeyId(csr.PublicKey)
require.NoError(err)
require.Equal(subjectKeyID, parsed.SubjectKeyId)
require.NoError(t, err)
require.Equal(t, subjectKeyID, parsed.SubjectKeyId)
requireNotEncoded(t, parsed.SubjectKeyId)
requireNotEncoded(t, parsed.AuthorityKeyId)
// Ensure the cert is valid now and expires within the correct limit.
now := time.Now()
require.True(parsed.NotAfter.Sub(now) < time.Hour)
require.True(parsed.NotBefore.Before(now))
require.True(t, parsed.NotAfter.Sub(now) < time.Hour)
require.True(t, parsed.NotBefore.Before(now))
}
// Generate a new cert for another service and make sure
@ -206,22 +202,22 @@ func TestConsulCAProvider_SignLeaf(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
require.NoError(err)
require.NoError(t, err)
cert, err := provider.Sign(csr)
require.NoError(err)
require.NoError(t, err)
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.Equal(spiffeService.URI(), parsed.URIs[0])
require.Empty(parsed.Subject.CommonName)
require.Equal(uint64(4), parsed.SerialNumber.Uint64())
require.NoError(t, err)
require.Equal(t, spiffeService.URI(), parsed.URIs[0])
require.Empty(t, parsed.Subject.CommonName)
require.Equal(t, uint64(4), parsed.SerialNumber.Uint64())
requireNotEncoded(t, parsed.SubjectKeyId)
requireNotEncoded(t, parsed.AuthorityKeyId)
// Ensure the cert is valid now and expires within the correct limit.
require.True(time.Until(parsed.NotAfter) < 3*24*time.Hour)
require.True(parsed.NotBefore.Before(time.Now()))
require.True(t, time.Until(parsed.NotAfter) < 3*24*time.Hour)
require.True(t, parsed.NotBefore.Before(time.Now()))
}
spiffeAgent := &connect.SpiffeIDAgent{
@ -234,23 +230,23 @@ func TestConsulCAProvider_SignLeaf(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeAgent)
csr, err := connect.ParseCSR(raw)
require.NoError(err)
require.NoError(t, err)
cert, err := provider.Sign(csr)
require.NoError(err)
require.NoError(t, err)
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.Equal(spiffeAgent.URI(), parsed.URIs[0])
require.Empty(parsed.Subject.CommonName)
require.Equal(uint64(5), parsed.SerialNumber.Uint64())
require.NoError(t, err)
require.Equal(t, spiffeAgent.URI(), parsed.URIs[0])
require.Empty(t, parsed.Subject.CommonName)
require.Equal(t, uint64(5), parsed.SerialNumber.Uint64())
requireNotEncoded(t, parsed.SubjectKeyId)
requireNotEncoded(t, parsed.AuthorityKeyId)
// Ensure the cert is valid now and expires within the correct limit.
now := time.Now()
require.True(parsed.NotAfter.Sub(now) < time.Hour)
require.True(parsed.NotBefore.Before(now))
require.True(t, parsed.NotAfter.Sub(now) < time.Hour)
require.True(t, parsed.NotBefore.Before(now))
}
})
}
@ -268,15 +264,15 @@ func TestConsulCAProvider_CrossSignCA(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
conf1 := testConsulCAConfig()
delegate1 := newMockDelegate(t, conf1)
provider1 := TestConsulProvider(t, delegate1)
conf1.Config["PrivateKeyType"] = tc.SigningKeyType
conf1.Config["PrivateKeyBits"] = tc.SigningKeyBits
require.NoError(provider1.Configure(testProviderConfig(conf1)))
require.NoError(provider1.GenerateRoot())
require.NoError(t, provider1.Configure(testProviderConfig(conf1)))
_, err := provider1.GenerateRoot()
require.NoError(t, err)
conf2 := testConsulCAConfig()
conf2.CreateIndex = 10
@ -284,8 +280,9 @@ func TestConsulCAProvider_CrossSignCA(t *testing.T) {
provider2 := TestConsulProvider(t, delegate2)
conf2.Config["PrivateKeyType"] = tc.CSRKeyType
conf2.Config["PrivateKeyBits"] = tc.CSRKeyBits
require.NoError(provider2.Configure(testProviderConfig(conf2)))
require.NoError(provider2.GenerateRoot())
require.NoError(t, provider2.Configure(testProviderConfig(conf2)))
_, err = provider2.GenerateRoot()
require.NoError(t, err)
testCrossSignProviders(t, provider1, provider2)
})
@ -293,52 +290,52 @@ func TestConsulCAProvider_CrossSignCA(t *testing.T) {
}
func testCrossSignProviders(t *testing.T, provider1, provider2 Provider) {
require := require.New(t)
// Get the root from the new provider to be cross-signed.
newRootPEM, err := provider2.ActiveRoot()
require.NoError(err)
newRoot, err := connect.ParseCert(newRootPEM)
require.NoError(err)
root, err := provider2.GenerateRoot()
require.NoError(t, err)
newRoot, err := connect.ParseCert(root.PEM)
require.NoError(t, err)
oldSubject := newRoot.Subject.CommonName
requireNotEncoded(t, newRoot.SubjectKeyId)
requireNotEncoded(t, newRoot.AuthorityKeyId)
newInterPEM, err := provider2.ActiveIntermediate()
require.NoError(err)
require.NoError(t, err)
newIntermediate, err := connect.ParseCert(newInterPEM)
require.NoError(err)
require.NoError(t, err)
requireNotEncoded(t, newIntermediate.SubjectKeyId)
requireNotEncoded(t, newIntermediate.AuthorityKeyId)
// Have provider1 cross sign our new root cert.
xcPEM, err := provider1.CrossSignCA(newRoot)
require.NoError(err)
require.NoError(t, err)
xc, err := connect.ParseCert(xcPEM)
require.NoError(err)
require.NoError(t, err)
requireNotEncoded(t, xc.SubjectKeyId)
requireNotEncoded(t, xc.AuthorityKeyId)
oldRootPEM, err := provider1.ActiveRoot()
require.NoError(err)
oldRoot, err := connect.ParseCert(oldRootPEM)
require.NoError(err)
p1Root, err := provider1.GenerateRoot()
require.NoError(t, err)
oldRoot, err := connect.ParseCert(p1Root.PEM)
require.NoError(t, err)
requireNotEncoded(t, oldRoot.SubjectKeyId)
requireNotEncoded(t, oldRoot.AuthorityKeyId)
// AuthorityKeyID should now be the signing root's, SubjectKeyId should be kept.
require.Equal(oldRoot.SubjectKeyId, xc.AuthorityKeyId,
require.Equal(t, oldRoot.SubjectKeyId, xc.AuthorityKeyId,
"newSKID=%x\nnewAKID=%x\noldSKID=%x\noldAKID=%x\nxcSKID=%x\nxcAKID=%x",
newRoot.SubjectKeyId, newRoot.AuthorityKeyId,
oldRoot.SubjectKeyId, oldRoot.AuthorityKeyId,
xc.SubjectKeyId, xc.AuthorityKeyId)
require.Equal(newRoot.SubjectKeyId, xc.SubjectKeyId)
require.Equal(t, newRoot.SubjectKeyId, xc.SubjectKeyId)
// Subject name should not have changed.
require.Equal(oldSubject, xc.Subject.CommonName)
require.Equal(t, oldSubject, xc.Subject.CommonName)
// Issuer should be the signing root.
require.Equal(oldRoot.Issuer.CommonName, xc.Issuer.CommonName)
require.Equal(t, oldRoot.Issuer.CommonName, xc.Issuer.CommonName)
// Get a leaf cert so we can verify against the cross-signed cert.
spiffeService := &connect.SpiffeIDService{
@ -350,13 +347,13 @@ func testCrossSignProviders(t *testing.T, provider1, provider2 Provider) {
raw, _ := connect.TestCSR(t, spiffeService)
leafCsr, err := connect.ParseCSR(raw)
require.NoError(err)
require.NoError(t, err)
leafPEM, err := provider2.Sign(leafCsr)
require.NoError(err)
require.NoError(t, err)
cert, err := connect.ParseCert(leafPEM)
require.NoError(err)
require.NoError(t, err)
requireNotEncoded(t, cert.SubjectKeyId)
requireNotEncoded(t, cert.AuthorityKeyId)
@ -374,7 +371,7 @@ func testCrossSignProviders(t *testing.T, provider1, provider2 Provider) {
Intermediates: intermediatePool,
Roots: rootPool,
})
require.NoError(err)
require.NoError(t, err)
}
}
@ -390,15 +387,15 @@ func TestConsulProvider_SignIntermediate(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
conf1 := testConsulCAConfig()
delegate1 := newMockDelegate(t, conf1)
provider1 := TestConsulProvider(t, delegate1)
conf1.Config["PrivateKeyType"] = tc.SigningKeyType
conf1.Config["PrivateKeyBits"] = tc.SigningKeyBits
require.NoError(provider1.Configure(testProviderConfig(conf1)))
require.NoError(provider1.GenerateRoot())
require.NoError(t, provider1.Configure(testProviderConfig(conf1)))
_, err := provider1.GenerateRoot()
require.NoError(t, err)
conf2 := testConsulCAConfig()
conf2.CreateIndex = 10
@ -409,7 +406,7 @@ func TestConsulProvider_SignIntermediate(t *testing.T) {
cfg := testProviderConfig(conf2)
cfg.IsPrimary = false
cfg.Datacenter = "dc2"
require.NoError(provider2.Configure(cfg))
require.NoError(t, provider2.Configure(cfg))
testSignIntermediateCrossDC(t, provider1, provider2)
})
@ -418,22 +415,22 @@ func TestConsulProvider_SignIntermediate(t *testing.T) {
}
func testSignIntermediateCrossDC(t *testing.T, provider1, provider2 Provider) {
require := require.New(t)
// Get the intermediate CSR from provider2.
csrPEM, err := provider2.GenerateIntermediateCSR()
require.NoError(err)
require.NoError(t, err)
csr, err := connect.ParseCSR(csrPEM)
require.NoError(err)
require.NoError(t, err)
// Sign the CSR with provider1.
intermediatePEM, err := provider1.SignIntermediate(csr)
require.NoError(err)
rootPEM, err := provider1.ActiveRoot()
require.NoError(err)
require.NoError(t, err)
root, err := provider1.GenerateRoot()
require.NoError(t, err)
rootPEM := root.PEM
// Give the new intermediate to provider2 to use.
require.NoError(provider2.SetIntermediate(intermediatePEM, rootPEM))
require.NoError(t, provider2.SetIntermediate(intermediatePEM, rootPEM))
// Have provider2 sign a leaf cert and make sure the chain is correct.
spiffeService := &connect.SpiffeIDService{
@ -445,13 +442,13 @@ func testSignIntermediateCrossDC(t *testing.T, provider1, provider2 Provider) {
raw, _ := connect.TestCSR(t, spiffeService)
leafCsr, err := connect.ParseCSR(raw)
require.NoError(err)
require.NoError(t, err)
leafPEM, err := provider2.Sign(leafCsr)
require.NoError(err)
require.NoError(t, err)
cert, err := connect.ParseCert(leafPEM)
require.NoError(err)
require.NoError(t, err)
requireNotEncoded(t, cert.SubjectKeyId)
requireNotEncoded(t, cert.AuthorityKeyId)
@ -466,7 +463,7 @@ func testSignIntermediateCrossDC(t *testing.T, provider1, provider2 Provider) {
Intermediates: intermediatePool,
Roots: rootPool,
})
require.NoError(err)
require.NoError(t, err)
}
func TestConsulCAProvider_MigrateOldID(t *testing.T) {
@ -503,7 +500,8 @@ func TestConsulCAProvider_MigrateOldID(t *testing.T) {
provider := TestConsulProvider(t, delegate)
require.NoError(t, provider.Configure(testProviderConfig(conf)))
require.NoError(t, provider.GenerateRoot())
_, err = provider.GenerateRoot()
require.NoError(t, err)
// After running Configure, the old ID entry should be gone.
_, providerState, err = delegate.state.CAProviderState(tc.oldID)

View File

@ -12,13 +12,13 @@ import (
"strings"
"time"
"github.com/hashicorp/consul/lib/decode"
"github.com/hashicorp/go-hclog"
vaultapi "github.com/hashicorp/vault/api"
"github.com/mitchellh/mapstructure"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib/decode"
)
const (
@ -220,19 +220,14 @@ func (v *VaultProvider) State() (map[string]string, error) {
return nil, nil
}
// ActiveRoot returns the active root CA certificate.
func (v *VaultProvider) ActiveRoot() (string, error) {
return v.getCA(v.config.RootPKIPath)
}
// GenerateRoot mounts and initializes a new root PKI backend if needed.
func (v *VaultProvider) GenerateRoot() error {
func (v *VaultProvider) GenerateRoot() (RootResult, error) {
if !v.isPrimary {
return fmt.Errorf("provider is not the root certificate authority")
return RootResult{}, fmt.Errorf("provider is not the root certificate authority")
}
// Set up the root PKI backend if necessary.
rootPEM, err := v.ActiveRoot()
rootPEM, err := v.getCA(v.config.RootPKIPath)
switch err {
case ErrBackendNotMounted:
err := v.client.Sys().Mount(v.config.RootPKIPath, &vaultapi.MountInput{
@ -247,14 +242,14 @@ func (v *VaultProvider) GenerateRoot() error {
},
})
if err != nil {
return err
return RootResult{}, err
}
fallthrough
case ErrBackendNotInitialized:
uid, err := connect.CompactUID()
if err != nil {
return err
return RootResult{}, err
}
_, err = v.client.Logical().Write(v.config.RootPKIPath+"root/generate/internal", map[string]interface{}{
"common_name": connect.CACN("vault", uid, v.clusterID, v.isPrimary),
@ -263,17 +258,25 @@ func (v *VaultProvider) GenerateRoot() error {
"key_bits": v.config.PrivateKeyBits,
})
if err != nil {
return err
return RootResult{}, err
}
// retrieve the newly generated cert so that we can return it
// TODO: is this already available from the Local().Write() above?
rootPEM, err = v.getCA(v.config.RootPKIPath)
if err != nil {
return RootResult{}, err
}
default:
if err != nil {
return err
return RootResult{}, err
}
if rootPEM != "" {
rootCert, err := connect.ParseCert(rootPEM)
if err != nil {
return err
return RootResult{}, err
}
// Vault PKI doesn't allow in-place cert/key regeneration. That
@ -285,18 +288,18 @@ func (v *VaultProvider) GenerateRoot() error {
// ForceWithoutCrossSigning option when changing key types.
foundKeyType, foundKeyBits, err := connect.KeyInfoFromCert(rootCert)
if err != nil {
return err
return RootResult{}, err
}
if v.config.PrivateKeyType != foundKeyType {
return fmt.Errorf("cannot update the PrivateKeyType field without choosing a new PKI mount for the root CA")
return RootResult{}, fmt.Errorf("cannot update the PrivateKeyType field without choosing a new PKI mount for the root CA")
}
if v.config.PrivateKeyBits != foundKeyBits {
return fmt.Errorf("cannot update the PrivateKeyBits field without choosing a new PKI mount for the root CA")
return RootResult{}, fmt.Errorf("cannot update the PrivateKeyBits field without choosing a new PKI mount for the root CA")
}
}
}
return nil
return RootResult{PEM: rootPEM}, nil
}
// GenerateIntermediateCSR creates a private key and generates a CSR
@ -396,17 +399,14 @@ func (v *VaultProvider) SetIntermediate(intermediatePEM, rootPEM string) error {
return fmt.Errorf("cannot set an intermediate using another root in the primary datacenter")
}
err := validateSetIntermediate(
intermediatePEM, rootPEM,
"", // we don't have access to the private key directly
v.spiffeID,
)
// the private key is in vault, so we can't use it in this validation
err := validateSetIntermediate(intermediatePEM, rootPEM, "", v.spiffeID)
if err != nil {
return err
}
_, err = v.client.Logical().Write(v.config.IntermediatePKIPath+"intermediate/set-signed", map[string]interface{}{
"certificate": fmt.Sprintf("%s\n%s", intermediatePEM, rootPEM),
"certificate": intermediatePEM,
})
if err != nil {
return err
@ -574,7 +574,7 @@ func (v *VaultProvider) SignIntermediate(csr *x509.CertificateRequest) (string,
// CrossSignCA takes a CA certificate and cross-signs it to form a trust chain
// back to our active root.
func (v *VaultProvider) CrossSignCA(cert *x509.Certificate) (string, error) {
rootPEM, err := v.ActiveRoot()
rootPEM, err := v.getCA(v.config.RootPKIPath)
if err != nil {
return "", err
}

View File

@ -116,13 +116,12 @@ func TestVaultCAProvider_VaultTLSConfig(t *testing.T) {
TLSSkipVerify: true,
}
tlsConfig := vaultTLSConfig(config)
require := require.New(t)
require.Equal(config.CAFile, tlsConfig.CACert)
require.Equal(config.CAPath, tlsConfig.CAPath)
require.Equal(config.CertFile, tlsConfig.ClientCert)
require.Equal(config.KeyFile, tlsConfig.ClientKey)
require.Equal(config.TLSServerName, tlsConfig.TLSServerName)
require.Equal(config.TLSSkipVerify, tlsConfig.Insecure)
require.Equal(t, config.CAFile, tlsConfig.CACert)
require.Equal(t, config.CAPath, tlsConfig.CAPath)
require.Equal(t, config.CertFile, tlsConfig.ClientCert)
require.Equal(t, config.KeyFile, tlsConfig.ClientKey)
require.Equal(t, config.TLSServerName, tlsConfig.TLSServerName)
require.Equal(t, config.TLSSkipVerify, tlsConfig.Insecure)
}
func TestVaultCAProvider_Configure(t *testing.T) {
@ -171,11 +170,10 @@ func TestVaultCAProvider_SecondaryActiveIntermediate(t *testing.T) {
provider, testVault := testVaultProviderWithConfig(t, false, nil)
defer testVault.Stop()
require := require.New(t)
cert, err := provider.ActiveIntermediate()
require.Empty(cert)
require.NoError(err)
require.Empty(t, cert)
require.NoError(t, err)
}
func TestVaultCAProvider_RenewToken(t *testing.T) {
@ -231,8 +229,6 @@ func TestVaultCAProvider_Bootstrap(t *testing.T) {
defer testvault2.Stop()
client2 := testvault2.client
require := require.New(t)
cases := []struct {
certFunc func() (string, error)
backendPath string
@ -242,7 +238,10 @@ func TestVaultCAProvider_Bootstrap(t *testing.T) {
expectedRootCertTTL string
}{
{
certFunc: providerWDefaultRootCertTtl.ActiveRoot,
certFunc: func() (string, error) {
root, err := providerWDefaultRootCertTtl.GenerateRoot()
return root.PEM, err
},
backendPath: "pki-root/",
rootCaCreation: true,
client: client1,
@ -264,28 +263,28 @@ func TestVaultCAProvider_Bootstrap(t *testing.T) {
provider := tc.provider
client := tc.client
cert, err := tc.certFunc()
require.NoError(err)
require.NoError(t, err)
req := client.NewRequest("GET", "/v1/"+tc.backendPath+"ca/pem")
resp, err := client.RawRequest(req)
require.NoError(err)
require.NoError(t, err)
bytes, err := ioutil.ReadAll(resp.Body)
require.NoError(err)
require.Equal(cert, string(bytes)+"\n")
require.NoError(t, err)
require.Equal(t, cert, string(bytes)+"\n")
// Should be a valid CA cert
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.True(parsed.IsCA)
require.Len(parsed.URIs, 1)
require.Equal(fmt.Sprintf("spiffe://%s.consul", provider.clusterID), parsed.URIs[0].String())
require.NoError(t, err)
require.True(t, parsed.IsCA)
require.Len(t, parsed.URIs, 1)
require.Equal(t, fmt.Sprintf("spiffe://%s.consul", provider.clusterID), parsed.URIs[0].String())
// test that the root cert ttl as applied
if tc.rootCaCreation {
rootCertTTL, err := time.ParseDuration(tc.expectedRootCertTTL)
require.NoError(err)
require.NoError(t, err)
expectedNotAfter := time.Now().Add(rootCertTTL).UTC()
require.WithinDuration(expectedNotAfter, parsed.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured")
require.WithinDuration(t, expectedNotAfter, parsed.NotAfter, 10*time.Minute, "expected parsed cert ttl to be the same as the value configured")
}
}
}
@ -313,7 +312,6 @@ func TestVaultCAProvider_SignLeaf(t *testing.T) {
for _, tc := range KeyTestCases {
tc := tc
t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
provider, testVault := testVaultProviderWithConfig(t, true, map[string]interface{}{
"LeafCertTTL": "1h",
"PrivateKeyType": tc.KeyType,
@ -328,12 +326,13 @@ func TestVaultCAProvider_SignLeaf(t *testing.T) {
Service: "foo",
}
rootPEM, err := provider.ActiveRoot()
require.NoError(err)
root, err := provider.GenerateRoot()
require.NoError(t, err)
rootPEM := root.PEM
assertCorrectKeyType(t, tc.KeyType, rootPEM)
intPEM, err := provider.ActiveIntermediate()
require.NoError(err)
require.NoError(t, err)
assertCorrectKeyType(t, tc.KeyType, intPEM)
// Generate a leaf cert for the service.
@ -342,23 +341,23 @@ func TestVaultCAProvider_SignLeaf(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
require.NoError(err)
require.NoError(t, err)
cert, err := provider.Sign(csr)
require.NoError(err)
require.NoError(t, err)
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.Equal(parsed.URIs[0], spiffeService.URI())
require.NoError(t, err)
require.Equal(t, parsed.URIs[0], spiffeService.URI())
firstSerial = parsed.SerialNumber.Uint64()
// Ensure the cert is valid now and expires within the correct limit.
now := time.Now()
require.True(parsed.NotAfter.Sub(now) < time.Hour)
require.True(parsed.NotBefore.Before(now))
require.True(t, parsed.NotAfter.Sub(now) < time.Hour)
require.True(t, parsed.NotBefore.Before(now))
// Make sure we can validate the cert as expected.
require.NoError(connect.ValidateLeaf(rootPEM, cert, []string{intPEM}))
require.NoError(t, connect.ValidateLeaf(rootPEM, cert, []string{intPEM}))
requireTrailingNewline(t, cert)
}
@ -369,22 +368,22 @@ func TestVaultCAProvider_SignLeaf(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
require.NoError(err)
require.NoError(t, err)
cert, err := provider.Sign(csr)
require.NoError(err)
require.NoError(t, err)
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.Equal(parsed.URIs[0], spiffeService.URI())
require.NotEqual(firstSerial, parsed.SerialNumber.Uint64())
require.NoError(t, err)
require.Equal(t, parsed.URIs[0], spiffeService.URI())
require.NotEqual(t, firstSerial, parsed.SerialNumber.Uint64())
// Ensure the cert is valid now and expires within the correct limit.
require.True(time.Until(parsed.NotAfter) < time.Hour)
require.True(parsed.NotBefore.Before(time.Now()))
require.True(t, time.Until(parsed.NotAfter) < time.Hour)
require.True(t, parsed.NotBefore.Before(time.Now()))
// Make sure we can validate the cert as expected.
require.NoError(connect.ValidateLeaf(rootPEM, cert, []string{intPEM}))
require.NoError(t, connect.ValidateLeaf(rootPEM, cert, []string{intPEM}))
}
})
}
@ -399,7 +398,6 @@ func TestVaultCAProvider_CrossSignCA(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(tc.Desc, func(t *testing.T) {
require := require.New(t)
if tc.SigningKeyType != tc.CSRKeyType {
// See https://github.com/hashicorp/vault/issues/7709
@ -413,12 +411,12 @@ func TestVaultCAProvider_CrossSignCA(t *testing.T) {
defer testVault1.Stop()
{
rootPEM, err := provider1.ActiveRoot()
require.NoError(err)
assertCorrectKeyType(t, tc.SigningKeyType, rootPEM)
root, err := provider1.GenerateRoot()
require.NoError(t, err)
assertCorrectKeyType(t, tc.SigningKeyType, root.PEM)
intPEM, err := provider1.ActiveIntermediate()
require.NoError(err)
require.NoError(t, err)
assertCorrectKeyType(t, tc.SigningKeyType, intPEM)
}
@ -430,12 +428,12 @@ func TestVaultCAProvider_CrossSignCA(t *testing.T) {
defer testVault2.Stop()
{
rootPEM, err := provider2.ActiveRoot()
require.NoError(err)
assertCorrectKeyType(t, tc.CSRKeyType, rootPEM)
root, err := provider2.GenerateRoot()
require.NoError(t, err)
assertCorrectKeyType(t, tc.CSRKeyType, root.PEM)
intPEM, err := provider2.ActiveIntermediate()
require.NoError(err)
require.NoError(t, err)
assertCorrectKeyType(t, tc.CSRKeyType, intPEM)
}
@ -498,7 +496,8 @@ func TestVaultProvider_SignIntermediateConsul(t *testing.T) {
delegate := newMockDelegate(t, conf)
provider1 := TestConsulProvider(t, delegate)
require.NoError(t, provider1.Configure(testProviderConfig(conf)))
require.NoError(t, provider1.GenerateRoot())
_, err := provider1.GenerateRoot()
require.NoError(t, err)
// Ensure that we don't configure vault to try and mint leafs that
// outlive their CA during the test (which hard fails in vault).
@ -792,8 +791,9 @@ func createVaultProvider(t *testing.T, isPrimary bool, addr, token string, rawCo
t.Cleanup(provider.Stop)
require.NoError(t, provider.Configure(cfg))
if isPrimary {
require.NoError(t, provider.GenerateRoot())
_, err := provider.GenerateIntermediate()
_, err := provider.GenerateRoot()
require.NoError(t, err)
_, err = provider.GenerateIntermediate()
require.NoError(t, err)
}

View File

@ -8,8 +8,9 @@ import (
"crypto/x509"
"encoding/pem"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/structs"
)
type KeyConfig struct {
@ -47,32 +48,30 @@ func makeConfig(kc KeyConfig) structs.CommonCAProviderConfig {
}
func testGenerateRSAKey(t *testing.T, bits int) {
r := require.New(t)
_, rsaBlock, err := GeneratePrivateKeyWithConfig("rsa", bits)
r.NoError(err)
r.Contains(rsaBlock, "RSA PRIVATE KEY")
require.NoError(t, err)
require.Contains(t, rsaBlock, "RSA PRIVATE KEY")
rsaBytes, _ := pem.Decode([]byte(rsaBlock))
r.NotNil(rsaBytes)
require.NotNil(t, rsaBytes)
rsaKey, err := x509.ParsePKCS1PrivateKey(rsaBytes.Bytes)
r.NoError(err)
r.NoError(rsaKey.Validate())
r.Equal(bits/8, rsaKey.Size()) // note: returned size is in bytes. 2048/8==256
require.NoError(t, err)
require.NoError(t, rsaKey.Validate())
require.Equal(t, bits/8, rsaKey.Size()) // note: returned size is in bytes. 2048/8==256
}
func testGenerateECDSAKey(t *testing.T, bits int) {
r := require.New(t)
_, pemBlock, err := GeneratePrivateKeyWithConfig("ec", bits)
r.NoError(err)
r.Contains(pemBlock, "EC PRIVATE KEY")
require.NoError(t, err)
require.Contains(t, pemBlock, "EC PRIVATE KEY")
block, _ := pem.Decode([]byte(pemBlock))
r.NotNil(block)
require.NotNil(t, block)
pk, err := x509.ParseECPrivateKey(block.Bytes)
r.NoError(err)
r.Equal(bits, pk.Curve.Params().BitSize)
require.NoError(t, err)
require.Equal(t, bits, pk.Curve.Params().BitSize)
}
// Tests to make sure we are able to generate every type of private key supported by the x509 lib.
@ -104,7 +103,7 @@ func TestValidateGoodConfigs(t *testing.T) {
config := makeConfig(params)
t.Run(fmt.Sprintf("TestValidateGoodConfigs-%s-%d", params.keyType, params.keyBits),
func(t *testing.T) {
require.New(t).NoError(config.Validate(), "unexpected error: type=%s bits=%d",
require.NoError(t, config.Validate(), "unexpected error: type=%s bits=%d",
params.keyType, params.keyBits)
})
@ -117,7 +116,7 @@ func TestValidateBadConfigs(t *testing.T) {
for _, params := range badParams {
config := makeConfig(params)
t.Run(fmt.Sprintf("TestValidateBadConfigs-%s-%d", params.keyType, params.keyBits), func(t *testing.T) {
require.New(t).Error(config.Validate(), "expected error: type=%s bits=%d",
require.Error(t, config.Validate(), "expected error: type=%s bits=%d",
params.keyType, params.keyBits)
})
}
@ -131,7 +130,6 @@ func TestSignatureMismatches(t *testing.T) {
}
t.Parallel()
r := require.New(t)
for _, p1 := range goodParams {
for _, p2 := range goodParams {
if p1 == p2 {
@ -139,14 +137,14 @@ func TestSignatureMismatches(t *testing.T) {
}
t.Run(fmt.Sprintf("TestMismatches-%s%d-%s%d", p1.keyType, p1.keyBits, p2.keyType, p2.keyBits), func(t *testing.T) {
ca := TestCAWithKeyType(t, nil, p1.keyType, p1.keyBits)
r.Equal(p1.keyType, ca.PrivateKeyType)
r.Equal(p1.keyBits, ca.PrivateKeyBits)
require.Equal(t, p1.keyType, ca.PrivateKeyType)
require.Equal(t, p1.keyBits, ca.PrivateKeyBits)
certPEM, keyPEM, err := testLeaf(t, "foobar.service.consul", "default", ca, p2.keyType, p2.keyBits)
r.NoError(err)
require.NoError(t, err)
_, err = ParseCert(certPEM)
r.NoError(err)
require.NoError(t, err)
_, err = ParseSigner(keyPEM)
r.NoError(err)
require.NoError(t, err)
})
}
}

View File

@ -29,20 +29,18 @@ func skipIfMissingOpenSSL(t *testing.T) {
func testCAAndLeaf(t *testing.T, keyType string, keyBits int) {
skipIfMissingOpenSSL(t)
require := require.New(t)
// Create the certs
ca := TestCAWithKeyType(t, nil, keyType, keyBits)
leaf, _ := TestLeaf(t, "web", ca)
// Create a temporary directory for storing the certs
td, err := ioutil.TempDir("", "consul")
require.NoError(err)
require.NoError(t, err)
defer os.RemoveAll(td)
// Write the cert
require.NoError(ioutil.WriteFile(filepath.Join(td, "ca.pem"), []byte(ca.RootCert), 0644))
require.NoError(ioutil.WriteFile(filepath.Join(td, "leaf.pem"), []byte(leaf[:]), 0644))
require.NoError(t, ioutil.WriteFile(filepath.Join(td, "ca.pem"), []byte(ca.RootCert), 0644))
require.NoError(t, ioutil.WriteFile(filepath.Join(td, "leaf.pem"), []byte(leaf[:]), 0644))
// Use OpenSSL to verify so we have an external, known-working process
// that can verify this outside of our own implementations.
@ -54,15 +52,13 @@ func testCAAndLeaf(t *testing.T, keyType string, keyBits int) {
if ee, ok := err.(*exec.ExitError); ok {
t.Log("STDERR:", string(ee.Stderr))
}
require.NoError(err)
require.NoError(t, err)
}
// Test cross-signing.
func testCAAndLeaf_xc(t *testing.T, keyType string, keyBits int) {
skipIfMissingOpenSSL(t)
assert := assert.New(t)
// Create the certs
ca1 := TestCAWithKeyType(t, nil, keyType, keyBits)
ca2 := TestCAWithKeyType(t, ca1, keyType, keyBits)
@ -71,16 +67,16 @@ func testCAAndLeaf_xc(t *testing.T, keyType string, keyBits int) {
// Create a temporary directory for storing the certs
td, err := ioutil.TempDir("", "consul")
assert.Nil(err)
assert.Nil(t, err)
defer os.RemoveAll(td)
// Write the cert
xcbundle := []byte(ca1.RootCert)
xcbundle = append(xcbundle, '\n')
xcbundle = append(xcbundle, []byte(ca2.SigningCert)...)
assert.Nil(ioutil.WriteFile(filepath.Join(td, "ca.pem"), xcbundle, 0644))
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf1.pem"), []byte(leaf1), 0644))
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf2.pem"), []byte(leaf2), 0644))
assert.Nil(t, ioutil.WriteFile(filepath.Join(td, "ca.pem"), xcbundle, 0644))
assert.Nil(t, ioutil.WriteFile(filepath.Join(td, "leaf1.pem"), []byte(leaf1), 0644))
assert.Nil(t, ioutil.WriteFile(filepath.Join(td, "leaf2.pem"), []byte(leaf2), 0644))
// OpenSSL verify the cross-signed leaf (leaf2)
{
@ -89,7 +85,7 @@ func testCAAndLeaf_xc(t *testing.T, keyType string, keyBits int) {
cmd.Dir = td
output, err := cmd.Output()
t.Log(string(output))
assert.Nil(err)
assert.Nil(t, err)
}
// OpenSSL verify the old leaf (leaf1)
@ -99,7 +95,7 @@ func testCAAndLeaf_xc(t *testing.T, keyType string, keyBits int) {
cmd.Dir = td
output, err := cmd.Output()
t.Log(string(output))
assert.Nil(err)
assert.Nil(t, err)
}
}

View File

@ -43,7 +43,6 @@ func TestConnectCARoots_list(t *testing.T) {
t.Parallel()
assertion := assert.New(t)
a := NewTestAgent(t, "")
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
@ -56,16 +55,16 @@ func TestConnectCARoots_list(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/connect/ca/roots", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.ConnectCARoots(resp, req)
assertion.NoError(err)
assert.NoError(t, err)
value := obj.(structs.IndexedCARoots)
assertion.Equal(value.ActiveRootID, ca2.ID)
assertion.Len(value.Roots, 2)
assert.Equal(t, value.ActiveRootID, ca2.ID)
assert.Len(t, value.Roots, 2)
// We should never have the secret information
for _, r := range value.Roots {
assertion.Equal("", r.SigningCert)
assertion.Equal("", r.SigningKey)
assert.Equal(t, "", r.SigningCert)
assert.Equal(t, "", r.SigningKey)
}
}

View File

@ -34,10 +34,6 @@ var ACLSummaries = []prometheus.SummaryDefinition{
Name: []string{"acl", "ResolveToken"},
Help: "This measures the time it takes to resolve an ACL token.",
},
{
Name: []string{"acl", "ResolveTokenToIdentity"},
Help: "This measures the time it takes to resolve an ACL token to an Identity.",
},
}
// These must be kept in sync with the constants in command/agent/acl.go.
@ -133,11 +129,12 @@ func tokenSecretCacheID(token string) string {
return "token-secret:" + token
}
type ACLResolverDelegate interface {
type ACLResolverBackend interface {
ACLDatacenter() string
ResolveIdentityFromToken(token string) (bool, structs.ACLIdentity, error)
ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy, error)
ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error)
// TODO: separate methods for each RPC call (there are 4)
RPC(method string, args interface{}, reply interface{}) error
EnterpriseACLResolverDelegate
}
@ -160,8 +157,9 @@ type ACLResolverConfig struct {
// CacheConfig is a pass through configuration for ACL cache limits
CacheConfig *structs.ACLCachesConfig
// Delegate that implements some helper functionality that is server/client specific
Delegate ACLResolverDelegate
// Backend is used to retrieve data from the state store, or perform RPCs
// to fetch data from other Datacenters.
Backend ACLResolverBackend
// DisableDuration is the length of time to leave ACLs disabled when an RPC
// request to a server indicates that the ACL system is disabled. If set to
@ -219,9 +217,9 @@ type ACLResolverSettings struct {
// ACLResolver is the type to handle all your token and policy resolution needs.
//
// Supports:
// - Resolving tokens locally via the ACLResolverDelegate
// - Resolving policies locally via the ACLResolverDelegate
// - Resolving roles locally via the ACLResolverDelegate
// - Resolving tokens locally via the ACLResolverBackend
// - Resolving policies locally via the ACLResolverBackend
// - Resolving roles locally via the ACLResolverBackend
// - Resolving legacy tokens remotely via an ACL.GetPolicy RPC
// - Resolving tokens remotely via an ACL.TokenRead RPC
// - Resolving policies remotely via an ACL.PolicyResolve RPC
@ -245,8 +243,8 @@ type ACLResolver struct {
config ACLResolverSettings
logger hclog.Logger
delegate ACLResolverDelegate
aclConf *acl.Config
backend ACLResolverBackend
aclConf *acl.Config
tokens *token.Store
@ -263,19 +261,19 @@ type ACLResolver struct {
// disabledLock synchronizes access to disabledUntil
disabledLock sync.RWMutex
agentMasterAuthz acl.Authorizer
agentRecoveryAuthz acl.Authorizer
}
func agentMasterAuthorizer(nodeName string, entMeta *structs.EnterpriseMeta, aclConf *acl.Config) (acl.Authorizer, error) {
func agentRecoveryAuthorizer(nodeName string, entMeta *structs.EnterpriseMeta, aclConf *acl.Config) (acl.Authorizer, error) {
var conf acl.Config
if aclConf != nil {
conf = *aclConf
}
setEnterpriseConf(entMeta, &conf)
// Build a policy for the agent master token.
// Build a policy for the agent recovery token.
//
// The builtin agent master policy allows reading any node information
// The builtin agent recovery policy allows reading any node information
// and allows writes to the agent with the node name of the running agent
// only. This used to allow a prefix match on agent names but that seems
// entirely unnecessary so it is now using an exact match.
@ -298,8 +296,8 @@ func NewACLResolver(config *ACLResolverConfig) (*ACLResolver, error) {
if config == nil {
return nil, fmt.Errorf("ACL Resolver must be initialized with a config")
}
if config.Delegate == nil {
return nil, fmt.Errorf("ACL Resolver must be initialized with a valid delegate")
if config.Backend == nil {
return nil, fmt.Errorf("ACL Resolver must be initialized with a valid backend")
}
if config.Logger == nil {
@ -323,21 +321,21 @@ func NewACLResolver(config *ACLResolverConfig) (*ACLResolver, error) {
return nil, fmt.Errorf("invalid ACL down policy %q", config.Config.ACLDownPolicy)
}
authz, err := agentMasterAuthorizer(config.Config.NodeName, &config.Config.EnterpriseMeta, config.ACLConfig)
authz, err := agentRecoveryAuthorizer(config.Config.NodeName, &config.Config.EnterpriseMeta, config.ACLConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize the agent master authorizer")
return nil, fmt.Errorf("failed to initialize the agent recovery authorizer")
}
return &ACLResolver{
config: config.Config,
logger: config.Logger.Named(logging.ACL),
delegate: config.Delegate,
aclConf: config.ACLConfig,
cache: cache,
disableDuration: config.DisableDuration,
down: down,
tokens: config.Tokens,
agentMasterAuthz: authz,
config: config.Config,
logger: config.Logger.Named(logging.ACL),
backend: config.Backend,
aclConf: config.ACLConfig,
cache: cache,
disableDuration: config.DisableDuration,
down: down,
tokens: config.Tokens,
agentRecoveryAuthz: authz,
}, nil
}
@ -349,7 +347,7 @@ func (r *ACLResolver) fetchAndCacheIdentityFromToken(token string, cached *struc
cacheID := tokenSecretCacheID(token)
req := structs.ACLTokenGetRequest{
Datacenter: r.delegate.ACLDatacenter(),
Datacenter: r.backend.ACLDatacenter(),
TokenID: token,
TokenIDType: structs.ACLTokenSecret,
QueryOptions: structs.QueryOptions{
@ -359,7 +357,7 @@ func (r *ACLResolver) fetchAndCacheIdentityFromToken(token string, cached *struc
}
var resp structs.ACLTokenResponse
err := r.delegate.RPC("ACL.TokenRead", &req, &resp)
err := r.backend.RPC("ACL.TokenRead", &req, &resp)
if err == nil {
if resp.Token == nil {
r.cache.PutIdentity(cacheID, nil)
@ -396,7 +394,7 @@ func (r *ACLResolver) fetchAndCacheIdentityFromToken(token string, cached *struc
// we initiate an RPC for the value.
func (r *ACLResolver) resolveIdentityFromToken(token string) (structs.ACLIdentity, error) {
// Attempt to resolve locally first (local results are not cached)
if done, identity, err := r.delegate.ResolveIdentityFromToken(token); done {
if done, identity, err := r.backend.ResolveIdentityFromToken(token); done {
return identity, err
}
@ -437,7 +435,7 @@ func (r *ACLResolver) resolveIdentityFromToken(token string) (structs.ACLIdentit
func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdentity, policyIDs []string, cached map[string]*structs.PolicyCacheEntry) (map[string]*structs.ACLPolicy, error) {
req := structs.ACLPolicyBatchGetRequest{
Datacenter: r.delegate.ACLDatacenter(),
Datacenter: r.backend.ACLDatacenter(),
PolicyIDs: policyIDs,
QueryOptions: structs.QueryOptions{
Token: identity.SecretToken(),
@ -446,7 +444,7 @@ func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdent
}
var resp structs.ACLPolicyBatchResponse
err := r.delegate.RPC("ACL.PolicyResolve", &req, &resp)
err := r.backend.RPC("ACL.PolicyResolve", &req, &resp)
if err == nil {
out := make(map[string]*structs.ACLPolicy)
for _, policy := range resp.Policies {
@ -492,7 +490,7 @@ func (r *ACLResolver) fetchAndCachePoliciesForIdentity(identity structs.ACLIdent
func (r *ACLResolver) fetchAndCacheRolesForIdentity(identity structs.ACLIdentity, roleIDs []string, cached map[string]*structs.RoleCacheEntry) (map[string]*structs.ACLRole, error) {
req := structs.ACLRoleBatchGetRequest{
Datacenter: r.delegate.ACLDatacenter(),
Datacenter: r.backend.ACLDatacenter(),
RoleIDs: roleIDs,
QueryOptions: structs.QueryOptions{
Token: identity.SecretToken(),
@ -501,7 +499,7 @@ func (r *ACLResolver) fetchAndCacheRolesForIdentity(identity structs.ACLIdentity
}
var resp structs.ACLRoleBatchResponse
err := r.delegate.RPC("ACL.RoleResolve", &req, &resp)
err := r.backend.RPC("ACL.RoleResolve", &req, &resp)
if err == nil {
out := make(map[string]*structs.ACLRole)
for _, role := range resp.Roles {
@ -774,7 +772,7 @@ func (r *ACLResolver) collectPoliciesForIdentity(identity structs.ACLIdentity, p
}
for _, policyID := range policyIDs {
if done, policy, err := r.delegate.ResolvePolicyFromID(policyID); done {
if done, policy, err := r.backend.ResolvePolicyFromID(policyID); done {
if err != nil && !acl.IsErrNotFound(err) {
return nil, err
}
@ -871,7 +869,7 @@ func (r *ACLResolver) collectRolesForIdentity(identity structs.ACLIdentity, role
expCacheMap := make(map[string]*structs.RoleCacheEntry)
for _, roleID := range roleIDs {
if done, role, err := r.delegate.ResolveRoleFromID(roleID); done {
if done, role, err := r.backend.ResolveRoleFromID(roleID); done {
if err != nil && !acl.IsErrNotFound(err) {
return nil, err
}
@ -1049,19 +1047,22 @@ func (r *ACLResolver) resolveLocallyManagedToken(token string) (structs.ACLIdent
}
if r.tokens.IsAgentRecoveryToken(token) {
return structs.NewAgentMasterTokenIdentity(r.config.NodeName, token), r.agentMasterAuthz, true
return structs.NewAgentRecoveryTokenIdentity(r.config.NodeName, token), r.agentRecoveryAuthz, true
}
return r.resolveLocallyManagedEnterpriseToken(token)
}
func (r *ACLResolver) ResolveTokenToIdentityAndAuthorizer(token string) (structs.ACLIdentity, acl.Authorizer, error) {
// ResolveToken to an acl.Authorizer and structs.ACLIdentity. The acl.Authorizer
// can be used to check permissions granted to the token, and the ACLIdentity
// describes the token and any defaults applied to it.
func (r *ACLResolver) ResolveToken(token string) (ACLResolveResult, error) {
if !r.ACLsEnabled() {
return nil, acl.ManageAll(), nil
return ACLResolveResult{Authorizer: acl.ManageAll()}, nil
}
if acl.RootAuthorizer(token) != nil {
return nil, nil, acl.ErrRootDenied
return ACLResolveResult{}, acl.ErrRootDenied
}
// handle the anonymous token
@ -1070,7 +1071,7 @@ func (r *ACLResolver) ResolveTokenToIdentityAndAuthorizer(token string) (structs
}
if ident, authz, ok := r.resolveLocallyManagedToken(token); ok {
return ident, authz, nil
return ACLResolveResult{Authorizer: authz, ACLIdentity: ident}, nil
}
defer metrics.MeasureSince([]string{"acl", "ResolveToken"}, time.Now())
@ -1080,10 +1081,11 @@ func (r *ACLResolver) ResolveTokenToIdentityAndAuthorizer(token string) (structs
r.handleACLDisabledError(err)
if IsACLRemoteError(err) {
r.logger.Error("Error resolving token", "error", err)
return &missingIdentity{reason: "primary-dc-down", token: token}, r.down, nil
ident := &missingIdentity{reason: "primary-dc-down", token: token}
return ACLResolveResult{Authorizer: r.down, ACLIdentity: ident}, nil
}
return nil, nil, err
return ACLResolveResult{}, err
}
// Build the Authorizer
@ -1096,7 +1098,7 @@ func (r *ACLResolver) ResolveTokenToIdentityAndAuthorizer(token string) (structs
authz, err := policies.Compile(r.cache, &conf)
if err != nil {
return nil, nil, err
return ACLResolveResult{}, err
}
chain = append(chain, authz)
@ -1104,42 +1106,32 @@ func (r *ACLResolver) ResolveTokenToIdentityAndAuthorizer(token string) (structs
if err != nil {
if IsACLRemoteError(err) {
r.logger.Error("Error resolving identity defaults", "error", err)
return identity, r.down, nil
return ACLResolveResult{Authorizer: r.down, ACLIdentity: identity}, nil
}
return nil, nil, err
return ACLResolveResult{}, err
} else if authz != nil {
chain = append(chain, authz)
}
chain = append(chain, acl.RootAuthorizer(r.config.ACLDefaultPolicy))
return identity, acl.NewChainedAuthorizer(chain), nil
return ACLResolveResult{Authorizer: acl.NewChainedAuthorizer(chain), ACLIdentity: identity}, nil
}
// TODO: rename to AccessorIDFromToken. This method is only used to retrieve the
// ACLIdentity.ID, so we don't need to return a full ACLIdentity. We could
// return a much smaller type (instad of just a string) to allow for changes
// in the future.
func (r *ACLResolver) ResolveTokenToIdentity(token string) (structs.ACLIdentity, error) {
if !r.ACLsEnabled() {
return nil, nil
type ACLResolveResult struct {
acl.Authorizer
// TODO: likely we can reduce this interface
ACLIdentity structs.ACLIdentity
}
func (a ACLResolveResult) AccessorID() string {
if a.ACLIdentity == nil {
return ""
}
return a.ACLIdentity.ID()
}
if acl.RootAuthorizer(token) != nil {
return nil, acl.ErrRootDenied
}
// handle the anonymous token
if token == "" {
token = anonymousToken
}
if ident, _, ok := r.resolveLocallyManagedToken(token); ok {
return ident, nil
}
defer metrics.MeasureSince([]string{"acl", "ResolveTokenToIdentity"}, time.Now())
return r.resolveIdentityFromToken(token)
func (a ACLResolveResult) Identity() structs.ACLIdentity {
return a.ACLIdentity
}
func (r *ACLResolver) ACLsEnabled() bool {
@ -1158,6 +1150,30 @@ func (r *ACLResolver) ACLsEnabled() bool {
return true
}
func (r *ACLResolver) ResolveTokenAndDefaultMeta(token string, entMeta *structs.EnterpriseMeta, authzContext *acl.AuthorizerContext) (ACLResolveResult, error) {
result, err := r.ResolveToken(token)
if err != nil {
return ACLResolveResult{}, err
}
if entMeta == nil {
entMeta = &structs.EnterpriseMeta{}
}
// Default the EnterpriseMeta based on the Tokens meta or actual defaults
// in the case of unknown identity
if result.ACLIdentity != nil {
entMeta.Merge(result.ACLIdentity.EnterpriseMetadata())
} else {
entMeta.Merge(structs.DefaultEnterpriseMetaInDefaultPartition())
}
// Use the meta to fill in the ACL authorization context
entMeta.FillAuthzContext(authzContext)
return result, err
}
// aclFilter is used to filter results from our state store based on ACL rules
// configured for the provided token.
type aclFilter struct {
@ -1965,7 +1981,7 @@ func filterACLWithAuthorizer(logger hclog.Logger, authorizer acl.Authorizer, sub
// not authorized for read access will be removed from subj.
func filterACL(r *ACLResolver, token string, subj interface{}) error {
// Get the ACL from the token
_, authorizer, err := r.ResolveTokenToIdentityAndAuthorizer(token)
authorizer, err := r.ResolveToken(token)
if err != nil {
return err
}

View File

@ -1,11 +1,10 @@
package consul
import (
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/structs"
)
var clientACLCacheConfig *structs.ACLCachesConfig = &structs.ACLCachesConfig{
var clientACLCacheConfig = &structs.ACLCachesConfig{
// The ACL cache configuration on client agents is more conservative than
// on the servers. It is assumed that individual client agents will have
// fewer distinct identities accessing the client than a server would
@ -23,55 +22,28 @@ var clientACLCacheConfig *structs.ACLCachesConfig = &structs.ACLCachesConfig{
Roles: 128,
}
func (c *Client) ACLDatacenter() string {
// For resolution running on clients, servers within the current datacenter
type clientACLResolverBackend struct {
// TODO: un-embed
*Client
}
func (c *clientACLResolverBackend) ACLDatacenter() string {
// For resolution running on clients servers within the current datacenter
// must be queried first to pick up local tokens.
return c.config.Datacenter
}
func (c *Client) ResolveIdentityFromToken(token string) (bool, structs.ACLIdentity, error) {
func (c *clientACLResolverBackend) ResolveIdentityFromToken(token string) (bool, structs.ACLIdentity, error) {
// clients do no local identity resolution at the moment
return false, nil, nil
}
func (c *Client) ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy, error) {
func (c *clientACLResolverBackend) ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy, error) {
// clients do no local policy resolution at the moment
return false, nil, nil
}
func (c *Client) ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) {
func (c *clientACLResolverBackend) ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) {
// clients do no local role resolution at the moment
return false, nil, nil
}
func (c *Client) ResolveTokenToIdentity(token string) (structs.ACLIdentity, error) {
// not using ResolveTokenToIdentityAndAuthorizer because in this case we don't
// need to resolve the roles, policies and namespace but just want the identity
// information such as accessor id.
return c.acls.ResolveTokenToIdentity(token)
}
// TODO: Server has an identical implementation, remove duplication
func (c *Client) ResolveTokenAndDefaultMeta(token string, entMeta *structs.EnterpriseMeta, authzContext *acl.AuthorizerContext) (acl.Authorizer, error) {
identity, authz, err := c.acls.ResolveTokenToIdentityAndAuthorizer(token)
if err != nil {
return nil, err
}
if entMeta == nil {
entMeta = &structs.EnterpriseMeta{}
}
// Default the EnterpriseMeta based on the Tokens meta or actual defaults
// in the case of unknown identity
if identity != nil {
entMeta.Merge(identity.EnterpriseMetadata())
} else {
entMeta.Merge(structs.DefaultEnterpriseMetaInDefaultPartition())
}
// Use the meta to fill in the ACL authorization context
entMeta.FillAuthzContext(authzContext)
return authz, err
}

View File

@ -724,7 +724,7 @@ func (a *ACL) tokenSetInternal(args *structs.ACLTokenSetRequest, reply *structs.
}
// Purge the identity from the cache to prevent using the previous definition of the identity
a.srv.acls.cache.RemoveIdentity(tokenSecretCacheID(token.SecretID))
a.srv.ACLResolver.cache.RemoveIdentity(tokenSecretCacheID(token.SecretID))
// Don't check expiration times here as it doesn't really matter.
if _, updatedToken, err := a.srv.fsm.State().ACLTokenGetByAccessor(nil, token.AccessorID, nil); err == nil && updatedToken != nil {
@ -876,7 +876,7 @@ func (a *ACL) TokenDelete(args *structs.ACLTokenDeleteRequest, reply *string) er
}
// Purge the identity from the cache to prevent using the previous definition of the identity
a.srv.acls.cache.RemoveIdentity(tokenSecretCacheID(token.SecretID))
a.srv.ACLResolver.cache.RemoveIdentity(tokenSecretCacheID(token.SecretID))
if reply != nil {
*reply = token.AccessorID
@ -1198,7 +1198,7 @@ func (a *ACL) PolicySet(args *structs.ACLPolicySetRequest, reply *structs.ACLPol
}
// Remove from the cache to prevent stale cache usage
a.srv.acls.cache.RemovePolicy(policy.ID)
a.srv.ACLResolver.cache.RemovePolicy(policy.ID)
if _, policy, err := a.srv.fsm.State().ACLPolicyGetByID(nil, policy.ID, &policy.EnterpriseMeta); err == nil && policy != nil {
*reply = *policy
@ -1257,7 +1257,7 @@ func (a *ACL) PolicyDelete(args *structs.ACLPolicyDeleteRequest, reply *string)
return fmt.Errorf("Failed to apply policy delete request: %v", err)
}
a.srv.acls.cache.RemovePolicy(policy.ID)
a.srv.ACLResolver.cache.RemovePolicy(policy.ID)
*reply = policy.Name
@ -1318,12 +1318,12 @@ func (a *ACL) PolicyResolve(args *structs.ACLPolicyBatchGetRequest, reply *struc
}
// get full list of policies for this token
identity, policies, err := a.srv.acls.resolveTokenToIdentityAndPolicies(args.Token)
identity, policies, err := a.srv.ACLResolver.resolveTokenToIdentityAndPolicies(args.Token)
if err != nil {
return err
}
entIdentity, entPolicies, err := a.srv.acls.resolveEnterpriseIdentityAndPolicies(identity)
entIdentity, entPolicies, err := a.srv.ACLResolver.resolveEnterpriseIdentityAndPolicies(identity)
if err != nil {
return err
}
@ -1609,7 +1609,7 @@ func (a *ACL) RoleSet(args *structs.ACLRoleSetRequest, reply *structs.ACLRole) e
}
// Remove from the cache to prevent stale cache usage
a.srv.acls.cache.RemoveRole(role.ID)
a.srv.ACLResolver.cache.RemoveRole(role.ID)
if _, role, err := a.srv.fsm.State().ACLRoleGetByID(nil, role.ID, &role.EnterpriseMeta); err == nil && role != nil {
*reply = *role
@ -1664,7 +1664,7 @@ func (a *ACL) RoleDelete(args *structs.ACLRoleDeleteRequest, reply *string) erro
return fmt.Errorf("Failed to apply role delete request: %v", err)
}
a.srv.acls.cache.RemoveRole(role.ID)
a.srv.ACLResolver.cache.RemoveRole(role.ID)
*reply = role.Name
@ -1719,12 +1719,12 @@ func (a *ACL) RoleResolve(args *structs.ACLRoleBatchGetRequest, reply *structs.A
}
// get full list of roles for this token
identity, roles, err := a.srv.acls.resolveTokenToIdentityAndRoles(args.Token)
identity, roles, err := a.srv.ACLResolver.resolveTokenToIdentityAndRoles(args.Token)
if err != nil {
return err
}
entIdentity, entRoles, err := a.srv.acls.resolveEnterpriseIdentityAndRoles(identity)
entIdentity, entRoles, err := a.srv.ACLResolver.resolveEnterpriseIdentityAndRoles(identity)
if err != nil {
return err
}
@ -2481,7 +2481,7 @@ func (a *ACL) Logout(args *structs.ACLLogoutRequest, reply *bool) error {
}
// Purge the identity from the cache to prevent using the previous definition of the identity
a.srv.acls.cache.RemoveIdentity(tokenSecretCacheID(token.SecretID))
a.srv.ACLResolver.cache.RemoveIdentity(tokenSecretCacheID(token.SecretID))
*reply = true

File diff suppressed because it is too large Load Diff

View File

@ -100,9 +100,14 @@ func (s *Server) LocalTokensEnabled() bool {
return true
}
func (s *Server) ACLDatacenter() string {
// For resolution running on servers the only option
// is to contact the configured ACL Datacenter
type serverACLResolverBackend struct {
// TODO: un-embed
*Server
}
func (s *serverACLResolverBackend) ACLDatacenter() string {
// For resolution running on servers the only option is to contact the
// configured ACL Datacenter
if s.config.PrimaryDatacenter != "" {
return s.config.PrimaryDatacenter
}
@ -114,6 +119,7 @@ func (s *Server) ACLDatacenter() string {
}
// ResolveIdentityFromToken retrieves a token's full identity given its secretID.
// TODO: why does some code call this directly instead of using ACLResolver.ResolveTokenToIdentity ?
func (s *Server) ResolveIdentityFromToken(token string) (bool, structs.ACLIdentity, error) {
// only allow remote RPC resolution when token replication is off and
// when not in the ACL datacenter
@ -131,7 +137,7 @@ func (s *Server) ResolveIdentityFromToken(token string) (bool, structs.ACLIdenti
return s.InPrimaryDatacenter() || index > 0, nil, acl.ErrNotFound
}
func (s *Server) ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy, error) {
func (s *serverACLResolverBackend) ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy, error) {
index, policy, err := s.fsm.State().ACLPolicyGetByID(nil, policyID, nil)
if err != nil {
return true, nil, err
@ -145,7 +151,7 @@ func (s *Server) ResolvePolicyFromID(policyID string) (bool, *structs.ACLPolicy,
return s.InPrimaryDatacenter() || index > 0, policy, acl.ErrNotFound
}
func (s *Server) ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) {
func (s *serverACLResolverBackend) ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error) {
index, role, err := s.fsm.State().ACLRoleGetByID(nil, roleID, nil)
if err != nil {
return true, nil, err
@ -159,47 +165,10 @@ func (s *Server) ResolveRoleFromID(roleID string) (bool, *structs.ACLRole, error
return s.InPrimaryDatacenter() || index > 0, role, acl.ErrNotFound
}
func (s *Server) ResolveToken(token string) (acl.Authorizer, error) {
_, authz, err := s.acls.ResolveTokenToIdentityAndAuthorizer(token)
return authz, err
}
func (s *Server) ResolveTokenToIdentity(token string) (structs.ACLIdentity, error) {
// not using ResolveTokenToIdentityAndAuthorizer because in this case we don't
// need to resolve the roles, policies and namespace but just want the identity
// information such as accessor id.
return s.acls.ResolveTokenToIdentity(token)
}
// TODO: Client has an identical implementation, remove duplication
func (s *Server) ResolveTokenAndDefaultMeta(token string, entMeta *structs.EnterpriseMeta, authzContext *acl.AuthorizerContext) (acl.Authorizer, error) {
identity, authz, err := s.acls.ResolveTokenToIdentityAndAuthorizer(token)
if err != nil {
return nil, err
}
if entMeta == nil {
entMeta = &structs.EnterpriseMeta{}
}
// Default the EnterpriseMeta based on the Tokens meta or actual defaults
// in the case of unknown identity
if identity != nil {
entMeta.Merge(identity.EnterpriseMetadata())
} else {
entMeta.Merge(structs.DefaultEnterpriseMetaInDefaultPartition())
}
// Use the meta to fill in the ACL authorization context
entMeta.FillAuthzContext(authzContext)
return authz, err
}
func (s *Server) filterACL(token string, subj interface{}) error {
return filterACL(s.acls, token, subj)
return filterACL(s.ACLResolver, token, subj)
}
func (s *Server) filterACLWithAuthorizer(authorizer acl.Authorizer, subj interface{}) {
filterACLWithAuthorizer(s.acls.logger, authorizer, subj)
filterACLWithAuthorizer(s.ACLResolver.logger, authorizer, subj)
}

File diff suppressed because it is too large Load Diff

View File

@ -107,7 +107,7 @@ func (s *Server) reapExpiredACLTokens(local, global bool) (int, error) {
// Purge the identities from the cache
for _, secretID := range secretIDs {
s.acls.cache.RemoveIdentity(tokenSecretCacheID(secretID))
s.ACLResolver.cache.RemoveIdentity(tokenSecretCacheID(secretID))
}
return len(req.TokenIDs), nil

View File

@ -58,7 +58,7 @@ func testACLTokenReap_Primary(t *testing.T, local, global bool) {
acl := ACL{srv: s1}
masterTokenAccessorID, err := retrieveTestTokenAccessorForSecret(codec, "root", "dc1", "root")
initialManagementTokenAccessorID, err := retrieveTestTokenAccessorForSecret(codec, "root", "dc1", "root")
require.NoError(t, err)
listTokens := func() (localTokens, globalTokens []string, err error) {
@ -88,9 +88,9 @@ func testACLTokenReap_Primary(t *testing.T, local, global bool) {
t.Helper()
var expectLocal, expectGlobal []string
// The master token and the anonymous token are always going to be
// present and global.
expectGlobal = append(expectGlobal, masterTokenAccessorID)
// The initial management token and the anonymous token are always
// going to be present and global.
expectGlobal = append(expectGlobal, initialManagementTokenAccessorID)
expectGlobal = append(expectGlobal, structs.ACLTokenAnonymousID)
if local {

View File

@ -41,7 +41,7 @@ func TestAutoConfigBackend_CreateACLToken(t *testing.T) {
waitForLeaderEstablishment(t, srv)
r1, err := upsertTestRole(codec, TestDefaultMasterToken, "dc1")
r1, err := upsertTestRole(codec, TestDefaultInitialManagementToken, "dc1")
require.NoError(t, err)
t.Run("predefined-ids", func(t *testing.T) {

View File

@ -77,7 +77,6 @@ func TestAutopilot_CleanupDeadServer(t *testing.T) {
retry.Run(t, func(r *retry.R) { r.Check(wantPeers(s, 5)) })
}
require := require.New(t)
testrpc.WaitForLeader(t, s1.RPC, "dc1")
leaderIndex := -1
for i, s := range servers {
@ -86,7 +85,7 @@ func TestAutopilot_CleanupDeadServer(t *testing.T) {
break
}
}
require.NotEqual(leaderIndex, -1)
require.NotEqual(t, leaderIndex, -1)
// Shutdown two non-leader servers
killed := make(map[string]struct{})

View File

@ -388,7 +388,6 @@ func TestCatalog_Register_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -399,7 +398,7 @@ func TestCatalog_Register_ConnectProxy(t *testing.T) {
// Register
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// List
req := structs.ServiceSpecificRequest{
@ -407,11 +406,11 @@ func TestCatalog_Register_ConnectProxy(t *testing.T) {
ServiceName: args.Service.Service,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName)
assert.Equal(t, structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(t, args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName)
}
// Test an invalid ConnectProxy. We don't need to exhaustively test because
@ -423,7 +422,6 @@ func TestCatalog_Register_ConnectProxy_invalid(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -436,8 +434,8 @@ func TestCatalog_Register_ConnectProxy_invalid(t *testing.T) {
// Register
var out struct{}
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.NotNil(err)
assert.Contains(err.Error(), "DestinationServiceName")
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "DestinationServiceName")
}
// Test that write is required for the proxy destination to register a proxy.
@ -448,7 +446,6 @@ func TestCatalog_Register_ConnectProxy_ACLDestinationServiceName(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -479,7 +476,7 @@ node "foo" {
args.WriteRequest.Token = token
var out struct{}
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.True(acl.IsErrPermissionDenied(err))
assert.True(t, acl.IsErrPermissionDenied(err))
// Register should fail with the right destination but wrong name
args = structs.TestRegisterRequestProxy(t)
@ -487,14 +484,14 @@ node "foo" {
args.Service.Proxy.DestinationServiceName = "foo"
args.WriteRequest.Token = token
err = msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.True(acl.IsErrPermissionDenied(err))
assert.True(t, acl.IsErrPermissionDenied(err))
// Register should work with the right destination
args = structs.TestRegisterRequestProxy(t)
args.Service.Service = "foo"
args.Service.Proxy.DestinationServiceName = "foo"
args.WriteRequest.Token = token
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
}
func TestCatalog_Register_ConnectNative(t *testing.T) {
@ -504,7 +501,6 @@ func TestCatalog_Register_ConnectNative(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -516,7 +512,7 @@ func TestCatalog_Register_ConnectNative(t *testing.T) {
// Register
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// List
req := structs.ServiceSpecificRequest{
@ -524,11 +520,11 @@ func TestCatalog_Register_ConnectNative(t *testing.T) {
ServiceName: args.Service.Service,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindTypical, v.ServiceKind)
assert.True(v.ServiceConnect.Native)
assert.Equal(t, structs.ServiceKindTypical, v.ServiceKind)
assert.True(t, v.ServiceConnect.Native)
}
func TestCatalog_Deregister(t *testing.T) {
@ -2149,7 +2145,6 @@ func TestCatalog_ListServiceNodes_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -2161,7 +2156,7 @@ func TestCatalog_ListServiceNodes_ConnectProxy(t *testing.T) {
// Register the service
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.ServiceSpecificRequest{
@ -2170,11 +2165,11 @@ func TestCatalog_ListServiceNodes_ConnectProxy(t *testing.T) {
TagFilter: false,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName)
assert.Equal(t, structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(t, args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName)
}
func TestCatalog_ServiceNodes_Gateway(t *testing.T) {
@ -2304,7 +2299,6 @@ func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -2316,7 +2310,7 @@ func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
// Register the proxy service
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// Register the service
{
@ -2324,7 +2318,7 @@ func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
args := structs.TestRegisterRequest(t)
args.Service.Service = dst
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
}
// List
@ -2334,22 +2328,22 @@ func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
ServiceName: args.Service.Proxy.DestinationServiceName,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName)
assert.Equal(t, structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(t, args.Service.Proxy.DestinationServiceName, v.ServiceProxy.DestinationServiceName)
// List by non-Connect
req = structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Proxy.DestinationServiceName,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(t, resp.ServiceNodes, 1)
v = resp.ServiceNodes[0]
assert.Equal(args.Service.Proxy.DestinationServiceName, v.ServiceName)
assert.Equal("", v.ServiceProxy.DestinationServiceName)
assert.Equal(t, args.Service.Proxy.DestinationServiceName, v.ServiceName)
assert.Equal(t, "", v.ServiceProxy.DestinationServiceName)
}
// Test that calling ServiceNodes with Connect: true will return
@ -2361,7 +2355,6 @@ func TestCatalog_ListServiceNodes_ConnectDestinationNative(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -2374,7 +2367,7 @@ func TestCatalog_ListServiceNodes_ConnectDestinationNative(t *testing.T) {
args := structs.TestRegisterRequest(t)
args.Service.Connect.Native = true
var out struct{}
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.ServiceSpecificRequest{
@ -2383,20 +2376,20 @@ func TestCatalog_ListServiceNodes_ConnectDestinationNative(t *testing.T) {
ServiceName: args.Service.Service,
}
var resp structs.IndexedServiceNodes
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
require.Len(resp.ServiceNodes, 1)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
require.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
require.Equal(args.Service.Service, v.ServiceName)
require.Equal(t, args.Service.Service, v.ServiceName)
// List by non-Connect
req = structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Service,
}
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
require.Len(resp.ServiceNodes, 1)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
require.Len(t, resp.ServiceNodes, 1)
v = resp.ServiceNodes[0]
require.Equal(args.Service.Service, v.ServiceName)
require.Equal(t, args.Service.Service, v.ServiceName)
}
func TestCatalog_ListServiceNodes_ConnectProxy_ACL(t *testing.T) {
@ -2491,7 +2484,6 @@ func TestCatalog_ListServiceNodes_ConnectNative(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -2504,7 +2496,7 @@ func TestCatalog_ListServiceNodes_ConnectNative(t *testing.T) {
args := structs.TestRegisterRequest(t)
args.Service.Connect.Native = true
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.ServiceSpecificRequest{
@ -2513,10 +2505,10 @@ func TestCatalog_ListServiceNodes_ConnectNative(t *testing.T) {
TagFilter: false,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(t, resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(args.Service.Connect.Native, v.ServiceConnect.Native)
assert.Equal(t, args.Service.Connect.Native, v.ServiceConnect.Native)
}
func TestCatalog_NodeServices(t *testing.T) {
@ -2581,7 +2573,6 @@ func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -2593,7 +2584,7 @@ func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
// Register the service
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.NodeSpecificRequest{
@ -2601,12 +2592,12 @@ func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
Node: args.Node,
}
var resp structs.IndexedNodeServices
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
assert.Len(resp.NodeServices.Services, 1)
assert.Len(t, resp.NodeServices.Services, 1)
v := resp.NodeServices.Services[args.Service.Service]
assert.Equal(structs.ServiceKindConnectProxy, v.Kind)
assert.Equal(args.Service.Proxy.DestinationServiceName, v.Proxy.DestinationServiceName)
assert.Equal(t, structs.ServiceKindConnectProxy, v.Kind)
assert.Equal(t, args.Service.Proxy.DestinationServiceName, v.Proxy.DestinationServiceName)
}
func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
@ -2616,7 +2607,6 @@ func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -2628,7 +2618,7 @@ func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
// Register the service
args := structs.TestRegisterRequest(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.NodeSpecificRequest{
@ -2636,11 +2626,11 @@ func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
Node: args.Node,
}
var resp structs.IndexedNodeServices
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
assert.Len(resp.NodeServices.Services, 1)
assert.Len(t, resp.NodeServices.Services, 1)
v := resp.NodeServices.Services[args.Service.Service]
assert.Equal(args.Service.Connect.Native, v.Connect.Native)
assert.Equal(t, args.Service.Connect.Native, v.Connect.Native)
}
// Used to check for a regression against a known bug
@ -2883,27 +2873,25 @@ func TestCatalog_NodeServices_ACL(t *testing.T) {
}
t.Run("deny", func(t *testing.T) {
require := require.New(t)
args.Token = token("deny")
var reply structs.IndexedNodeServices
err := msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &args, &reply)
require.NoError(err)
require.Nil(reply.NodeServices)
require.True(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.NoError(t, err)
require.Nil(t, reply.NodeServices)
require.True(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
})
t.Run("allow", func(t *testing.T) {
require := require.New(t)
args.Token = token("read")
var reply structs.IndexedNodeServices
err := msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &args, &reply)
require.NoError(err)
require.NotNil(reply.NodeServices)
require.False(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
require.NoError(t, err)
require.NotNil(t, reply.NodeServices)
require.False(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
})
}

View File

@ -56,7 +56,7 @@ type Client struct {
config *Config
// acls is used to resolve tokens to effective policies
acls *ACLResolver
*ACLResolver
// Connection pool to consul servers
connPool *pool.ConnPool
@ -119,7 +119,7 @@ func NewClient(config *Config, deps Deps) (*Client, error) {
aclConfig := ACLResolverConfig{
Config: config.ACLResolverSettings,
Delegate: c,
Backend: &clientACLResolverBackend{Client: c},
Logger: c.logger,
DisableDuration: aclClientDisabledTTL,
CacheConfig: clientACLCacheConfig,
@ -127,7 +127,7 @@ func NewClient(config *Config, deps Deps) (*Client, error) {
Tokens: deps.Tokens,
}
var err error
if c.acls, err = NewACLResolver(&aclConfig); err != nil {
if c.ACLResolver, err = NewACLResolver(&aclConfig); err != nil {
c.Shutdown()
return nil, fmt.Errorf("Failed to create ACL resolver: %v", err)
}
@ -172,7 +172,7 @@ func (c *Client) Shutdown() error {
// Close the connection pool
c.connPool.Shutdown()
c.acls.Close()
c.ACLResolver.Close()
return nil
}

View File

@ -150,8 +150,6 @@ func TestConfigEntry_Apply_ACLDeny(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -191,16 +189,16 @@ operator = "write"
Name: "foo",
}
err = msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &args, &out)
require.NoError(err)
require.NoError(t, err)
state := s1.fsm.State()
_, entry, err := state.ConfigEntry(nil, structs.ServiceDefaults, "foo", nil)
require.NoError(err)
require.NoError(t, err)
serviceConf, ok := entry.(*structs.ServiceConfigEntry)
require.True(ok)
require.Equal("foo", serviceConf.Name)
require.Equal(structs.ServiceDefaults, serviceConf.Kind)
require.True(t, ok)
require.Equal(t, "foo", serviceConf.Name)
require.Equal(t, structs.ServiceDefaults, serviceConf.Kind)
// Try to update the global proxy args with the anonymous token - this should fail.
proxyArgs := structs.ConfigEntryRequest{
@ -219,7 +217,7 @@ operator = "write"
// Now with the privileged token.
proxyArgs.WriteRequest.Token = id
err = msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &proxyArgs, &out)
require.NoError(err)
require.NoError(t, err)
}
func TestConfigEntry_Get(t *testing.T) {
@ -229,8 +227,6 @@ func TestConfigEntry_Get(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -243,7 +239,7 @@ func TestConfigEntry_Get(t *testing.T) {
Name: "foo",
}
state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, entry))
require.NoError(t, state.EnsureConfigEntry(1, entry))
args := structs.ConfigEntryQuery{
Kind: structs.ServiceDefaults,
@ -251,12 +247,12 @@ func TestConfigEntry_Get(t *testing.T) {
Datacenter: s1.config.Datacenter,
}
var out structs.ConfigEntryResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Get", &args, &out))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Get", &args, &out))
serviceConf, ok := out.Entry.(*structs.ServiceConfigEntry)
require.True(ok)
require.Equal("foo", serviceConf.Name)
require.Equal(structs.ServiceDefaults, serviceConf.Kind)
require.True(t, ok)
require.Equal(t, "foo", serviceConf.Name)
require.Equal(t, structs.ServiceDefaults, serviceConf.Kind)
}
func TestConfigEntry_Get_ACLDeny(t *testing.T) {
@ -266,8 +262,6 @@ func TestConfigEntry_Get_ACLDeny(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -290,11 +284,11 @@ operator = "read"
// Create some dummy service/proxy configs to be looked up.
state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "foo",
}))
@ -314,12 +308,12 @@ operator = "read"
// The "foo" service should work.
args.Name = "foo"
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Get", &args, &out))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Get", &args, &out))
serviceConf, ok := out.Entry.(*structs.ServiceConfigEntry)
require.True(ok)
require.Equal("foo", serviceConf.Name)
require.Equal(structs.ServiceDefaults, serviceConf.Kind)
require.True(t, ok)
require.Equal(t, "foo", serviceConf.Name)
require.Equal(t, structs.ServiceDefaults, serviceConf.Kind)
}
func TestConfigEntry_List(t *testing.T) {
@ -329,8 +323,6 @@ func TestConfigEntry_List(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -351,19 +343,19 @@ func TestConfigEntry_List(t *testing.T) {
},
},
}
require.NoError(state.EnsureConfigEntry(1, expected.Entries[0]))
require.NoError(state.EnsureConfigEntry(2, expected.Entries[1]))
require.NoError(t, state.EnsureConfigEntry(1, expected.Entries[0]))
require.NoError(t, state.EnsureConfigEntry(2, expected.Entries[1]))
args := structs.ConfigEntryQuery{
Kind: structs.ServiceDefaults,
Datacenter: "dc1",
}
var out structs.IndexedConfigEntries
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.List", &args, &out))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.List", &args, &out))
expected.Kind = structs.ServiceDefaults
expected.QueryMeta = out.QueryMeta
require.Equal(expected, out)
require.Equal(t, expected, out)
}
func TestConfigEntry_ListAll(t *testing.T) {
@ -466,8 +458,6 @@ func TestConfigEntry_List_ACLDeny(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -490,15 +480,15 @@ operator = "read"
// Create some dummy service/proxy configs to be looked up.
state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "foo",
}))
require.NoError(state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "db",
}))
@ -511,26 +501,26 @@ operator = "read"
}
var out structs.IndexedConfigEntries
err := msgpackrpc.CallWithCodec(codec, "ConfigEntry.List", &args, &out)
require.NoError(err)
require.NoError(t, err)
serviceConf, ok := out.Entries[0].(*structs.ServiceConfigEntry)
require.Len(out.Entries, 1)
require.True(ok)
require.Equal("foo", serviceConf.Name)
require.Equal(structs.ServiceDefaults, serviceConf.Kind)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.Len(t, out.Entries, 1)
require.True(t, ok)
require.Equal(t, "foo", serviceConf.Name)
require.Equal(t, structs.ServiceDefaults, serviceConf.Kind)
require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// Get the global proxy config.
args.Kind = structs.ProxyDefaults
err = msgpackrpc.CallWithCodec(codec, "ConfigEntry.List", &args, &out)
require.NoError(err)
require.NoError(t, err)
proxyConf, ok := out.Entries[0].(*structs.ProxyConfigEntry)
require.Len(out.Entries, 1)
require.True(ok)
require.Equal(structs.ProxyConfigGlobal, proxyConf.Name)
require.Equal(structs.ProxyDefaults, proxyConf.Kind)
require.False(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
require.Len(t, out.Entries, 1)
require.True(t, ok)
require.Equal(t, structs.ProxyConfigGlobal, proxyConf.Name)
require.Equal(t, structs.ProxyDefaults, proxyConf.Kind)
require.False(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
}
func TestConfigEntry_ListAll_ACLDeny(t *testing.T) {
@ -540,8 +530,6 @@ func TestConfigEntry_ListAll_ACLDeny(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -564,15 +552,15 @@ operator = "read"
// Create some dummy service/proxy configs to be looked up.
state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "foo",
}))
require.NoError(state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "db",
}))
@ -585,8 +573,8 @@ operator = "read"
}
var out structs.IndexedGenericConfigEntries
err := msgpackrpc.CallWithCodec(codec, "ConfigEntry.ListAll", &args, &out)
require.NoError(err)
require.Len(out.Entries, 2)
require.NoError(t, err)
require.Len(t, out.Entries, 2)
svcIndex := 0
proxyIndex := 1
if out.Entries[0].GetKind() == structs.ProxyDefaults {
@ -595,15 +583,15 @@ operator = "read"
}
svcConf, ok := out.Entries[svcIndex].(*structs.ServiceConfigEntry)
require.True(ok)
require.True(t, ok)
proxyConf, ok := out.Entries[proxyIndex].(*structs.ProxyConfigEntry)
require.True(ok)
require.True(t, ok)
require.Equal("foo", svcConf.Name)
require.Equal(structs.ServiceDefaults, svcConf.Kind)
require.Equal(structs.ProxyConfigGlobal, proxyConf.Name)
require.Equal(structs.ProxyDefaults, proxyConf.Kind)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.Equal(t, "foo", svcConf.Name)
require.Equal(t, structs.ServiceDefaults, svcConf.Kind)
require.Equal(t, structs.ProxyConfigGlobal, proxyConf.Name)
require.Equal(t, structs.ProxyDefaults, proxyConf.Kind)
require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
}
func TestConfigEntry_Delete(t *testing.T) {
@ -686,8 +674,6 @@ func TestConfigEntry_DeleteCAS(t *testing.T) {
}
t.Parallel()
require := require.New(t)
dir, s := testServer(t)
defer os.RemoveAll(dir)
defer s.Shutdown()
@ -703,11 +689,11 @@ func TestConfigEntry_DeleteCAS(t *testing.T) {
Name: "foo",
}
state := s.fsm.State()
require.NoError(state.EnsureConfigEntry(1, entry))
require.NoError(t, state.EnsureConfigEntry(1, entry))
// Verify it's there.
_, existing, err := state.ConfigEntry(nil, entry.Kind, entry.Name, nil)
require.NoError(err)
require.NoError(t, err)
// Send a delete CAS request with an invalid index.
args := structs.ConfigEntryRequest{
@ -718,24 +704,24 @@ func TestConfigEntry_DeleteCAS(t *testing.T) {
args.Entry.GetRaftIndex().ModifyIndex = existing.GetRaftIndex().ModifyIndex - 1
var rsp structs.ConfigEntryDeleteResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &rsp))
require.False(rsp.Deleted)
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &rsp))
require.False(t, rsp.Deleted)
// Verify the entry was not deleted.
_, existing, err = s.fsm.State().ConfigEntry(nil, structs.ServiceDefaults, "foo", nil)
require.NoError(err)
require.NotNil(existing)
require.NoError(t, err)
require.NotNil(t, existing)
// Restore the valid index and try again.
args.Entry.GetRaftIndex().ModifyIndex = existing.GetRaftIndex().ModifyIndex
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &rsp))
require.True(rsp.Deleted)
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &rsp))
require.True(t, rsp.Deleted)
// Verify the entry was deleted.
_, existing, err = s.fsm.State().ConfigEntry(nil, structs.ServiceDefaults, "foo", nil)
require.NoError(err)
require.Nil(existing)
require.NoError(t, err)
require.Nil(t, existing)
}
func TestConfigEntry_Delete_ACLDeny(t *testing.T) {
@ -745,8 +731,6 @@ func TestConfigEntry_Delete_ACLDeny(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -769,11 +753,11 @@ operator = "write"
// Create some dummy service/proxy configs to be looked up.
state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "foo",
}))
@ -796,12 +780,12 @@ operator = "write"
args.Entry = &structs.ServiceConfigEntry{
Name: "foo",
}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &out))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &out))
// Verify the entry was deleted.
_, existing, err := state.ConfigEntry(nil, structs.ServiceDefaults, "foo", nil)
require.NoError(err)
require.Nil(existing)
require.NoError(t, err)
require.Nil(t, existing)
// Try to delete the global proxy config without a token.
args = structs.ConfigEntryRequest{
@ -817,11 +801,11 @@ operator = "write"
// Now delete with a valid token.
args.WriteRequest.Token = id
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &out))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Delete", &args, &out))
_, existing, err = state.ConfigEntry(nil, structs.ServiceDefaults, "foo", nil)
require.NoError(err)
require.Nil(existing)
require.NoError(t, err)
require.Nil(t, existing)
}
func TestConfigEntry_ResolveServiceConfig(t *testing.T) {
@ -831,8 +815,6 @@ func TestConfigEntry_ResolveServiceConfig(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -841,19 +823,19 @@ func TestConfigEntry_ResolveServiceConfig(t *testing.T) {
// Create a dummy proxy/service config in the state store to look up.
state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
Config: map[string]interface{}{
"foo": 1,
},
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "foo",
Protocol: "http",
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "bar",
Protocol: "grpc",
@ -865,7 +847,7 @@ func TestConfigEntry_ResolveServiceConfig(t *testing.T) {
Upstreams: []string{"bar", "baz"},
}
var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
expected := structs.ServiceConfigResponse{
ProxyConfig: map[string]interface{}{
@ -880,14 +862,14 @@ func TestConfigEntry_ResolveServiceConfig(t *testing.T) {
// Don't know what this is deterministically
QueryMeta: out.QueryMeta,
}
require.Equal(expected, out)
require.Equal(t, expected, out)
_, entry, err := s1.fsm.State().ConfigEntry(nil, structs.ProxyDefaults, structs.ProxyConfigGlobal, nil)
require.NoError(err)
require.NotNil(entry)
require.NoError(t, err)
require.NotNil(t, entry)
proxyConf, ok := entry.(*structs.ProxyConfigEntry)
require.True(ok)
require.Equal(map[string]interface{}{"foo": 1}, proxyConf.Config)
require.True(t, ok)
require.Equal(t, map[string]interface{}{"foo": 1}, proxyConf.Config)
}
func TestConfigEntry_ResolveServiceConfig_TransparentProxy(t *testing.T) {
@ -1426,8 +1408,6 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -1443,19 +1423,19 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
// TestConfigEntry_ResolveServiceConfig_Upstreams_Blocking
state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
Config: map[string]interface{}{
"global": 1,
},
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "foo",
Protocol: "grpc",
}))
require.NoError(state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "bar",
Protocol: "http",
@ -1465,7 +1445,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
{ // Verify that we get the results of proxy-defaults and service-defaults for 'foo'.
var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
&structs.ServiceConfigRequest{
Name: "foo",
Datacenter: "dc1",
@ -1480,7 +1460,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
},
QueryMeta: out.QueryMeta,
}
require.Equal(expected, out)
require.Equal(t, expected, out)
index = out.Index
}
@ -1490,7 +1470,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
start := time.Now()
go func() {
time.Sleep(100 * time.Millisecond)
require.NoError(state.DeleteConfigEntry(index+1,
require.NoError(t, state.DeleteConfigEntry(index+1,
structs.ServiceDefaults,
"foo",
nil,
@ -1499,7 +1479,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
// Re-run the query
var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
&structs.ServiceConfigRequest{
Name: "foo",
Datacenter: "dc1",
@ -1512,10 +1492,10 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
))
// Should block at least 100ms
require.True(time.Since(start) >= 100*time.Millisecond, "too fast")
require.True(t, time.Since(start) >= 100*time.Millisecond, "too fast")
// Check the indexes
require.Equal(out.Index, index+1)
require.Equal(t, out.Index, index+1)
expected := structs.ServiceConfigResponse{
ProxyConfig: map[string]interface{}{
@ -1523,14 +1503,14 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
},
QueryMeta: out.QueryMeta,
}
require.Equal(expected, out)
require.Equal(t, expected, out)
index = out.Index
}
{ // Verify that we get the results of proxy-defaults and service-defaults for 'bar'.
var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
&structs.ServiceConfigRequest{
Name: "bar",
Datacenter: "dc1",
@ -1545,7 +1525,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
},
QueryMeta: out.QueryMeta,
}
require.Equal(expected, out)
require.Equal(t, expected, out)
index = out.Index
}
@ -1555,7 +1535,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
start := time.Now()
go func() {
time.Sleep(100 * time.Millisecond)
require.NoError(state.DeleteConfigEntry(index+1,
require.NoError(t, state.DeleteConfigEntry(index+1,
structs.ProxyDefaults,
structs.ProxyConfigGlobal,
nil,
@ -1564,7 +1544,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
// Re-run the query
var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig",
&structs.ServiceConfigRequest{
Name: "bar",
Datacenter: "dc1",
@ -1577,10 +1557,10 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
))
// Should block at least 100ms
require.True(time.Since(start) >= 100*time.Millisecond, "too fast")
require.True(t, time.Since(start) >= 100*time.Millisecond, "too fast")
// Check the indexes
require.Equal(out.Index, index+1)
require.Equal(t, out.Index, index+1)
expected := structs.ServiceConfigResponse{
ProxyConfig: map[string]interface{}{
@ -1588,7 +1568,7 @@ func TestConfigEntry_ResolveServiceConfig_Blocking(t *testing.T) {
},
QueryMeta: out.QueryMeta,
}
require.Equal(expected, out)
require.Equal(t, expected, out)
}
}
@ -1798,8 +1778,6 @@ func TestConfigEntry_ResolveServiceConfig_UpstreamProxyDefaultsProtocol(t *testi
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -1808,26 +1786,26 @@ func TestConfigEntry_ResolveServiceConfig_UpstreamProxyDefaultsProtocol(t *testi
// Create a dummy proxy/service config in the state store to look up.
state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
Config: map[string]interface{}{
"protocol": "http",
},
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "foo",
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "bar",
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "other",
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "alreadyprotocol",
Protocol: "grpc",
@ -1839,7 +1817,7 @@ func TestConfigEntry_ResolveServiceConfig_UpstreamProxyDefaultsProtocol(t *testi
Upstreams: []string{"bar", "other", "alreadyprotocol", "dne"},
}
var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
expected := structs.ServiceConfigResponse{
ProxyConfig: map[string]interface{}{
@ -1862,7 +1840,7 @@ func TestConfigEntry_ResolveServiceConfig_UpstreamProxyDefaultsProtocol(t *testi
// Don't know what this is deterministically
QueryMeta: out.QueryMeta,
}
require.Equal(expected, out)
require.Equal(t, expected, out)
}
func TestConfigEntry_ResolveServiceConfig_ProxyDefaultsProtocol_UsedForAllUpstreams(t *testing.T) {
@ -1872,8 +1850,6 @@ func TestConfigEntry_ResolveServiceConfig_ProxyDefaultsProtocol_UsedForAllUpstre
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -1882,7 +1858,7 @@ func TestConfigEntry_ResolveServiceConfig_ProxyDefaultsProtocol_UsedForAllUpstre
// Create a dummy proxy/service config in the state store to look up.
state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
Config: map[string]interface{}{
@ -1896,7 +1872,7 @@ func TestConfigEntry_ResolveServiceConfig_ProxyDefaultsProtocol_UsedForAllUpstre
Upstreams: []string{"bar"},
}
var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
expected := structs.ServiceConfigResponse{
ProxyConfig: map[string]interface{}{
@ -1910,7 +1886,7 @@ func TestConfigEntry_ResolveServiceConfig_ProxyDefaultsProtocol_UsedForAllUpstre
// Don't know what this is deterministically
QueryMeta: out.QueryMeta,
}
require.Equal(expected, out)
require.Equal(t, expected, out)
}
func TestConfigEntry_ResolveServiceConfigNoConfig(t *testing.T) {
@ -1920,8 +1896,6 @@ func TestConfigEntry_ResolveServiceConfigNoConfig(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -1936,7 +1910,7 @@ func TestConfigEntry_ResolveServiceConfigNoConfig(t *testing.T) {
Upstreams: []string{"bar", "baz"},
}
var out structs.ServiceConfigResponse
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
expected := structs.ServiceConfigResponse{
ProxyConfig: nil,
@ -1944,7 +1918,7 @@ func TestConfigEntry_ResolveServiceConfigNoConfig(t *testing.T) {
// Don't know what this is deterministically
QueryMeta: out.QueryMeta,
}
require.Equal(expected, out)
require.Equal(t, expected, out)
}
func TestConfigEntry_ResolveServiceConfig_ACLDeny(t *testing.T) {
@ -1954,8 +1928,6 @@ func TestConfigEntry_ResolveServiceConfig_ACLDeny(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -1978,15 +1950,15 @@ operator = "write"
// Create some dummy service/proxy configs to be looked up.
state := s1.fsm.State()
require.NoError(state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
require.NoError(t, state.EnsureConfigEntry(1, &structs.ProxyConfigEntry{
Kind: structs.ProxyDefaults,
Name: structs.ProxyConfigGlobal,
}))
require.NoError(state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(2, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "foo",
}))
require.NoError(state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
require.NoError(t, state.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "db",
}))
@ -2005,7 +1977,7 @@ operator = "write"
// The "foo" service should work.
args.Name = "foo"
require.NoError(msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.ResolveServiceConfig", &args, &out))
}

View File

@ -38,8 +38,6 @@ func TestConnectCARoots(t *testing.T) {
t.Parallel()
assert := assert.New(t)
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -54,29 +52,29 @@ func TestConnectCARoots(t *testing.T) {
ca2 := connect.TestCA(t, nil)
ca2.Active = false
idx, _, err := state.CARoots(nil)
require.NoError(err)
require.NoError(t, err)
ok, err := state.CARootSetCAS(idx, idx, []*structs.CARoot{ca1, ca2})
assert.True(ok)
require.NoError(err)
assert.True(t, ok)
require.NoError(t, err)
_, caCfg, err := state.CAConfig(nil)
require.NoError(err)
require.NoError(t, err)
// Request
args := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var reply structs.IndexedCARoots
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
// Verify
assert.Equal(ca1.ID, reply.ActiveRootID)
assert.Len(reply.Roots, 2)
assert.Equal(t, ca1.ID, reply.ActiveRootID)
assert.Len(t, reply.Roots, 2)
for _, r := range reply.Roots {
// These must never be set, for security
assert.Equal("", r.SigningCert)
assert.Equal("", r.SigningKey)
assert.Equal(t, "", r.SigningCert)
assert.Equal(t, "", r.SigningKey)
}
assert.Equal(fmt.Sprintf("%s.consul", caCfg.ClusterID), reply.TrustDomain)
assert.Equal(t, fmt.Sprintf("%s.consul", caCfg.ClusterID), reply.TrustDomain)
}
func TestConnectCAConfig_GetSet(t *testing.T) {
@ -86,7 +84,6 @@ func TestConnectCAConfig_GetSet(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -101,14 +98,14 @@ func TestConnectCAConfig_GetSet(t *testing.T) {
Datacenter: "dc1",
}
var reply structs.CAConfiguration
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
assert.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config)
assert.NoError(err)
assert.NoError(t, err)
expected, err := ca.ParseConsulCAConfig(s1.config.CAConfig.Config)
assert.NoError(err)
assert.Equal(reply.Provider, s1.config.CAConfig.Provider)
assert.Equal(actual, expected)
assert.NoError(t, err)
assert.Equal(t, reply.Provider, s1.config.CAConfig.Provider)
assert.Equal(t, actual, expected)
}
testState := map[string]string{"foo": "bar"}
@ -141,15 +138,15 @@ func TestConnectCAConfig_GetSet(t *testing.T) {
Datacenter: "dc1",
}
var reply structs.CAConfiguration
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
assert.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config)
assert.NoError(err)
assert.NoError(t, err)
expected, err := ca.ParseConsulCAConfig(newConfig.Config)
assert.NoError(err)
assert.Equal(reply.Provider, newConfig.Provider)
assert.Equal(actual, expected)
assert.Equal(testState, reply.State)
assert.NoError(t, err)
assert.Equal(t, reply.Provider, newConfig.Provider)
assert.Equal(t, actual, expected)
assert.Equal(t, testState, reply.State)
}
}
@ -163,7 +160,7 @@ func TestConnectCAConfig_GetSet_ACLDeny(t *testing.T) {
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
c.ACLInitialManagementToken = TestDefaultMasterToken
c.ACLInitialManagementToken = TestDefaultInitialManagementToken
c.ACLResolverSettings.ACLDefaultPolicy = "deny"
})
defer os.RemoveAll(dir1)
@ -175,11 +172,11 @@ func TestConnectCAConfig_GetSet_ACLDeny(t *testing.T) {
testrpc.WaitForLeader(t, s1.RPC, "dc1")
opReadToken, err := upsertTestTokenWithPolicyRules(
codec, TestDefaultMasterToken, "dc1", `operator = "read"`)
codec, TestDefaultInitialManagementToken, "dc1", `operator = "read"`)
require.NoError(t, err)
opWriteToken, err := upsertTestTokenWithPolicyRules(
codec, TestDefaultMasterToken, "dc1", `operator = "write"`)
codec, TestDefaultInitialManagementToken, "dc1", `operator = "write"`)
require.NoError(t, err)
// Update a config value
@ -215,7 +212,7 @@ pY0heYeK9A6iOLrzqxSerkXXQyj5e9bE4VgUnxgPU6g=
args := &structs.CARequest{
Datacenter: "dc1",
Config: newConfig,
WriteRequest: structs.WriteRequest{Token: TestDefaultMasterToken},
WriteRequest: structs.WriteRequest{Token: TestDefaultInitialManagementToken},
}
var reply interface{}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
@ -254,7 +251,6 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
t.Parallel()
require := require.New(t)
// Setup a server with a built-in CA that as artificially disabled cross
// signing. This is simpler than running tests with external CA dependencies.
dir1, s1 := testServerWithConfig(t, func(c *Config) {
@ -272,8 +268,8 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
Datacenter: "dc1",
}
var rootList structs.IndexedCARoots
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(rootList.Roots, 1)
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(t, rootList.Roots, 1)
oldRoot := rootList.Roots[0]
// Get the starting config
@ -282,20 +278,20 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
Datacenter: "dc1",
}
var reply structs.CAConfiguration
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config)
require.NoError(err)
require.NoError(t, err)
expected, err := ca.ParseConsulCAConfig(s1.config.CAConfig.Config)
require.NoError(err)
require.Equal(reply.Provider, s1.config.CAConfig.Provider)
require.Equal(actual, expected)
require.NoError(t, err)
require.Equal(t, reply.Provider, s1.config.CAConfig.Provider)
require.Equal(t, actual, expected)
}
// Update to a new CA with different key. This should fail since the existing
// CA doesn't support cross signing so can't rotate safely.
_, newKey, err := connect.GeneratePrivateKey()
require.NoError(err)
require.NoError(t, err)
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
@ -309,7 +305,7 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
}
var reply interface{}
err := msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply)
require.EqualError(err, "The current CA Provider does not support cross-signing. "+
require.EqualError(t, err, "The current CA Provider does not support cross-signing. "+
"You can try again with ForceWithoutCrossSigningSet but this may cause disruption"+
" - see documentation for more.")
}
@ -323,7 +319,7 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
}
var reply interface{}
err := msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply)
require.NoError(err)
require.NoError(t, err)
}
// Make sure the new root has been added but with no cross-signed intermediate
@ -332,23 +328,23 @@ func TestConnectCAConfig_GetSetForceNoCrossSigning(t *testing.T) {
Datacenter: "dc1",
}
var reply structs.IndexedCARoots
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
require.Len(reply.Roots, 2)
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
require.Len(t, reply.Roots, 2)
for _, r := range reply.Roots {
if r.ID == oldRoot.ID {
// The old root should no longer be marked as the active root,
// and none of its other fields should have changed.
require.False(r.Active)
require.Equal(r.Name, oldRoot.Name)
require.Equal(r.RootCert, oldRoot.RootCert)
require.Equal(r.SigningCert, oldRoot.SigningCert)
require.Equal(r.IntermediateCerts, oldRoot.IntermediateCerts)
require.False(t, r.Active)
require.Equal(t, r.Name, oldRoot.Name)
require.Equal(t, r.RootCert, oldRoot.RootCert)
require.Equal(t, r.SigningCert, oldRoot.SigningCert)
require.Equal(t, r.IntermediateCerts, oldRoot.IntermediateCerts)
} else {
// The new root should NOT have a valid cross-signed cert from the old
// root as an intermediate.
require.True(r.Active)
require.Empty(r.IntermediateCerts)
require.True(t, r.Active)
require.Empty(t, r.IntermediateCerts)
}
}
}
@ -664,9 +660,6 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
t.Parallel()
assert := assert.New(t)
require := require.New(t)
// Initialize primary as the primary DC
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.Datacenter = "primary"
@ -693,8 +686,8 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
// Capture the current root
rootList, activeRoot, err := getTestRoots(s1, "primary")
require.NoError(err)
require.Len(rootList.Roots, 1)
require.NoError(t, err)
require.Len(t, rootList.Roots, 1)
rootCert := activeRoot
testrpc.WaitForActiveCARoot(t, s1.RPC, "primary", rootCert)
@ -702,15 +695,15 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
// Capture the current intermediate
rootList, activeRoot, err = getTestRoots(s2, "secondary")
require.NoError(err)
require.Len(rootList.Roots, 1)
require.Len(activeRoot.IntermediateCerts, 1)
require.NoError(t, err)
require.Len(t, rootList.Roots, 1)
require.Len(t, activeRoot.IntermediateCerts, 1)
oldIntermediatePEM := activeRoot.IntermediateCerts[0]
// Update the secondary CA config to use a new private key, which should
// cause a re-signing with a new intermediate.
_, newKey, err := connect.GeneratePrivateKey()
assert.NoError(err)
assert.NoError(t, err)
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
@ -725,7 +718,7 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
}
var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
}
// Make sure the new intermediate has replaced the old one in the active root,
@ -736,12 +729,12 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
Datacenter: "secondary",
}
var reply structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
require.Len(reply.Roots, 1)
require.Len(reply.Roots[0].IntermediateCerts, 1)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
require.Len(t, reply.Roots, 1)
require.Len(t, reply.Roots[0].IntermediateCerts, 1)
newIntermediatePEM = reply.Roots[0].IntermediateCerts[0]
require.NotEqual(oldIntermediatePEM, newIntermediatePEM)
require.Equal(reply.Roots[0].RootCert, rootCert.RootCert)
require.NotEqual(t, oldIntermediatePEM, newIntermediatePEM)
require.Equal(t, reply.Roots[0].RootCert, rootCert.RootCert)
}
// Verify the new config was set.
@ -750,14 +743,14 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
Datacenter: "secondary",
}
var reply structs.CAConfiguration
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config)
require.NoError(err)
require.NoError(t, err)
expected, err := ca.ParseConsulCAConfig(newConfig.Config)
require.NoError(err)
assert.Equal(reply.Provider, newConfig.Provider)
assert.Equal(actual, expected)
require.NoError(t, err)
assert.Equal(t, reply.Provider, newConfig.Provider)
assert.Equal(t, actual, expected)
}
// Verify that new leaf certs get the new intermediate bundled
@ -770,28 +763,28 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
CSR: csr,
}
var reply structs.IssuedCert
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
// Verify the leaf cert has the new intermediate.
{
roots := x509.NewCertPool()
assert.True(roots.AppendCertsFromPEM([]byte(rootCert.RootCert)))
assert.True(t, roots.AppendCertsFromPEM([]byte(rootCert.RootCert)))
leaf, err := connect.ParseCert(reply.CertPEM)
require.NoError(err)
require.NoError(t, err)
intermediates := x509.NewCertPool()
require.True(intermediates.AppendCertsFromPEM([]byte(newIntermediatePEM)))
require.True(t, intermediates.AppendCertsFromPEM([]byte(newIntermediatePEM)))
_, err = leaf.Verify(x509.VerifyOptions{
Roots: roots,
Intermediates: intermediates,
})
require.NoError(err)
require.NoError(t, err)
}
// Verify other fields
assert.Equal("web", reply.Service)
assert.Equal(spiffeId.URI().String(), reply.ServiceURI)
assert.Equal(t, "web", reply.Service)
assert.Equal(t, spiffeId.URI().String(), reply.ServiceURI)
}
// Update a minor field in the config that doesn't trigger an intermediate refresh.
@ -810,7 +803,7 @@ func TestConnectCAConfig_UpdateSecondary(t *testing.T) {
}
var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
}
}
}
@ -840,8 +833,6 @@ func TestConnectCASign(t *testing.T) {
for _, tt := range tests {
t.Run(fmt.Sprintf("%s-%d", tt.caKeyType, tt.caKeyBits), func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(cfg *Config) {
cfg.PrimaryDatacenter = "dc1"
cfg.CAConfig.Config["PrivateKeyType"] = tt.caKeyType
@ -864,7 +855,7 @@ func TestConnectCASign(t *testing.T) {
CSR: csr,
}
var reply structs.IssuedCert
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
// Generate a second CSR and request signing
spiffeId2 := connect.TestSpiffeIDService(t, "web2")
@ -875,20 +866,20 @@ func TestConnectCASign(t *testing.T) {
}
var reply2 structs.IssuedCert
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply2))
require.True(reply2.ModifyIndex > reply.ModifyIndex)
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply2))
require.True(t, reply2.ModifyIndex > reply.ModifyIndex)
// Get the current CA
state := s1.fsm.State()
_, ca, err := state.CARootActive(nil)
require.NoError(err)
require.NoError(t, err)
// Verify that the cert is signed by the CA
require.NoError(connect.ValidateLeaf(ca.RootCert, reply.CertPEM, nil))
require.NoError(t, connect.ValidateLeaf(ca.RootCert, reply.CertPEM, nil))
// Verify other fields
assert.Equal("web", reply.Service)
assert.Equal(spiffeId.URI().String(), reply.ServiceURI)
assert.Equal(t, "web", reply.Service)
assert.Equal(t, spiffeId.URI().String(), reply.ServiceURI)
})
}
}
@ -899,7 +890,6 @@ func TestConnectCASign(t *testing.T) {
func BenchmarkConnectCASign(b *testing.B) {
t := &testing.T{}
require := require.New(b)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -919,7 +909,9 @@ func BenchmarkConnectCASign(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
if err := msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply); err != nil {
b.Fatalf("err: %v", err)
}
}
}
@ -930,7 +922,6 @@ func TestConnectCASign_rateLimit(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.Datacenter = "dc1"
c.PrimaryDatacenter = "dc1"
@ -975,7 +966,7 @@ func TestConnectCASign_rateLimit(t *testing.T) {
} else if err.Error() == ErrRateLimited.Error() {
limitedCount++
} else {
require.NoError(err)
require.NoError(t, err)
}
}
// I've only ever seen this as 1/9 however if the test runs slowly on an
@ -985,8 +976,8 @@ func TestConnectCASign_rateLimit(t *testing.T) {
// check that some limiting is being applied. Note that we can't just measure
// the time it took to send them all and infer how many should have succeeded
// without some complex modeling of the token bucket algorithm.
require.Truef(successCount >= 1, "at least 1 CSRs should have succeeded, got %d", successCount)
require.Truef(limitedCount >= 7, "at least 7 CSRs should have been rate limited, got %d", limitedCount)
require.Truef(t, successCount >= 1, "at least 1 CSRs should have succeeded, got %d", successCount)
require.Truef(t, limitedCount >= 7, "at least 7 CSRs should have been rate limited, got %d", limitedCount)
}
func TestConnectCASign_concurrencyLimit(t *testing.T) {
@ -996,7 +987,6 @@ func TestConnectCASign_concurrencyLimit(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.Datacenter = "dc1"
c.PrimaryDatacenter = "dc1"
@ -1056,7 +1046,7 @@ func TestConnectCASign_concurrencyLimit(t *testing.T) {
} else if err.Error() == ErrRateLimited.Error() {
limitedCount++
} else {
require.NoError(err)
require.NoError(t, err)
}
}
@ -1095,7 +1085,7 @@ func TestConnectCASign_concurrencyLimit(t *testing.T) {
// requests were serialized.
t.Logf("min=%s, max=%s", minTime, maxTime)
//t.Fail() // Uncomment to see the time spread logged
require.Truef(successCount >= 1, "at least 1 CSRs should have succeeded, got %d", successCount)
require.Truef(t, successCount >= 1, "at least 1 CSRs should have succeeded, got %d", successCount)
}
func TestConnectCASignValidation(t *testing.T) {

View File

@ -541,7 +541,7 @@ func TestFederationState_List_ACLDeny(t *testing.T) {
gwListEmpty: true,
gwFilteredByACLs: true,
},
"master token": {
"initial management token": {
token: "root",
},
}

View File

@ -105,7 +105,7 @@ func TestFSM_RegisterNode_Service(t *testing.T) {
Service: &structs.NodeService{
ID: "db",
Service: "db",
Tags: []string{"master"},
Tags: []string{"primary"},
Port: 8000,
},
Check: &structs.HealthCheck{
@ -170,7 +170,7 @@ func TestFSM_DeregisterService(t *testing.T) {
Service: &structs.NodeService{
ID: "db",
Service: "db",
Tags: []string{"master"},
Tags: []string{"primary"},
Port: 8000,
},
}
@ -296,7 +296,7 @@ func TestFSM_DeregisterNode(t *testing.T) {
Service: &structs.NodeService{
ID: "db",
Service: "db",
Tags: []string{"master"},
Tags: []string{"primary"},
Port: 8000,
},
Check: &structs.HealthCheck{
@ -1101,10 +1101,9 @@ func TestFSM_Autopilot(t *testing.T) {
func TestFSM_Intention_CRUD(t *testing.T) {
t.Parallel()
assert := assert.New(t)
logger := testutil.Logger(t)
fsm, err := New(nil, logger)
assert.Nil(err)
assert.Nil(t, err)
// Create a new intention.
ixn := structs.IntentionRequest{
@ -1118,19 +1117,19 @@ func TestFSM_Intention_CRUD(t *testing.T) {
{
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err)
assert.Nil(fsm.Apply(makeLog(buf)))
assert.Nil(t, err)
assert.Nil(t, fsm.Apply(makeLog(buf)))
}
// Verify it's in the state store.
{
_, _, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err)
assert.Nil(t, err)
actual.CreateIndex, actual.ModifyIndex = 0, 0
actual.CreatedAt = ixn.Intention.CreatedAt
actual.UpdatedAt = ixn.Intention.UpdatedAt
assert.Equal(ixn.Intention, actual)
assert.Equal(t, ixn.Intention, actual)
}
// Make an update
@ -1138,44 +1137,43 @@ func TestFSM_Intention_CRUD(t *testing.T) {
ixn.Intention.SourceName = "api"
{
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err)
assert.Nil(fsm.Apply(makeLog(buf)))
assert.Nil(t, err)
assert.Nil(t, fsm.Apply(makeLog(buf)))
}
// Verify the update.
{
_, _, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err)
assert.Nil(t, err)
actual.CreateIndex, actual.ModifyIndex = 0, 0
actual.CreatedAt = ixn.Intention.CreatedAt
actual.UpdatedAt = ixn.Intention.UpdatedAt
assert.Equal(ixn.Intention, actual)
assert.Equal(t, ixn.Intention, actual)
}
// Delete
ixn.Op = structs.IntentionOpDelete
{
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err)
assert.Nil(fsm.Apply(makeLog(buf)))
assert.Nil(t, err)
assert.Nil(t, fsm.Apply(makeLog(buf)))
}
// Make sure it's gone.
{
_, _, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err)
assert.Nil(actual)
assert.Nil(t, err)
assert.Nil(t, actual)
}
}
func TestFSM_CAConfig(t *testing.T) {
t.Parallel()
assert := assert.New(t)
logger := testutil.Logger(t)
fsm, err := New(nil, logger)
assert.Nil(err)
assert.Nil(t, err)
// Set the autopilot config using a request.
req := structs.CARequest{
@ -1190,7 +1188,7 @@ func TestFSM_CAConfig(t *testing.T) {
},
}
buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err)
assert.Nil(t, err)
resp := fsm.Apply(makeLog(buf))
if _, ok := resp.(error); ok {
t.Fatalf("bad: %v", resp)
@ -1231,7 +1229,7 @@ func TestFSM_CAConfig(t *testing.T) {
}
_, config, err = fsm.state.CAConfig(nil)
assert.Nil(err)
assert.Nil(t, err)
if config.Provider != "static" {
t.Fatalf("bad: %v", config.Provider)
}
@ -1240,10 +1238,9 @@ func TestFSM_CAConfig(t *testing.T) {
func TestFSM_CARoots(t *testing.T) {
t.Parallel()
assert := assert.New(t)
logger := testutil.Logger(t)
fsm, err := New(nil, logger)
assert.Nil(err)
assert.Nil(t, err)
// Roots
ca1 := connect.TestCA(t, nil)
@ -1258,25 +1255,24 @@ func TestFSM_CARoots(t *testing.T) {
{
buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err)
assert.True(fsm.Apply(makeLog(buf)).(bool))
assert.Nil(t, err)
assert.True(t, fsm.Apply(makeLog(buf)).(bool))
}
// Verify it's in the state store.
{
_, roots, err := fsm.state.CARoots(nil)
assert.Nil(err)
assert.Len(roots, 2)
assert.Nil(t, err)
assert.Len(t, roots, 2)
}
}
func TestFSM_CABuiltinProvider(t *testing.T) {
t.Parallel()
assert := assert.New(t)
logger := testutil.Logger(t)
fsm, err := New(nil, logger)
assert.Nil(err)
assert.Nil(t, err)
// Provider state.
expected := &structs.CAConsulProviderState{
@ -1297,25 +1293,24 @@ func TestFSM_CABuiltinProvider(t *testing.T) {
{
buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err)
assert.True(fsm.Apply(makeLog(buf)).(bool))
assert.Nil(t, err)
assert.True(t, fsm.Apply(makeLog(buf)).(bool))
}
// Verify it's in the state store.
{
_, state, err := fsm.state.CAProviderState("foo")
assert.Nil(err)
assert.Equal(expected, state)
assert.Nil(t, err)
assert.Equal(t, expected, state)
}
}
func TestFSM_ConfigEntry(t *testing.T) {
t.Parallel()
require := require.New(t)
logger := testutil.Logger(t)
fsm, err := New(nil, logger)
require.NoError(err)
require.NoError(t, err)
// Create a simple config entry
entry := &structs.ProxyConfigEntry{
@ -1335,7 +1330,7 @@ func TestFSM_ConfigEntry(t *testing.T) {
{
buf, err := structs.Encode(structs.ConfigEntryRequestType, req)
require.NoError(err)
require.NoError(t, err)
resp := fsm.Apply(makeLog(buf))
if _, ok := resp.(error); ok {
t.Fatalf("bad: %v", resp)
@ -1345,33 +1340,31 @@ func TestFSM_ConfigEntry(t *testing.T) {
// Verify it's in the state store.
{
_, config, err := fsm.state.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err)
require.NoError(t, err)
entry.RaftIndex.CreateIndex = 1
entry.RaftIndex.ModifyIndex = 1
require.Equal(entry, config)
require.Equal(t, entry, config)
}
}
func TestFSM_ConfigEntry_DeleteCAS(t *testing.T) {
t.Parallel()
require := require.New(t)
logger := testutil.Logger(t)
fsm, err := New(nil, logger)
require.NoError(err)
require.NoError(t, err)
// Create a simple config entry and write it to the state store.
entry := &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "global",
}
require.NoError(fsm.state.EnsureConfigEntry(1, entry))
require.NoError(t, fsm.state.EnsureConfigEntry(1, entry))
// Raft index is populated by EnsureConfigEntry, hold on to it so that we can
// restore it later.
raftIndex := entry.RaftIndex
require.NotZero(raftIndex.ModifyIndex)
require.NotZero(t, raftIndex.ModifyIndex)
// Attempt a CAS delete with an invalid index.
entry = entry.Clone()
@ -1383,24 +1376,24 @@ func TestFSM_ConfigEntry_DeleteCAS(t *testing.T) {
Entry: entry,
}
buf, err := structs.Encode(structs.ConfigEntryRequestType, req)
require.NoError(err)
require.NoError(t, err)
// Expect to get boolean false back.
rsp := fsm.Apply(makeLog(buf))
didDelete, isBool := rsp.(bool)
require.True(isBool)
require.False(didDelete)
require.True(t, isBool)
require.False(t, didDelete)
// Attempt a CAS delete with a valid index.
entry.RaftIndex = raftIndex
buf, err = structs.Encode(structs.ConfigEntryRequestType, req)
require.NoError(err)
require.NoError(t, err)
// Expect to get boolean true back.
rsp = fsm.Apply(makeLog(buf))
didDelete, isBool = rsp.(bool)
require.True(isBool)
require.True(didDelete)
require.True(t, isBool)
require.True(t, didDelete)
}
// This adapts another test by chunking the encoded data and then performing
@ -1413,12 +1406,10 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
}
t.Parallel()
require := require.New(t)
assert := assert.New(t)
logger := testutil.Logger(t)
fsm, err := New(nil, logger)
require.NoError(err)
require.NoError(t, err)
var logOfLogs [][]*raft.Log
for i := 0; i < 10; i++ {
@ -1429,7 +1420,7 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
Service: &structs.NodeService{
ID: "db",
Service: "db",
Tags: []string{"master"},
Tags: []string{"primary"},
Port: 8000,
},
Check: &structs.HealthCheck{
@ -1442,7 +1433,7 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
}
buf, err := structs.Encode(structs.RegisterRequestType, req)
require.NoError(err)
require.NoError(t, err)
var logs []*raft.Log
@ -1453,7 +1444,7 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
NumChunks: uint32(len(buf)),
}
chunkBytes, err := proto.Marshal(chunkInfo)
require.NoError(err)
require.NoError(t, err)
logs = append(logs, &raft.Log{
Data: []byte{b},
@ -1468,41 +1459,41 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
// the full set, and out of order.
for _, logs := range logOfLogs {
resp := fsm.chunker.Apply(logs[8])
assert.Nil(resp)
assert.Nil(t, resp)
resp = fsm.chunker.Apply(logs[0])
assert.Nil(resp)
assert.Nil(t, resp)
resp = fsm.chunker.Apply(logs[3])
assert.Nil(resp)
assert.Nil(t, resp)
}
// Verify we are not registered
for i := 0; i < 10; i++ {
_, node, err := fsm.state.GetNode(fmt.Sprintf("foo%d", i), nil)
require.NoError(err)
assert.Nil(node)
require.NoError(t, err)
assert.Nil(t, node)
}
// Snapshot, restore elsewhere, apply the rest of the logs, make sure it
// looks right
snap, err := fsm.Snapshot()
require.NoError(err)
require.NoError(t, err)
defer snap.Release()
sinkBuf := bytes.NewBuffer(nil)
sink := &MockSink{sinkBuf, false}
err = snap.Persist(sink)
require.NoError(err)
require.NoError(t, err)
fsm2, err := New(nil, logger)
require.NoError(err)
require.NoError(t, err)
err = fsm2.Restore(sink)
require.NoError(err)
require.NoError(t, err)
// Verify we are still not registered
for i := 0; i < 10; i++ {
_, node, err := fsm2.state.GetNode(fmt.Sprintf("foo%d", i), nil)
require.NoError(err)
assert.Nil(node)
require.NoError(t, err)
assert.Nil(t, node)
}
// Apply the rest of the logs
@ -1514,43 +1505,41 @@ func TestFSM_Chunking_Lifecycle(t *testing.T) {
default:
resp = fsm2.chunker.Apply(log)
if i != len(logs)-1 {
assert.Nil(resp)
assert.Nil(t, resp)
}
}
}
_, ok := resp.(raftchunking.ChunkingSuccess)
assert.True(ok)
assert.True(t, ok)
}
// Verify we are registered
for i := 0; i < 10; i++ {
_, node, err := fsm2.state.GetNode(fmt.Sprintf("foo%d", i), nil)
require.NoError(err)
assert.NotNil(node)
require.NoError(t, err)
assert.NotNil(t, node)
// Verify service registered
_, services, err := fsm2.state.NodeServices(nil, fmt.Sprintf("foo%d", i), structs.DefaultEnterpriseMetaInDefaultPartition())
require.NoError(err)
require.NotNil(services)
require.NoError(t, err)
require.NotNil(t, services)
_, ok := services.Services["db"]
assert.True(ok)
assert.True(t, ok)
// Verify check
_, checks, err := fsm2.state.NodeChecks(nil, fmt.Sprintf("foo%d", i), nil)
require.NoError(err)
require.NotNil(checks)
assert.Equal(string(checks[0].CheckID), "db")
require.NoError(t, err)
require.NotNil(t, checks)
assert.Equal(t, string(checks[0].CheckID), "db")
}
}
func TestFSM_Chunking_TermChange(t *testing.T) {
t.Parallel()
assert := assert.New(t)
require := require.New(t)
logger := testutil.Logger(t)
fsm, err := New(nil, logger)
require.NoError(err)
require.NoError(t, err)
req := structs.RegisterRequest{
Datacenter: "dc1",
@ -1559,7 +1548,7 @@ func TestFSM_Chunking_TermChange(t *testing.T) {
Service: &structs.NodeService{
ID: "db",
Service: "db",
Tags: []string{"master"},
Tags: []string{"primary"},
Port: 8000,
},
Check: &structs.HealthCheck{
@ -1571,7 +1560,7 @@ func TestFSM_Chunking_TermChange(t *testing.T) {
},
}
buf, err := structs.Encode(structs.RegisterRequestType, req)
require.NoError(err)
require.NoError(t, err)
// Only need two chunks to test this
chunks := [][]byte{
@ -1599,7 +1588,7 @@ func TestFSM_Chunking_TermChange(t *testing.T) {
// We should see nil for both
for _, log := range logs {
resp := fsm.chunker.Apply(log)
assert.Nil(resp)
assert.Nil(t, resp)
}
// Now verify the other baseline, that when the term doesn't change we see
@ -1616,10 +1605,10 @@ func TestFSM_Chunking_TermChange(t *testing.T) {
for i, log := range logs {
resp := fsm.chunker.Apply(log)
if i == 0 {
assert.Nil(resp)
assert.Nil(t, resp)
}
if i == 1 {
assert.NotNil(resp)
assert.NotNil(t, resp)
}
}
}

View File

@ -979,7 +979,6 @@ func TestHealth_ServiceNodes_ConnectProxy_ACL(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -1020,7 +1019,7 @@ node "foo" {
Status: api.HealthPassing,
ServiceID: args.Service.ID,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a service
args = structs.TestRegisterRequestProxy(t)
@ -1032,7 +1031,7 @@ node "foo" {
Status: api.HealthPassing,
ServiceID: args.Service.Service,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a service
args = structs.TestRegisterRequestProxy(t)
@ -1044,7 +1043,7 @@ node "foo" {
Status: api.HealthPassing,
ServiceID: args.Service.Service,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
}
// List w/ token. This should disallow because we don't have permission
@ -1056,8 +1055,8 @@ node "foo" {
QueryOptions: structs.QueryOptions{Token: token},
}
var resp structs.IndexedCheckServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
assert.Len(resp.Nodes, 0)
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
assert.Len(t, resp.Nodes, 0)
// List w/ token. This should work since we're requesting "foo", but should
// also only contain the proxies with names that adhere to our ACL.
@ -1067,8 +1066,8 @@ node "foo" {
ServiceName: "foo",
QueryOptions: structs.QueryOptions{Token: token},
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
assert.Len(resp.Nodes, 1)
assert.Nil(t, msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
assert.Len(t, resp.Nodes, 1)
}
func TestHealth_ServiceNodes_Gateway(t *testing.T) {
@ -1432,8 +1431,6 @@ func TestHealth_NodeChecks_FilterACL(t *testing.T) {
t.Parallel()
require := require.New(t)
dir, token, srv, codec := testACLFilterServer(t)
defer os.RemoveAll(dir)
defer srv.Shutdown()
@ -1446,7 +1443,7 @@ func TestHealth_NodeChecks_FilterACL(t *testing.T) {
}
reply := structs.IndexedHealthChecks{}
err := msgpackrpc.CallWithCodec(codec, "Health.NodeChecks", &opt, &reply)
require.NoError(err)
require.NoError(t, err)
found := false
for _, chk := range reply.HealthChecks {
@ -1457,8 +1454,8 @@ func TestHealth_NodeChecks_FilterACL(t *testing.T) {
t.Fatalf("bad: %#v", reply.HealthChecks)
}
}
require.True(found, "bad: %#v", reply.HealthChecks)
require.True(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.True(t, found, "bad: %#v", reply.HealthChecks)
require.True(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// We've already proven that we call the ACL filtering function so we
// test node filtering down in acl.go for node cases. This also proves
@ -1474,8 +1471,6 @@ func TestHealth_ServiceChecks_FilterACL(t *testing.T) {
t.Parallel()
require := require.New(t)
dir, token, srv, codec := testACLFilterServer(t)
defer os.RemoveAll(dir)
defer srv.Shutdown()
@ -1488,7 +1483,7 @@ func TestHealth_ServiceChecks_FilterACL(t *testing.T) {
}
reply := structs.IndexedHealthChecks{}
err := msgpackrpc.CallWithCodec(codec, "Health.ServiceChecks", &opt, &reply)
require.NoError(err)
require.NoError(t, err)
found := false
for _, chk := range reply.HealthChecks {
@ -1497,14 +1492,14 @@ func TestHealth_ServiceChecks_FilterACL(t *testing.T) {
break
}
}
require.True(found, "bad: %#v", reply.HealthChecks)
require.True(t, found, "bad: %#v", reply.HealthChecks)
opt.ServiceName = "bar"
reply = structs.IndexedHealthChecks{}
err = msgpackrpc.CallWithCodec(codec, "Health.ServiceChecks", &opt, &reply)
require.NoError(err)
require.Empty(reply.HealthChecks)
require.True(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.NoError(t, err)
require.Empty(t, reply.HealthChecks)
require.True(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// We've already proven that we call the ACL filtering function so we
// test node filtering down in acl.go for node cases. This also proves
@ -1520,8 +1515,6 @@ func TestHealth_ServiceNodes_FilterACL(t *testing.T) {
t.Parallel()
require := require.New(t)
dir, token, srv, codec := testACLFilterServer(t)
defer os.RemoveAll(dir)
defer srv.Shutdown()
@ -1534,15 +1527,15 @@ func TestHealth_ServiceNodes_FilterACL(t *testing.T) {
}
reply := structs.IndexedCheckServiceNodes{}
err := msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &opt, &reply)
require.NoError(err)
require.Len(reply.Nodes, 1)
require.NoError(t, err)
require.Len(t, reply.Nodes, 1)
opt.ServiceName = "bar"
reply = structs.IndexedCheckServiceNodes{}
err = msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &opt, &reply)
require.NoError(err)
require.Empty(reply.Nodes)
require.True(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.NoError(t, err)
require.Empty(t, reply.Nodes)
require.True(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// We've already proven that we call the ACL filtering function so we
// test node filtering down in acl.go for node cases. This also proves
@ -1558,8 +1551,6 @@ func TestHealth_ChecksInState_FilterACL(t *testing.T) {
t.Parallel()
require := require.New(t)
dir, token, srv, codec := testACLFilterServer(t)
defer os.RemoveAll(dir)
defer srv.Shutdown()
@ -1572,7 +1563,7 @@ func TestHealth_ChecksInState_FilterACL(t *testing.T) {
}
reply := structs.IndexedHealthChecks{}
err := msgpackrpc.CallWithCodec(codec, "Health.ChecksInState", &opt, &reply)
require.NoError(err)
require.NoError(t, err)
found := false
for _, chk := range reply.HealthChecks {
@ -1583,8 +1574,8 @@ func TestHealth_ChecksInState_FilterACL(t *testing.T) {
t.Fatalf("bad service 'bar': %#v", reply.HealthChecks)
}
}
require.True(found, "missing service 'foo': %#v", reply.HealthChecks)
require.True(reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.True(t, found, "missing service 'foo': %#v", reply.HealthChecks)
require.True(t, reply.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// We've already proven that we call the ACL filtering function so we
// test node filtering down in acl.go for node cases. This also proves

View File

@ -100,20 +100,13 @@ func (s *Intention) Apply(args *structs.IntentionRequest, reply *string) error {
}
// Get the ACL token for the request for the checks below.
identity, authz, err := s.srv.acls.ResolveTokenToIdentityAndAuthorizer(args.Token)
var entMeta structs.EnterpriseMeta
authz, err := s.srv.ACLResolver.ResolveTokenAndDefaultMeta(args.Token, &entMeta, nil)
if err != nil {
return err
}
var accessorID string
var entMeta structs.EnterpriseMeta
if identity != nil {
entMeta.Merge(identity.EnterpriseMetadata())
accessorID = identity.ID()
} else {
entMeta.Merge(structs.DefaultEnterpriseMetaInDefaultPartition())
}
accessorID := authz.AccessorID()
var (
mut *structs.IntentionMutation
legacyWrite bool
@ -432,7 +425,8 @@ func (s *Intention) Get(args *structs.IntentionQueryRequest, reply *structs.Inde
// Get the ACL token for the request for the checks below.
var entMeta structs.EnterpriseMeta
if _, err := s.srv.ResolveTokenAndDefaultMeta(args.Token, &entMeta, nil); err != nil {
authz, err := s.srv.ResolveTokenAndDefaultMeta(args.Token, &entMeta, nil)
if err != nil {
return err
}
@ -479,13 +473,11 @@ func (s *Intention) Get(args *structs.IntentionQueryRequest, reply *structs.Inde
reply.Intentions = structs.Intentions{ixn}
// Filter
if err := s.srv.filterACL(args.Token, reply); err != nil {
return err
}
s.srv.filterACLWithAuthorizer(authz, reply)
// If ACLs prevented any responses, error
if len(reply.Intentions) == 0 {
accessorID := s.aclAccessorID(args.Token)
accessorID := authz.AccessorID()
// todo(kit) Migrate intention access denial logging over to audit logging when we implement it
s.logger.Warn("Request to get intention denied due to ACLs", "intention", args.IntentionID, "accessorID", accessorID)
return acl.ErrPermissionDenied
@ -618,7 +610,7 @@ func (s *Intention) Match(args *structs.IntentionQueryRequest, reply *structs.In
for _, entry := range args.Match.Entries {
entry.FillAuthzContext(&authzContext)
if prefix := entry.Name; prefix != "" && authz.IntentionRead(prefix, &authzContext) != acl.Allow {
accessorID := s.aclAccessorID(args.Token)
accessorID := authz.AccessorID()
// todo(kit) Migrate intention access denial logging over to audit logging when we implement it
s.logger.Warn("Operation on intention prefix denied due to ACLs", "prefix", prefix, "accessorID", accessorID)
return acl.ErrPermissionDenied
@ -708,7 +700,7 @@ func (s *Intention) Check(args *structs.IntentionQueryRequest, reply *structs.In
var authzContext acl.AuthorizerContext
query.FillAuthzContext(&authzContext)
if authz.ServiceRead(prefix, &authzContext) != acl.Allow {
accessorID := s.aclAccessorID(args.Token)
accessorID := authz.AccessorID()
// todo(kit) Migrate intention access denial logging over to audit logging when we implement it
s.logger.Warn("test on intention denied due to ACLs", "prefix", prefix, "accessorID", accessorID)
return acl.ErrPermissionDenied
@ -760,24 +752,6 @@ func (s *Intention) Check(args *structs.IntentionQueryRequest, reply *structs.In
return nil
}
// aclAccessorID is used to convert an ACLToken's secretID to its accessorID for non-
// critical purposes, such as logging. Therefore we interpret all errors as empty-string
// so we can safely log it without handling non-critical errors at the usage site.
func (s *Intention) aclAccessorID(secretID string) string {
_, ident, err := s.srv.ResolveIdentityFromToken(secretID)
if acl.IsErrNotFound(err) {
return ""
}
if err != nil {
s.logger.Debug("non-critical error resolving acl token accessor for logging", "error", err)
return ""
}
if ident == nil {
return ""
}
return ident.ID()
}
func (s *Intention) validateEnterpriseIntention(ixn *structs.Intention) error {
if err := s.srv.validateEnterpriseIntentionPartition(ixn.SourcePartition); err != nil {
return fmt.Errorf("Invalid source partition %q: %v", ixn.SourcePartition, err)

View File

@ -111,7 +111,6 @@ func TestIntentionApply_defaultSourceType(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -135,8 +134,8 @@ func TestIntentionApply_defaultSourceType(t *testing.T) {
var reply string
// Create
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.NotEmpty(reply)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.NotEmpty(t, reply)
// Read
ixn.Intention.ID = reply
@ -146,10 +145,10 @@ func TestIntentionApply_defaultSourceType(t *testing.T) {
IntentionID: ixn.Intention.ID,
}
var resp structs.IndexedIntentions
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp))
require.Len(resp.Intentions, 1)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp))
require.Len(t, resp.Intentions, 1)
actual := resp.Intentions[0]
require.Equal(structs.IntentionSourceConsul, actual.SourceType)
require.Equal(t, structs.IntentionSourceConsul, actual.SourceType)
}
}
@ -161,7 +160,6 @@ func TestIntentionApply_createWithID(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -184,8 +182,8 @@ func TestIntentionApply_createWithID(t *testing.T) {
// Create
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.NotNil(err)
require.Contains(err, "ID must be empty")
require.NotNil(t, err)
require.Contains(t, err, "ID must be empty")
}
// Test basic updating
@ -282,7 +280,6 @@ func TestIntentionApply_updateNonExist(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -304,8 +301,8 @@ func TestIntentionApply_updateNonExist(t *testing.T) {
// Create
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.NotNil(err)
require.Contains(err, "Cannot modify non-existent intention")
require.NotNil(t, err)
require.Contains(t, err, "Cannot modify non-existent intention")
}
// Test basic deleting
@ -316,7 +313,6 @@ func TestIntentionApply_deleteGood(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -346,13 +342,13 @@ func TestIntentionApply_deleteGood(t *testing.T) {
}, &reply), "Cannot delete non-existent intention")
// Create
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.NotEmpty(reply)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.NotEmpty(t, reply)
// Delete
ixn.Op = structs.IntentionOpDelete
ixn.Intention.ID = reply
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Read
ixn.Intention.ID = reply
@ -363,8 +359,8 @@ func TestIntentionApply_deleteGood(t *testing.T) {
}
var resp structs.IndexedIntentions
err := msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp)
require.NotNil(err)
require.Contains(err, ErrIntentionNotFound.Error())
require.NotNil(t, err)
require.Contains(t, err, ErrIntentionNotFound.Error())
}
}
@ -863,7 +859,6 @@ func TestIntentionApply_aclDeny(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -895,11 +890,11 @@ service "foobar" {
// Create without a token should error since default deny
var reply string
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.True(acl.IsErrPermissionDenied(err))
require.True(t, acl.IsErrPermissionDenied(err))
// Now add the token and try again.
ixn.WriteRequest.Token = token
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Read
ixn.Intention.ID = reply
@ -910,10 +905,10 @@ service "foobar" {
QueryOptions: structs.QueryOptions{Token: "root"},
}
var resp structs.IndexedIntentions
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp))
require.Len(resp.Intentions, 1)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp))
require.Len(t, resp.Intentions, 1)
actual := resp.Intentions[0]
require.Equal(resp.Index, actual.ModifyIndex)
require.Equal(t, resp.Index, actual.ModifyIndex)
actual.CreateIndex, actual.ModifyIndex = 0, 0
actual.CreatedAt = ixn.Intention.CreatedAt
@ -921,7 +916,7 @@ service "foobar" {
actual.Hash = ixn.Intention.Hash
//nolint:staticcheck
ixn.Intention.UpdatePrecedence()
require.Equal(ixn.Intention, actual)
require.Equal(t, ixn.Intention, actual)
}
}
@ -937,17 +932,17 @@ func TestIntention_WildcardACLEnforcement(t *testing.T) {
// create some test policies.
writeToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `service_prefix "" { policy = "deny" intentions = "write" }`)
writeToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `service_prefix "" { policy = "deny" intentions = "write" }`)
require.NoError(t, err)
readToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `service_prefix "" { policy = "deny" intentions = "read" }`)
readToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `service_prefix "" { policy = "deny" intentions = "read" }`)
require.NoError(t, err)
exactToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `service "*" { policy = "deny" intentions = "write" }`)
exactToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `service "*" { policy = "deny" intentions = "write" }`)
require.NoError(t, err)
wildcardPrefixToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `service_prefix "*" { policy = "deny" intentions = "write" }`)
wildcardPrefixToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `service_prefix "*" { policy = "deny" intentions = "write" }`)
require.NoError(t, err)
fooToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `service "foo" { policy = "deny" intentions = "write" }`)
fooToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `service "foo" { policy = "deny" intentions = "write" }`)
require.NoError(t, err)
denyToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `service_prefix "" { policy = "deny" intentions = "deny" }`)
denyToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `service_prefix "" { policy = "deny" intentions = "deny" }`)
require.NoError(t, err)
doIntentionCreate := func(t *testing.T, token string, dest string, deny bool) string {
@ -1253,7 +1248,6 @@ func TestIntentionApply_aclDelete(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -1285,18 +1279,18 @@ service "foobar" {
// Create
var reply string
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Try to do a delete with no token; this should get rejected.
ixn.Op = structs.IntentionOpDelete
ixn.Intention.ID = reply
ixn.WriteRequest.Token = ""
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.True(acl.IsErrPermissionDenied(err))
require.True(t, acl.IsErrPermissionDenied(err))
// Try again with the original token. This should go through.
ixn.WriteRequest.Token = token
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Verify it is gone
{
@ -1306,8 +1300,8 @@ service "foobar" {
}
var resp structs.IndexedIntentions
err := msgpackrpc.CallWithCodec(codec, "Intention.Get", req, &resp)
require.NotNil(err)
require.Contains(err.Error(), ErrIntentionNotFound.Error())
require.NotNil(t, err)
require.Contains(t, err.Error(), ErrIntentionNotFound.Error())
}
}
@ -1319,7 +1313,6 @@ func TestIntentionApply_aclUpdate(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -1351,18 +1344,18 @@ service "foobar" {
// Create
var reply string
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Try to do an update without a token; this should get rejected.
ixn.Op = structs.IntentionOpUpdate
ixn.Intention.ID = reply
ixn.WriteRequest.Token = ""
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.True(acl.IsErrPermissionDenied(err))
require.True(t, acl.IsErrPermissionDenied(err))
// Try again with the original token; this should go through.
ixn.WriteRequest.Token = token
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
}
// Test apply with a management token
@ -1373,7 +1366,6 @@ func TestIntentionApply_aclManagement(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -1398,16 +1390,16 @@ func TestIntentionApply_aclManagement(t *testing.T) {
// Create
var reply string
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
ixn.Intention.ID = reply
// Update
ixn.Op = structs.IntentionOpUpdate
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Delete
ixn.Op = structs.IntentionOpDelete
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
}
// Test update changing the name where an ACL won't allow it
@ -1418,7 +1410,6 @@ func TestIntentionApply_aclUpdateChange(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
@ -1450,7 +1441,7 @@ service "foobar" {
// Create
var reply string
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply))
// Try to do an update without a token; this should get rejected.
ixn.Op = structs.IntentionOpUpdate
@ -1458,7 +1449,7 @@ service "foobar" {
ixn.Intention.DestinationName = "foo"
ixn.WriteRequest.Token = token
err := msgpackrpc.CallWithCodec(codec, "Intention.Apply", &ixn, &reply)
require.True(acl.IsErrPermissionDenied(err))
require.True(t, acl.IsErrPermissionDenied(err))
}
// Test reading with ACLs
@ -1570,7 +1561,6 @@ func TestIntentionList(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -1585,9 +1575,9 @@ func TestIntentionList(t *testing.T) {
Datacenter: "dc1",
}
var resp structs.IndexedIntentions
require.Nil(msgpackrpc.CallWithCodec(codec, "Intention.List", req, &resp))
require.NotNil(resp.Intentions)
require.Len(resp.Intentions, 0)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "Intention.List", req, &resp))
require.NotNil(t, resp.Intentions)
require.Len(t, resp.Intentions, 0)
}
}
@ -1607,7 +1597,7 @@ func TestIntentionList_acl(t *testing.T) {
waitForLeaderEstablishment(t, s1)
token, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `service_prefix "foo" { policy = "write" }`)
token, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `service_prefix "foo" { policy = "write" }`)
require.NoError(t, err)
// Create a few records
@ -1620,7 +1610,7 @@ func TestIntentionList_acl(t *testing.T) {
ixn.Intention.SourceNS = "default"
ixn.Intention.DestinationNS = "default"
ixn.Intention.DestinationName = name
ixn.WriteRequest.Token = TestDefaultMasterToken
ixn.WriteRequest.Token = TestDefaultInitialManagementToken
// Create
var reply string
@ -1639,10 +1629,10 @@ func TestIntentionList_acl(t *testing.T) {
})
// Test with management token
t.Run("master-token", func(t *testing.T) {
t.Run("initial-management-token", func(t *testing.T) {
req := &structs.IntentionListRequest{
Datacenter: "dc1",
QueryOptions: structs.QueryOptions{Token: TestDefaultMasterToken},
QueryOptions: structs.QueryOptions{Token: TestDefaultInitialManagementToken},
}
var resp structs.IndexedIntentions
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Intention.List", req, &resp))
@ -1666,7 +1656,7 @@ func TestIntentionList_acl(t *testing.T) {
req := &structs.IntentionListRequest{
Datacenter: "dc1",
QueryOptions: structs.QueryOptions{
Token: TestDefaultMasterToken,
Token: TestDefaultInitialManagementToken,
Filter: "DestinationName == foobar",
},
}
@ -1763,7 +1753,7 @@ func TestIntentionMatch_acl(t *testing.T) {
_, srv, codec := testACLServerWithConfig(t, nil, false)
waitForLeaderEstablishment(t, srv)
token, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `service "bar" { policy = "write" }`)
token, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `service "bar" { policy = "write" }`)
require.NoError(t, err)
// Create some records
@ -1781,7 +1771,7 @@ func TestIntentionMatch_acl(t *testing.T) {
Intention: structs.TestIntention(t),
}
ixn.Intention.DestinationName = v
ixn.WriteRequest.Token = TestDefaultMasterToken
ixn.WriteRequest.Token = TestDefaultInitialManagementToken
// Create
var reply string
@ -1993,7 +1983,7 @@ func TestIntentionCheck_match(t *testing.T) {
_, srv, codec := testACLServerWithConfig(t, nil, false)
waitForLeaderEstablishment(t, srv)
token, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `service "api" { policy = "read" }`)
token, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `service "api" { policy = "read" }`)
require.NoError(t, err)
// Create some intentions
@ -2015,7 +2005,7 @@ func TestIntentionCheck_match(t *testing.T) {
DestinationName: v[1],
Action: structs.IntentionActionAllow,
},
WriteRequest: structs.WriteRequest{Token: TestDefaultMasterToken},
WriteRequest: structs.WriteRequest{Token: TestDefaultInitialManagementToken},
}
// Create
var reply string

View File

@ -401,13 +401,13 @@ func (m *Internal) EventFire(args *structs.EventFireRequest,
}
// Check ACLs
authz, err := m.srv.ResolveToken(args.Token)
authz, err := m.srv.ResolveTokenAndDefaultMeta(args.Token, nil, nil)
if err != nil {
return err
}
if authz.EventWrite(args.Name, nil) != acl.Allow {
accessorID := m.aclAccessorID(args.Token)
accessorID := authz.AccessorID()
m.logger.Warn("user event blocked by ACLs", "event", args.Name, "accessorID", accessorID)
return acl.ErrPermissionDenied
}
@ -433,11 +433,11 @@ func (m *Internal) KeyringOperation(
}
// Check ACLs
identity, authz, err := m.srv.acls.ResolveTokenToIdentityAndAuthorizer(args.Token)
authz, err := m.srv.ACLResolver.ResolveToken(args.Token)
if err != nil {
return err
}
if err := m.srv.validateEnterpriseToken(identity); err != nil {
if err := m.srv.validateEnterpriseToken(authz.Identity()); err != nil {
return err
}
switch args.Operation {
@ -545,21 +545,3 @@ func (m *Internal) executeKeyringOpMgr(
return serfResp, err
}
// aclAccessorID is used to convert an ACLToken's secretID to its accessorID for non-
// critical purposes, such as logging. Therefore we interpret all errors as empty-string
// so we can safely log it without handling non-critical errors at the usage site.
func (m *Internal) aclAccessorID(secretID string) string {
_, ident, err := m.srv.ResolveIdentityFromToken(secretID)
if acl.IsErrNotFound(err) {
return ""
}
if err != nil {
m.logger.Debug("non-critical error resolving acl token accessor for logging", "error", err)
return ""
}
if ident == nil {
return ""
}
return ident.ID()
}

View File

@ -853,7 +853,6 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
}
t.Run("can read all", func(t *testing.T) {
require := require.New(t)
token := tokenWithRules(t, `
node_prefix "" {
@ -870,14 +869,13 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
}
var out structs.IndexedNodesWithGateways
err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out)
require.NoError(err)
require.NotEmpty(out.Nodes)
require.NotEmpty(out.Gateways)
require.False(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
require.NoError(t, err)
require.NotEmpty(t, out.Nodes)
require.NotEmpty(t, out.Gateways)
require.False(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
})
t.Run("cannot read service node", func(t *testing.T) {
require := require.New(t)
token := tokenWithRules(t, `
node "node1" {
@ -894,13 +892,12 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
}
var out structs.IndexedNodesWithGateways
err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out)
require.NoError(err)
require.Empty(out.Nodes)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.NoError(t, err)
require.Empty(t, out.Nodes)
require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
})
t.Run("cannot read service", func(t *testing.T) {
require := require.New(t)
token := tokenWithRules(t, `
node "node1" {
@ -917,13 +914,12 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
}
var out structs.IndexedNodesWithGateways
err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out)
require.NoError(err)
require.Empty(out.Nodes)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.NoError(t, err)
require.Empty(t, out.Nodes)
require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
})
t.Run("cannot read gateway node", func(t *testing.T) {
require := require.New(t)
token := tokenWithRules(t, `
node "node2" {
@ -940,13 +936,12 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
}
var out structs.IndexedNodesWithGateways
err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out)
require.NoError(err)
require.Empty(out.Gateways)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.NoError(t, err)
require.Empty(t, out.Gateways)
require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
})
t.Run("cannot read gateway", func(t *testing.T) {
require := require.New(t)
token := tokenWithRules(t, `
node "node2" {
@ -963,9 +958,9 @@ func TestInternal_ServiceDump_ACL(t *testing.T) {
}
var out structs.IndexedNodesWithGateways
err := msgpackrpc.CallWithCodec(codec, "Internal.ServiceDump", &args, &out)
require.NoError(err)
require.Empty(out.Gateways)
require.True(out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.NoError(t, err)
require.Empty(t, out.Gateways)
require.True(t, out.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
})
}
@ -1790,7 +1785,7 @@ func TestInternal_GatewayIntentions_aclDeny(t *testing.T) {
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForTestAgent(t, s1.RPC, "dc1", testrpc.WithToken(TestDefaultMasterToken))
testrpc.WaitForTestAgent(t, s1.RPC, "dc1", testrpc.WithToken(TestDefaultInitialManagementToken))
// Register terminating gateway and config entry linking it to postgres + redis
{
@ -1809,7 +1804,7 @@ func TestInternal_GatewayIntentions_aclDeny(t *testing.T) {
Status: api.HealthPassing,
ServiceID: "terminating-gateway",
},
WriteRequest: structs.WriteRequest{Token: TestDefaultMasterToken},
WriteRequest: structs.WriteRequest{Token: TestDefaultInitialManagementToken},
}
var regOutput struct{}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &arg, &regOutput))
@ -1834,7 +1829,7 @@ func TestInternal_GatewayIntentions_aclDeny(t *testing.T) {
Op: structs.ConfigEntryUpsert,
Datacenter: "dc1",
Entry: args,
WriteRequest: structs.WriteRequest{Token: TestDefaultMasterToken},
WriteRequest: structs.WriteRequest{Token: TestDefaultInitialManagementToken},
}
var configOutput bool
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConfigEntry.Apply", &req, &configOutput))
@ -1848,7 +1843,7 @@ func TestInternal_GatewayIntentions_aclDeny(t *testing.T) {
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: structs.TestIntention(t),
WriteRequest: structs.WriteRequest{Token: TestDefaultMasterToken},
WriteRequest: structs.WriteRequest{Token: TestDefaultInitialManagementToken},
}
req.Intention.SourceName = "api"
req.Intention.DestinationName = v
@ -1860,7 +1855,7 @@ func TestInternal_GatewayIntentions_aclDeny(t *testing.T) {
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: structs.TestIntention(t),
WriteRequest: structs.WriteRequest{Token: TestDefaultMasterToken},
WriteRequest: structs.WriteRequest{Token: TestDefaultInitialManagementToken},
}
req.Intention.SourceName = v
req.Intention.DestinationName = "api"
@ -1868,7 +1863,7 @@ func TestInternal_GatewayIntentions_aclDeny(t *testing.T) {
}
}
userToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `
userToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `
service_prefix "redis" { policy = "read" }
service_prefix "terminating-gateway" { policy = "read" }
`)
@ -2192,7 +2187,7 @@ func TestInternal_ServiceTopology_ACL(t *testing.T) {
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
c.ACLInitialManagementToken = TestDefaultMasterToken
c.ACLInitialManagementToken = TestDefaultInitialManagementToken
c.ACLResolverSettings.ACLDefaultPolicy = "deny"
})
defer os.RemoveAll(dir1)
@ -2215,10 +2210,10 @@ func TestInternal_ServiceTopology_ACL(t *testing.T) {
// web -> redis exact intention
// redis and redis-proxy on node zip
registerTestTopologyEntries(t, codec, TestDefaultMasterToken)
registerTestTopologyEntries(t, codec, TestDefaultInitialManagementToken)
// Token grants read to: foo/api, foo/api-proxy, bar/web, baz/web
userToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `
userToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `
node_prefix "" { policy = "read" }
service_prefix "api" { policy = "read" }
service "web" { policy = "read" }
@ -2331,7 +2326,7 @@ func TestInternal_IntentionUpstreams_ACL(t *testing.T) {
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
c.ACLInitialManagementToken = TestDefaultMasterToken
c.ACLInitialManagementToken = TestDefaultInitialManagementToken
c.ACLResolverSettings.ACLDefaultPolicy = "deny"
})
defer os.RemoveAll(dir1)
@ -2349,11 +2344,11 @@ func TestInternal_IntentionUpstreams_ACL(t *testing.T) {
// Intentions
// * -> * (deny) intention
// web -> api (allow)
registerIntentionUpstreamEntries(t, codec, TestDefaultMasterToken)
registerIntentionUpstreamEntries(t, codec, TestDefaultInitialManagementToken)
t.Run("valid token", func(t *testing.T) {
// Token grants read to read api service
userToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `
userToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `
service_prefix "api" { policy = "read" }
`)
require.NoError(t, err)
@ -2379,7 +2374,7 @@ service_prefix "api" { policy = "read" }
t.Run("invalid token filters results", func(t *testing.T) {
// Token grants read to read an unrelated service, mongo
userToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultMasterToken, "dc1", `
userToken, err := upsertTestTokenWithPolicyRules(codec, TestDefaultInitialManagementToken, "dc1", `
service_prefix "mongo" { policy = "read" }
`)
require.NoError(t, err)

View File

@ -363,7 +363,7 @@ func (s *Server) initializeACLs(ctx context.Context) error {
// Purge the cache, since it could've changed while we were not the
// leader.
s.acls.cache.Purge()
s.ACLResolver.cache.Purge()
// Purge the auth method validators since they could've changed while we
// were not leader.

View File

@ -197,11 +197,18 @@ func (c *CAManager) secondarySetPrimaryRoots(newRoots structs.IndexedCARoots) {
c.primaryRoots = newRoots
}
func (c *CAManager) secondaryGetPrimaryRoots() structs.IndexedCARoots {
func (c *CAManager) secondaryGetActivePrimaryCARoot() (*structs.CARoot, error) {
// TODO: this could be a different lock, as long as its the same lock in secondarySetPrimaryRoots
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.primaryRoots
primaryRoots := c.primaryRoots
c.stateLock.Unlock()
for _, root := range primaryRoots.Roots {
if root.ID == primaryRoots.ActiveRootID && root.Active {
return root, nil
}
}
return nil, fmt.Errorf("primary datacenter does not have an active root CA for Connect")
}
// initializeCAConfig is used to initialize the CA config if necessary
@ -475,16 +482,12 @@ func (c *CAManager) primaryInitialize(provider ca.Provider, conf *structs.CAConf
if err := provider.Configure(pCfg); err != nil {
return fmt.Errorf("error configuring provider: %v", err)
}
if err := provider.GenerateRoot(); err != nil {
root, err := provider.GenerateRoot()
if err != nil {
return fmt.Errorf("error generating CA root certificate: %v", err)
}
// Get the active root cert from the CA
rootPEM, err := provider.ActiveRoot()
if err != nil {
return fmt.Errorf("error getting root cert: %v", err)
}
rootCA, err := parseCARoot(rootPEM, conf.Provider, conf.ClusterID)
rootCA, err := parseCARoot(root.PEM, conf.Provider, conf.ClusterID)
if err != nil {
return err
}
@ -602,79 +605,45 @@ func (c *CAManager) getLeafSigningCertFromRoot(root *structs.CARoot) string {
return root.IntermediateCerts[len(root.IntermediateCerts)-1]
}
// secondaryInitializeIntermediateCA runs the routine for generating an intermediate CA CSR and getting
// it signed by the primary DC if the root CA of the primary DC has changed since the last
// intermediate. It should only be called while the state lock is held by setting the state
// to non-ready.
// secondaryInitializeIntermediateCA generates a Certificate Signing Request (CSR)
// for the intermediate CA that is used to sign leaf certificates in the secondary.
// The CSR is signed by the primary DC and then persisted in the state store.
//
// This method should only be called while the state lock is held by setting the
// state to non-ready.
func (c *CAManager) secondaryInitializeIntermediateCA(provider ca.Provider, config *structs.CAConfiguration) error {
activeIntermediate, err := provider.ActiveIntermediate()
if err != nil {
return err
}
var (
storedRootID string
expectedSigningKeyID string
currentSigningKeyID string
activeSecondaryRoot *structs.CARoot
)
_, activeRoot, err := c.delegate.State().CARootActive(nil)
if err != nil {
return err
}
var currentSigningKeyID string
if activeRoot != nil {
currentSigningKeyID = activeRoot.SigningKeyID
}
var expectedSigningKeyID string
if activeIntermediate != "" {
// In the event that we already have an intermediate, we must have
// already replicated some primary root information locally, so check
// to see if we're up to date by fetching the rootID and the
// signingKeyID used in the secondary.
//
// Note that for the same rootID the primary representation of the root
// will have a different SigningKeyID field than the secondary
// representation of the same root. This is because it's derived from
// the intermediate which is different in all datacenters.
storedRoot, err := provider.ActiveRoot()
if err != nil {
return err
}
storedRootID, err = connect.CalculateCertFingerprint(storedRoot)
if err != nil {
return fmt.Errorf("error parsing root fingerprint: %v, %#v", err, storedRoot)
}
intermediateCert, err := connect.ParseCert(activeIntermediate)
if err != nil {
return fmt.Errorf("error parsing active intermediate cert: %v", err)
}
expectedSigningKeyID = connect.EncodeSigningKeyID(intermediateCert.SubjectKeyId)
// This will fetch the secondary's exact current representation of the
// active root. Note that this data should only be used if the IDs
// match, otherwise it's out of date and should be regenerated.
_, activeSecondaryRoot, err = c.delegate.State().CARootActive(nil)
if err != nil {
return err
}
if activeSecondaryRoot != nil {
currentSigningKeyID = activeSecondaryRoot.SigningKeyID
}
}
// Determine which of the provided PRIMARY representations of roots is the
// active one. We'll use this as a template to generate any new root
// representations meant for this secondary.
var newActiveRoot *structs.CARoot
primaryRoots := c.secondaryGetPrimaryRoots()
for _, root := range primaryRoots.Roots {
if root.ID == primaryRoots.ActiveRootID && root.Active {
newActiveRoot = root
break
}
}
if newActiveRoot == nil {
return fmt.Errorf("primary datacenter does not have an active root CA for Connect")
newActiveRoot, err := c.secondaryGetActivePrimaryCARoot()
if err != nil {
return err
}
// Get a signed intermediate from the primary DC if the provider
// hasn't been initialized yet or if the primary's root has changed.
needsNewIntermediate := false
if activeIntermediate == "" || storedRootID != primaryRoots.ActiveRootID {
needsNewIntermediate := activeIntermediate == ""
if activeRoot != nil && newActiveRoot.ID != activeRoot.ID {
needsNewIntermediate = true
}
@ -684,28 +653,19 @@ func (c *CAManager) secondaryInitializeIntermediateCA(provider ca.Provider, conf
needsNewIntermediate = true
}
newIntermediate := false
if needsNewIntermediate {
if err := c.secondaryRenewIntermediate(provider, newActiveRoot); err != nil {
if err := c.secondaryRequestNewSigningCert(provider, newActiveRoot); err != nil {
return err
}
newIntermediate = true
} else {
// Discard the primary's representation since our local one is
// sufficiently up to date.
newActiveRoot = activeSecondaryRoot
}
// Update the roots list in the state store if there's a new active root.
state := c.delegate.State()
_, activeRoot, err := state.CARootActive(nil)
if err != nil {
return err
newActiveRoot = activeRoot
}
// Determine whether a root update is needed, and persist the roots/config accordingly.
var newRoot *structs.CARoot
if activeRoot == nil || activeRoot.ID != newActiveRoot.ID || newIntermediate {
if activeRoot == nil || needsNewIntermediate {
newRoot = newActiveRoot
}
if err := c.persistNewRootAndConfig(provider, newRoot, config); err != nil {
@ -899,15 +859,12 @@ func (c *CAManager) UpdateConfiguration(args *structs.CARequest) (reterr error)
}
func (c *CAManager) primaryUpdateRootCA(newProvider ca.Provider, args *structs.CARequest, config *structs.CAConfiguration) error {
if err := newProvider.GenerateRoot(); err != nil {
providerRoot, err := newProvider.GenerateRoot()
if err != nil {
return fmt.Errorf("error generating CA root certificate: %v", err)
}
newRootPEM, err := newProvider.ActiveRoot()
if err != nil {
return err
}
newRootPEM := providerRoot.PEM
newActiveRoot, err := parseCARoot(newRootPEM, args.Config.Provider, args.Config.ClusterID)
if err != nil {
return err
@ -961,6 +918,7 @@ func (c *CAManager) primaryUpdateRootCA(newProvider ca.Provider, args *structs.C
// get a cross-signed certificate.
// 3. Take the active root for the new provider and append the intermediate from step 2
// to its list of intermediates.
// TODO: this cert is already parsed once in parseCARoot, could we remove the second parse?
newRoot, err := connect.ParseCert(newRootPEM)
if err != nil {
return err
@ -1070,9 +1028,11 @@ func (c *CAManager) primaryRenewIntermediate(provider ca.Provider, newActiveRoot
return nil
}
// secondaryRenewIntermediate should only be called while the state lock is held by
// setting the state to non-ready.
func (c *CAManager) secondaryRenewIntermediate(provider ca.Provider, newActiveRoot *structs.CARoot) error {
// secondaryRequestNewSigningCert creates a Certificate Signing Request, sends
// the request to the primary, and stores the received certificate in the
// provider.
// Should only be called while the state lock is held by setting the state to non-ready.
func (c *CAManager) secondaryRequestNewSigningCert(provider ca.Provider, newActiveRoot *structs.CARoot) error {
csr, err := provider.GenerateIntermediateCSR()
if err != nil {
return err
@ -1187,7 +1147,7 @@ func (c *CAManager) RenewIntermediate(ctx context.Context, isPrimary bool) error
// Enough time has passed, go ahead with getting a new intermediate.
renewalFunc := c.primaryRenewIntermediate
if !isPrimary {
renewalFunc = c.secondaryRenewIntermediate
renewalFunc = c.secondaryRequestNewSigningCert
}
errCh := make(chan error, 1)
go func() {

View File

@ -17,6 +17,7 @@ import (
"time"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
vaultapi "github.com/hashicorp/vault/api"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -227,9 +228,8 @@ type mockCAProvider struct {
func (m *mockCAProvider) Configure(cfg ca.ProviderConfig) error { return nil }
func (m *mockCAProvider) State() (map[string]string, error) { return nil, nil }
func (m *mockCAProvider) GenerateRoot() error { return nil }
func (m *mockCAProvider) ActiveRoot() (string, error) {
return m.rootPEM, nil
func (m *mockCAProvider) GenerateRoot() (ca.RootResult, error) {
return ca.RootResult{PEM: m.rootPEM}, nil
}
func (m *mockCAProvider) GenerateIntermediateCSR() (string, error) {
m.callbackCh <- "provider/GenerateIntermediateCSR"
@ -607,6 +607,88 @@ func TestCAManager_UpdateConfiguration_Vault_Primary(t *testing.T) {
require.Equal(t, connect.HexString(cert.SubjectKeyId), newRoot.SigningKeyID)
}
func TestCAManager_Initialize_Vault_WithIntermediateAsPrimaryCA(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
ca.SkipIfVaultNotPresent(t)
vault := ca.NewTestVaultServer(t)
vclient := vault.Client()
generateExternalRootCA(t, vclient)
meshRootPath := "pki-root"
primaryCert := setupPrimaryCA(t, vclient, meshRootPath)
_, s1 := testServerWithConfig(t, func(c *Config) {
c.CAConfig = &structs.CAConfiguration{
Provider: "vault",
Config: map[string]interface{}{
"Address": vault.Addr,
"Token": vault.RootToken,
"RootPKIPath": meshRootPath,
"IntermediatePKIPath": "pki-intermediate/",
// TODO: there are failures to init the CA system if these are not set
// to the values of the already initialized CA.
"PrivateKeyType": "ec",
"PrivateKeyBits": 256,
},
}
})
defer s1.Shutdown()
runStep(t, "check primary DC", func(t *testing.T) {
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
codec := rpcClient(t, s1)
roots := structs.IndexedCARoots{}
err := msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", &structs.DCSpecificRequest{}, &roots)
require.NoError(t, err)
require.Len(t, roots.Roots, 1)
require.Equal(t, primaryCert, roots.Roots[0].RootCert)
leafCertPEM := getLeafCert(t, codec, roots.TrustDomain, "dc1")
verifyLeafCert(t, roots.Roots[0], leafCertPEM)
})
// TODO: renew primary leaf signing cert
// TODO: rotate root
runStep(t, "run secondary DC", func(t *testing.T) {
_, sDC2 := testServerWithConfig(t, func(c *Config) {
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc1"
c.CAConfig = &structs.CAConfiguration{
Provider: "vault",
Config: map[string]interface{}{
"Address": vault.Addr,
"Token": vault.RootToken,
"RootPKIPath": meshRootPath,
"IntermediatePKIPath": "pki-secondary/",
// TODO: there are failures to init the CA system if these are not set
// to the values of the already initialized CA.
"PrivateKeyType": "ec",
"PrivateKeyBits": 256,
},
}
})
defer sDC2.Shutdown()
joinWAN(t, sDC2, s1)
testrpc.WaitForActiveCARoot(t, sDC2.RPC, "dc2", nil)
codec := rpcClient(t, sDC2)
roots := structs.IndexedCARoots{}
err := msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", &structs.DCSpecificRequest{}, &roots)
require.NoError(t, err)
require.Len(t, roots.Roots, 1)
leafCertPEM := getLeafCert(t, codec, roots.TrustDomain, "dc2")
verifyLeafCert(t, roots.Roots[0], leafCertPEM)
// TODO: renew secondary leaf signing cert
})
}
func getLeafCert(t *testing.T, codec rpc.ClientCodec, trustDomain string, dc string) string {
pk, _, err := connect.GeneratePrivateKey()
require.NoError(t, err)
@ -625,3 +707,58 @@ func getLeafCert(t *testing.T, codec rpc.ClientCodec, trustDomain string, dc str
return cert.CertPEM
}
func generateExternalRootCA(t *testing.T, client *vaultapi.Client) string {
t.Helper()
err := client.Sys().Mount("corp", &vaultapi.MountInput{
Type: "pki",
Description: "External root, probably corporate CA",
Config: vaultapi.MountConfigInput{
MaxLeaseTTL: "2400h",
DefaultLeaseTTL: "1h",
},
})
require.NoError(t, err, "failed to mount")
resp, err := client.Logical().Write("corp/root/generate/internal", map[string]interface{}{
"common_name": "corporate CA",
"ttl": "2400h",
})
require.NoError(t, err, "failed to generate root")
return resp.Data["certificate"].(string)
}
func setupPrimaryCA(t *testing.T, client *vaultapi.Client, path string) string {
t.Helper()
err := client.Sys().Mount(path, &vaultapi.MountInput{
Type: "pki",
Description: "primary CA for Consul CA",
Config: vaultapi.MountConfigInput{
MaxLeaseTTL: "2200h",
DefaultLeaseTTL: "1h",
},
})
require.NoError(t, err, "failed to mount")
out, err := client.Logical().Write(path+"/intermediate/generate/internal", map[string]interface{}{
"common_name": "primary CA",
"ttl": "2200h",
"key_type": "ec",
"key_bits": 256,
})
require.NoError(t, err, "failed to generate root")
intermediate, err := client.Logical().Write("corp/root/sign-intermediate", map[string]interface{}{
"csr": out.Data["csr"],
"use_csr_values": true,
"format": "pem_bundle",
"ttl": "2200h",
})
require.NoError(t, err, "failed to sign intermediate")
_, err = client.Logical().Write(path+"/intermediate/set-signed", map[string]interface{}{
"certificate": intermediate.Data["certificate"],
})
require.NoError(t, err, "failed to set signed intermediate")
return ca.EnsureTrailingNewline(intermediate.Data["certificate"].(string))
}

View File

@ -196,7 +196,7 @@ func TestCAManager_Initialize_Secondary(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("%s-%d", tc.keyType, tc.keyBits), func(t *testing.T) {
masterToken := "8a85f086-dd95-4178-b128-e10902767c5c"
initialManagementToken := "8a85f086-dd95-4178-b128-e10902767c5c"
// Initialize primary as the primary DC
dir1, s1 := testServerWithConfig(t, func(c *Config) {
@ -204,7 +204,7 @@ func TestCAManager_Initialize_Secondary(t *testing.T) {
c.PrimaryDatacenter = "primary"
c.Build = "1.6.0"
c.ACLsEnabled = true
c.ACLInitialManagementToken = masterToken
c.ACLInitialManagementToken = initialManagementToken
c.ACLResolverSettings.ACLDefaultPolicy = "deny"
c.CAConfig.Config["PrivateKeyType"] = tc.keyType
c.CAConfig.Config["PrivateKeyBits"] = tc.keyBits
@ -213,7 +213,7 @@ func TestCAManager_Initialize_Secondary(t *testing.T) {
defer os.RemoveAll(dir1)
defer s1.Shutdown()
s1.tokens.UpdateAgentToken(masterToken, token.TokenSourceConfig)
s1.tokens.UpdateAgentToken(initialManagementToken, token.TokenSourceConfig)
testrpc.WaitForLeader(t, s1.RPC, "primary")
@ -232,8 +232,8 @@ func TestCAManager_Initialize_Secondary(t *testing.T) {
defer os.RemoveAll(dir2)
defer s2.Shutdown()
s2.tokens.UpdateAgentToken(masterToken, token.TokenSourceConfig)
s2.tokens.UpdateReplicationToken(masterToken, token.TokenSourceConfig)
s2.tokens.UpdateAgentToken(initialManagementToken, token.TokenSourceConfig)
s2.tokens.UpdateReplicationToken(initialManagementToken, token.TokenSourceConfig)
// Create the WAN link
joinWAN(t, s2, s1)
@ -327,7 +327,6 @@ func TestCAManager_RenewIntermediate_Vault_Primary(t *testing.T) {
// no parallel execution because we change globals
patchIntermediateCertRenewInterval(t)
require := require.New(t)
testVault := ca.NewTestVaultServer(t)
@ -354,15 +353,15 @@ func TestCAManager_RenewIntermediate_Vault_Primary(t *testing.T) {
store := s1.caManager.delegate.State()
_, activeRoot, err := store.CARootActive(nil)
require.NoError(err)
require.NoError(t, err)
t.Log("original SigningKeyID", activeRoot.SigningKeyID)
intermediatePEM := s1.caManager.getLeafSigningCertFromRoot(activeRoot)
intermediateCert, err := connect.ParseCert(intermediatePEM)
require.NoError(err)
require.NoError(t, err)
require.Equal(connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
require.Equal(intermediatePEM, s1.caManager.getLeafSigningCertFromRoot(activeRoot))
require.Equal(t, connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
require.Equal(t, intermediatePEM, s1.caManager.getLeafSigningCertFromRoot(activeRoot))
// Wait for dc1's intermediate to be refreshed.
retry.Run(t, func(r *retry.R) {
@ -382,12 +381,12 @@ func TestCAManager_RenewIntermediate_Vault_Primary(t *testing.T) {
codec := rpcClient(t, s1)
roots := structs.IndexedCARoots{}
err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", &structs.DCSpecificRequest{}, &roots)
require.NoError(err)
require.Len(roots.Roots, 1)
require.NoError(t, err)
require.Len(t, roots.Roots, 1)
activeRoot = roots.Active()
require.Equal(connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
require.Equal(intermediatePEM, s1.caManager.getLeafSigningCertFromRoot(activeRoot))
require.Equal(t, connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
require.Equal(t, intermediatePEM, s1.caManager.getLeafSigningCertFromRoot(activeRoot))
// Have the new intermediate sign a leaf cert and make sure the chain is correct.
spiffeService := &connect.SpiffeIDService{
@ -401,7 +400,7 @@ func TestCAManager_RenewIntermediate_Vault_Primary(t *testing.T) {
req := structs.CASignRequest{CSR: csr}
cert := structs.IssuedCert{}
err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", &req, &cert)
require.NoError(err)
require.NoError(t, err)
verifyLeafCert(t, activeRoot, cert.CertPEM)
}
@ -425,7 +424,6 @@ func TestCAManager_RenewIntermediate_Secondary(t *testing.T) {
// no parallel execution because we change globals
patchIntermediateCertRenewInterval(t)
require := require.New(t)
_, s1 := testServerWithConfig(t, func(c *Config) {
c.Build = "1.6.0"
@ -469,15 +467,15 @@ func TestCAManager_RenewIntermediate_Secondary(t *testing.T) {
store := s2.fsm.State()
_, activeRoot, err := store.CARootActive(nil)
require.NoError(err)
require.NoError(t, err)
t.Log("original SigningKeyID", activeRoot.SigningKeyID)
intermediatePEM := s2.caManager.getLeafSigningCertFromRoot(activeRoot)
intermediateCert, err := connect.ParseCert(intermediatePEM)
require.NoError(err)
require.NoError(t, err)
require.Equal(intermediatePEM, s2.caManager.getLeafSigningCertFromRoot(activeRoot))
require.Equal(connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
require.Equal(t, intermediatePEM, s2.caManager.getLeafSigningCertFromRoot(activeRoot))
require.Equal(t, connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
// Wait for dc2's intermediate to be refreshed.
retry.Run(t, func(r *retry.R) {
@ -497,13 +495,13 @@ func TestCAManager_RenewIntermediate_Secondary(t *testing.T) {
codec := rpcClient(t, s2)
roots := structs.IndexedCARoots{}
err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", &structs.DCSpecificRequest{}, &roots)
require.NoError(err)
require.Len(roots.Roots, 1)
require.NoError(t, err)
require.Len(t, roots.Roots, 1)
_, activeRoot, err = store.CARootActive(nil)
require.NoError(err)
require.Equal(connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
require.Equal(intermediatePEM, s2.caManager.getLeafSigningCertFromRoot(activeRoot))
require.NoError(t, err)
require.Equal(t, connect.HexString(intermediateCert.SubjectKeyId), activeRoot.SigningKeyID)
require.Equal(t, intermediatePEM, s2.caManager.getLeafSigningCertFromRoot(activeRoot))
// Have dc2 sign a leaf cert and make sure the chain is correct.
spiffeService := &connect.SpiffeIDService{
@ -517,7 +515,7 @@ func TestCAManager_RenewIntermediate_Secondary(t *testing.T) {
req := structs.CASignRequest{CSR: csr}
cert := structs.IssuedCert{}
err = msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", &req, &cert)
require.NoError(err)
require.NoError(t, err)
verifyLeafCert(t, activeRoot, cert.CertPEM)
}
@ -528,8 +526,6 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.Build = "1.6.0"
c.PrimaryDatacenter = "dc1"
@ -555,15 +551,15 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
// Get the original intermediate
secondaryProvider, _ := getCAProviderWithLock(s2)
oldIntermediatePEM, err := secondaryProvider.ActiveIntermediate()
require.NoError(err)
require.NotEmpty(oldIntermediatePEM)
require.NoError(t, err)
require.NotEmpty(t, oldIntermediatePEM)
// Capture the current root
var originalRoot *structs.CARoot
{
rootList, activeRoot, err := getTestRoots(s1, "dc1")
require.NoError(err)
require.Len(rootList.Roots, 1)
require.NoError(t, err)
require.Len(t, rootList.Roots, 1)
originalRoot = activeRoot
}
@ -574,7 +570,7 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
// Update the provider config to use a new private key, which should
// cause a rotation.
_, newKey, err := connect.GeneratePrivateKey()
require.NoError(err)
require.NoError(t, err)
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
@ -590,14 +586,14 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
}
var reply interface{}
require.NoError(s1.RPC("ConnectCA.ConfigurationSet", args, &reply))
require.NoError(t, s1.RPC("ConnectCA.ConfigurationSet", args, &reply))
}
var updatedRoot *structs.CARoot
{
rootList, activeRoot, err := getTestRoots(s1, "dc1")
require.NoError(err)
require.Len(rootList.Roots, 2)
require.NoError(t, err)
require.Len(t, rootList.Roots, 2)
updatedRoot = activeRoot
}
@ -613,17 +609,17 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
r.Fatal("not a new intermediate")
}
})
require.NoError(err)
require.NoError(t, err)
// Verify the root lists have been rotated in each DC's state store.
state1 := s1.fsm.State()
_, primaryRoot, err := state1.CARootActive(nil)
require.NoError(err)
require.NoError(t, err)
state2 := s2.fsm.State()
_, roots2, err := state2.CARoots(nil)
require.NoError(err)
require.Equal(2, len(roots2))
require.NoError(t, err)
require.Equal(t, 2, len(roots2))
newRoot := roots2[0]
oldRoot := roots2[1]
@ -631,10 +627,10 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
newRoot = roots2[1]
oldRoot = roots2[0]
}
require.False(oldRoot.Active)
require.True(newRoot.Active)
require.Equal(primaryRoot.ID, newRoot.ID)
require.Equal(primaryRoot.RootCert, newRoot.RootCert)
require.False(t, oldRoot.Active)
require.True(t, newRoot.Active)
require.Equal(t, primaryRoot.ID, newRoot.ID)
require.Equal(t, primaryRoot.RootCert, newRoot.RootCert)
// Get the new root from dc1 and validate a chain of:
// dc2 leaf -> dc2 intermediate -> dc1 root
@ -650,13 +646,13 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
raw, _ := connect.TestCSR(t, spiffeService)
leafCsr, err := connect.ParseCSR(raw)
require.NoError(err)
require.NoError(t, err)
leafPEM, err := secondaryProvider.Sign(leafCsr)
require.NoError(err)
require.NoError(t, err)
cert, err := connect.ParseCert(leafPEM)
require.NoError(err)
require.NoError(t, err)
// Check that the leaf signed by the new intermediate can be verified using the
// returned cert chain (signed intermediate + remote root).
@ -669,7 +665,7 @@ func TestConnectCA_ConfigurationSet_RootRotation_Secondary(t *testing.T) {
Intermediates: intermediatePool,
Roots: rootPool,
})
require.NoError(err)
require.NoError(t, err)
}
func TestCAManager_Initialize_Vault_FixesSigningKeyID_Primary(t *testing.T) {
@ -1113,7 +1109,6 @@ func TestLeader_CARootPruning(t *testing.T) {
caRootPruneInterval = origPruneInterval
})
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -1127,14 +1122,14 @@ func TestLeader_CARootPruning(t *testing.T) {
Datacenter: "dc1",
}
var rootList structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(rootList.Roots, 1)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(t, rootList.Roots, 1)
oldRoot := rootList.Roots[0]
// Update the provider config to use a new private key, which should
// cause a rotation.
_, newKey, err := connect.GeneratePrivateKey()
require.NoError(err)
require.NoError(t, err)
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
@ -1151,22 +1146,22 @@ func TestLeader_CARootPruning(t *testing.T) {
}
var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
}
// Should have 2 roots now.
_, roots, err := s1.fsm.State().CARoots(nil)
require.NoError(err)
require.Len(roots, 2)
require.NoError(t, err)
require.Len(t, roots, 2)
time.Sleep(2 * time.Second)
// Now the old root should be pruned.
_, roots, err = s1.fsm.State().CARoots(nil)
require.NoError(err)
require.Len(roots, 1)
require.True(roots[0].Active)
require.NotEqual(roots[0].ID, oldRoot.ID)
require.NoError(t, err)
require.Len(t, roots, 1)
require.True(t, roots[0].Active)
require.NotEqual(t, roots[0].ID, oldRoot.ID)
}
func TestConnectCA_ConfigurationSet_PersistsRoots(t *testing.T) {
@ -1176,7 +1171,6 @@ func TestConnectCA_ConfigurationSet_PersistsRoots(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -1201,13 +1195,13 @@ func TestConnectCA_ConfigurationSet_PersistsRoots(t *testing.T) {
Datacenter: "dc1",
}
var rootList structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(rootList.Roots, 1)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(t, rootList.Roots, 1)
// Update the provider config to use a new private key, which should
// cause a rotation.
_, newKey, err := connect.GeneratePrivateKey()
require.NoError(err)
require.NoError(t, err)
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
@ -1222,12 +1216,12 @@ func TestConnectCA_ConfigurationSet_PersistsRoots(t *testing.T) {
}
var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
}
// Get the active root before leader change.
_, root := getCAProviderWithLock(s1)
require.Len(root.IntermediateCerts, 1)
require.Len(t, root.IntermediateCerts, 1)
// Force a leader change and make sure the root CA values are preserved.
s1.Leave()
@ -1310,17 +1304,16 @@ func TestParseCARoot(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
root, err := parseCARoot(tt.pem, "consul", "cluster")
if tt.wantErr {
require.Error(err)
require.Error(t, err)
return
}
require.NoError(err)
require.Equal(tt.wantSerial, root.SerialNumber)
require.Equal(strings.ToLower(tt.wantSigningKeyID), root.SigningKeyID)
require.Equal(tt.wantKeyType, root.PrivateKeyType)
require.Equal(tt.wantKeyBits, root.PrivateKeyBits)
require.NoError(t, err)
require.Equal(t, tt.wantSerial, root.SerialNumber)
require.Equal(t, strings.ToLower(tt.wantSigningKeyID), root.SigningKeyID)
require.Equal(t, tt.wantKeyType, root.PrivateKeyType)
require.Equal(t, tt.wantKeyBits, root.PrivateKeyBits)
})
}
}
@ -1491,7 +1484,6 @@ func TestCAManager_Initialize_BadCAConfigDoesNotPreventLeaderEstablishment(t *te
}
func TestConnectCA_ConfigurationSet_ForceWithoutCrossSigning(t *testing.T) {
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -1505,14 +1497,14 @@ func TestConnectCA_ConfigurationSet_ForceWithoutCrossSigning(t *testing.T) {
Datacenter: "dc1",
}
var rootList structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(rootList.Roots, 1)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(t, rootList.Roots, 1)
oldRoot := rootList.Roots[0]
// Update the provider config to use a new private key, which should
// cause a rotation.
_, newKey, err := connect.GeneratePrivateKey()
require.NoError(err)
require.NoError(t, err)
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
@ -1530,18 +1522,18 @@ func TestConnectCA_ConfigurationSet_ForceWithoutCrossSigning(t *testing.T) {
}
var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
}
// Old root should no longer be active.
_, roots, err := s1.fsm.State().CARoots(nil)
require.NoError(err)
require.Len(roots, 2)
require.NoError(t, err)
require.Len(t, roots, 2)
for _, r := range roots {
if r.ID == oldRoot.ID {
require.False(r.Active)
require.False(t, r.Active)
} else {
require.True(r.Active)
require.True(t, r.Active)
}
}
}
@ -1549,7 +1541,6 @@ func TestConnectCA_ConfigurationSet_ForceWithoutCrossSigning(t *testing.T) {
func TestConnectCA_ConfigurationSet_Vault_ForceWithoutCrossSigning(t *testing.T) {
ca.SkipIfVaultNotPresent(t)
require := require.New(t)
testVault := ca.NewTestVaultServer(t)
defer testVault.Stop()
@ -1577,8 +1568,8 @@ func TestConnectCA_ConfigurationSet_Vault_ForceWithoutCrossSigning(t *testing.T)
Datacenter: "dc1",
}
var rootList structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(rootList.Roots, 1)
require.Nil(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
require.Len(t, rootList.Roots, 1)
oldRoot := rootList.Roots[0]
// Update the provider config to use a new PKI path, which should
@ -1600,18 +1591,18 @@ func TestConnectCA_ConfigurationSet_Vault_ForceWithoutCrossSigning(t *testing.T)
}
var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
}
// Old root should no longer be active.
_, roots, err := s1.fsm.State().CARoots(nil)
require.NoError(err)
require.Len(roots, 2)
require.NoError(t, err)
require.Len(t, roots, 2)
for _, r := range roots {
if r.ID == oldRoot.ID {
require.False(r.Active)
require.False(t, r.Active)
} else {
require.True(r.Active)
require.True(t, r.Active)
}
}
}

View File

@ -217,7 +217,6 @@ func TestLeader_ReplicateIntentions(t *testing.T) {
func TestLeader_batchLegacyIntentionUpdates(t *testing.T) {
t.Parallel()
assert := assert.New(t)
ixn1 := structs.TestIntention(t)
ixn1.ID = "ixn1"
ixn2 := structs.TestIntention(t)
@ -356,7 +355,7 @@ func TestLeader_batchLegacyIntentionUpdates(t *testing.T) {
for _, tc := range cases {
actual := batchLegacyIntentionUpdates(tc.deletes, tc.updates)
assert.Equal(tc.expected, actual)
assert.Equal(t, tc.expected, actual)
}
}

View File

@ -1162,15 +1162,15 @@ func TestLeader_ACL_Initialization(t *testing.T) {
t.Parallel()
tests := []struct {
name string
build string
master string
bootstrap bool
name string
build string
initialManagement string
bootstrap bool
}{
{"old version, no master", "0.8.0", "", true},
{"old version, master", "0.8.0", "root", false},
{"new version, no master", "0.9.1", "", true},
{"new version, master", "0.9.1", "root", false},
{"old version, no initial management", "0.8.0", "", true},
{"old version, initial management", "0.8.0", "root", false},
{"new version, no initial management", "0.9.1", "", true},
{"new version, initial management", "0.9.1", "root", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -1180,17 +1180,17 @@ func TestLeader_ACL_Initialization(t *testing.T) {
c.Datacenter = "dc1"
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
c.ACLInitialManagementToken = tt.master
c.ACLInitialManagementToken = tt.initialManagement
}
dir1, s1 := testServerWithConfig(t, conf)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
if tt.master != "" {
_, master, err := s1.fsm.State().ACLTokenGetBySecret(nil, tt.master, nil)
if tt.initialManagement != "" {
_, initialManagement, err := s1.fsm.State().ACLTokenGetBySecret(nil, tt.initialManagement, nil)
require.NoError(t, err)
require.NotNil(t, master)
require.NotNil(t, initialManagement)
}
_, anon, err := s1.fsm.State().ACLTokenGetBySecret(nil, anonymousToken, nil)

View File

@ -3,22 +3,21 @@ package consul
import (
"testing"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/sdk/testutil"
)
func TestLoggerStore_Named(t *testing.T) {
t.Parallel()
require := require.New(t)
logger := testutil.Logger(t)
store := newLoggerStore(logger)
require.NotNil(store)
require.NotNil(t, store)
l1 := store.Named("test1")
l2 := store.Named("test2")
require.Truef(
l1 != l2,
require.Truef(t, l1 != l2,
"expected %p and %p to have a different memory address",
l1,
l2,
@ -27,16 +26,14 @@ func TestLoggerStore_Named(t *testing.T) {
func TestLoggerStore_NamedCache(t *testing.T) {
t.Parallel()
require := require.New(t)
logger := testutil.Logger(t)
store := newLoggerStore(logger)
require.NotNil(store)
require.NotNil(t, store)
l1 := store.Named("test")
l2 := store.Named("test")
require.Truef(
l1 == l2,
require.Truef(t, l1 == l2,
"expected %p and %p to have the same memory address",
l1,
l2,

View File

@ -17,11 +17,11 @@ func (op *Operator) AutopilotGetConfiguration(args *structs.DCSpecificRequest, r
}
// This action requires operator read access.
identity, authz, err := op.srv.acls.ResolveTokenToIdentityAndAuthorizer(args.Token)
authz, err := op.srv.ACLResolver.ResolveToken(args.Token)
if err != nil {
return err
}
if err := op.srv.validateEnterpriseToken(identity); err != nil {
if err := op.srv.validateEnterpriseToken(authz.Identity()); err != nil {
return err
}
if authz.OperatorRead(nil) != acl.Allow {
@ -49,11 +49,11 @@ func (op *Operator) AutopilotSetConfiguration(args *structs.AutopilotSetConfigRe
}
// This action requires operator write access.
identity, authz, err := op.srv.acls.ResolveTokenToIdentityAndAuthorizer(args.Token)
authz, err := op.srv.ACLResolver.ResolveToken(args.Token)
if err != nil {
return err
}
if err := op.srv.validateEnterpriseToken(identity); err != nil {
if err := op.srv.validateEnterpriseToken(authz.Identity()); err != nil {
return err
}
if authz.OperatorWrite(nil) != acl.Allow {
@ -84,11 +84,11 @@ func (op *Operator) ServerHealth(args *structs.DCSpecificRequest, reply *structs
}
// This action requires operator read access.
identity, authz, err := op.srv.acls.ResolveTokenToIdentityAndAuthorizer(args.Token)
authz, err := op.srv.ACLResolver.ResolveToken(args.Token)
if err != nil {
return err
}
if err := op.srv.validateEnterpriseToken(identity); err != nil {
if err := op.srv.validateEnterpriseToken(authz.Identity()); err != nil {
return err
}
if authz.OperatorRead(nil) != acl.Allow {
@ -151,11 +151,11 @@ func (op *Operator) AutopilotState(args *structs.DCSpecificRequest, reply *autop
}
// This action requires operator read access.
identity, authz, err := op.srv.acls.ResolveTokenToIdentityAndAuthorizer(args.Token)
authz, err := op.srv.ACLResolver.ResolveToken(args.Token)
if err != nil {
return err
}
if err := op.srv.validateEnterpriseToken(identity); err != nil {
if err := op.srv.validateEnterpriseToken(authz.Identity()); err != nil {
return err
}
if authz.OperatorRead(nil) != acl.Allow {

View File

@ -81,11 +81,11 @@ func (op *Operator) RaftRemovePeerByAddress(args *structs.RaftRemovePeerRequest,
// This is a super dangerous operation that requires operator write
// access.
identity, authz, err := op.srv.acls.ResolveTokenToIdentityAndAuthorizer(args.Token)
authz, err := op.srv.ACLResolver.ResolveToken(args.Token)
if err != nil {
return err
}
if err := op.srv.validateEnterpriseToken(identity); err != nil {
if err := op.srv.validateEnterpriseToken(authz.Identity()); err != nil {
return err
}
if authz.OperatorWrite(nil) != acl.Allow {
@ -134,11 +134,11 @@ func (op *Operator) RaftRemovePeerByID(args *structs.RaftRemovePeerRequest, repl
// This is a super dangerous operation that requires operator write
// access.
identity, authz, err := op.srv.acls.ResolveTokenToIdentityAndAuthorizer(args.Token)
authz, err := op.srv.ACLResolver.ResolveToken(args.Token)
if err != nil {
return err
}
if err := op.srv.validateEnterpriseToken(identity); err != nil {
if err := op.srv.validateEnterpriseToken(authz.Identity()); err != nil {
return err
}
if authz.OperatorWrite(nil) != acl.Allow {

View File

@ -222,7 +222,7 @@ func TestPreparedQuery_Apply_ACLDeny(t *testing.T) {
Datacenter: "dc1",
Op: structs.PreparedQueryCreate,
Query: &structs.PreparedQuery{
Name: "redis-master",
Name: "redis-primary",
Service: structs.ServiceQuery{
Service: "the-redis",
},
@ -503,7 +503,7 @@ func TestPreparedQuery_Apply_ForwardLeader(t *testing.T) {
Address: "127.0.0.1",
Service: &structs.NodeService{
Service: "redis",
Tags: []string{"master"},
Tags: []string{"primary"},
Port: 8000,
},
}
@ -853,7 +853,7 @@ func TestPreparedQuery_Get(t *testing.T) {
Datacenter: "dc1",
Op: structs.PreparedQueryCreate,
Query: &structs.PreparedQuery{
Name: "redis-master",
Name: "redis-primary",
Service: structs.ServiceQuery{
Service: "the-redis",
},
@ -1110,7 +1110,7 @@ func TestPreparedQuery_List(t *testing.T) {
Datacenter: "dc1",
Op: structs.PreparedQueryCreate,
Query: &structs.PreparedQuery{
Name: "redis-master",
Name: "redis-primary",
Token: "le-token",
Service: structs.ServiceQuery{
Service: "the-redis",
@ -2348,7 +2348,7 @@ func TestPreparedQuery_Execute_ForwardLeader(t *testing.T) {
Address: "127.0.0.1",
Service: &structs.NodeService{
Service: "redis",
Tags: []string{"master"},
Tags: []string{"primary"},
Port: 8000,
},
}
@ -2448,7 +2448,6 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
@ -2484,7 +2483,7 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
}
var reply struct{}
require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &reply))
require.NoError(t, msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &reply))
}
// The query, start with connect disabled
@ -2501,7 +2500,7 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
},
},
}
require.NoError(msgpackrpc.CallWithCodec(
require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
// In the future we'll run updates
@ -2515,15 +2514,15 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec(
require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because it omits the proxy whose name
// doesn't match the query.
require.Len(reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader")
require.Len(t, reply.Nodes, 2)
require.Equal(t, query.Query.Service.Service, reply.Service)
require.Equal(t, query.Query.DNS, reply.DNS)
require.True(t, reply.QueryMeta.KnownLeader, "queried leader")
}
// Run with the Connect setting specified on the request
@ -2535,31 +2534,31 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec(
require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because we should get the native AND
// the proxy (since the destination matches our service name).
require.Len(reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader")
require.Len(t, reply.Nodes, 2)
require.Equal(t, query.Query.Service.Service, reply.Service)
require.Equal(t, query.Query.DNS, reply.DNS)
require.True(t, reply.QueryMeta.KnownLeader, "queried leader")
// Make sure the native is the first one
if !reply.Nodes[0].Service.Connect.Native {
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
}
require.True(reply.Nodes[0].Service.Connect.Native, "native")
require.Equal(reply.Service, reply.Nodes[0].Service.Service)
require.True(t, reply.Nodes[0].Service.Connect.Native, "native")
require.Equal(t, reply.Service, reply.Nodes[0].Service.Service)
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
require.Equal(reply.Service, reply.Nodes[1].Service.Proxy.DestinationServiceName)
require.Equal(t, structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
require.Equal(t, reply.Service, reply.Nodes[1].Service.Proxy.DestinationServiceName)
}
// Update the query
query.Query.Service.Connect = true
require.NoError(msgpackrpc.CallWithCodec(
require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
// Run the registered query.
@ -2570,31 +2569,31 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec(
require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because we should get the native AND
// the proxy (since the destination matches our service name).
require.Len(reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader")
require.Len(t, reply.Nodes, 2)
require.Equal(t, query.Query.Service.Service, reply.Service)
require.Equal(t, query.Query.DNS, reply.DNS)
require.True(t, reply.QueryMeta.KnownLeader, "queried leader")
// Make sure the native is the first one
if !reply.Nodes[0].Service.Connect.Native {
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
}
require.True(reply.Nodes[0].Service.Connect.Native, "native")
require.Equal(reply.Service, reply.Nodes[0].Service.Service)
require.True(t, reply.Nodes[0].Service.Connect.Native, "native")
require.Equal(t, reply.Service, reply.Nodes[0].Service.Service)
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
require.Equal(reply.Service, reply.Nodes[1].Service.Proxy.DestinationServiceName)
require.Equal(t, structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
require.Equal(t, reply.Service, reply.Nodes[1].Service.Proxy.DestinationServiceName)
}
// Unset the query
query.Query.Service.Connect = false
require.NoError(msgpackrpc.CallWithCodec(
require.NoError(t, msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
}

View File

@ -919,108 +919,74 @@ type queryFn func(memdb.WatchSet, *state.Store) error
// blockingQuery is used to process a potentially blocking query operation.
func (s *Server) blockingQuery(queryOpts structs.QueryOptionsCompat, queryMeta structs.QueryMetaCompat, fn queryFn) error {
var cancel func()
var ctx context.Context = &lib.StopChannelContext{StopCh: s.shutdownCh}
var queriesBlocking uint64
var queryTimeout time.Duration
// Instrument all queries run
metrics.IncrCounter([]string{"rpc", "query"}, 1)
minQueryIndex := queryOpts.GetMinQueryIndex()
// Fast path right to the non-blocking query.
// Perform a non-blocking query
if minQueryIndex == 0 {
goto RUN_QUERY
if queryOpts.GetRequireConsistent() {
if err := s.consistentRead(); err != nil {
return err
}
}
var ws memdb.WatchSet
err := fn(ws, s.fsm.State())
s.setQueryMeta(queryMeta, queryOpts.GetToken())
return err
}
queryTimeout = queryOpts.GetMaxQueryTime()
// Restrict the max query time, and ensure there is always one.
if queryTimeout > s.config.MaxQueryTime {
queryTimeout = s.config.MaxQueryTime
} else if queryTimeout <= 0 {
queryTimeout = s.config.DefaultQueryTime
}
// Apply a small amount of jitter to the request.
queryTimeout += lib.RandomStagger(queryTimeout / structs.JitterFraction)
// wrap the base context with a deadline
ctx, cancel = context.WithDeadline(ctx, time.Now().Add(queryTimeout))
timeout := s.rpcQueryTimeout(queryOpts.GetMaxQueryTime())
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// instrument blockingQueries
// atomic inc our server's count of in-flight blockingQueries and store the new value
queriesBlocking = atomic.AddUint64(&s.queriesBlocking, 1)
// atomic dec when we return from blockingQuery()
count := atomic.AddUint64(&s.queriesBlocking, 1)
metrics.SetGauge([]string{"rpc", "queries_blocking"}, float32(count))
// decrement the count when the function returns.
defer atomic.AddUint64(&s.queriesBlocking, ^uint64(0))
// set the gauge directly to the new value of s.blockingQueries
metrics.SetGauge([]string{"rpc", "queries_blocking"}, float32(queriesBlocking))
RUN_QUERY:
// Setup blocking loop
// Validate
// If the read must be consistent we verify that we are still the leader.
if queryOpts.GetRequireConsistent() {
if err := s.consistentRead(); err != nil {
return err
for {
if queryOpts.GetRequireConsistent() {
if err := s.consistentRead(); err != nil {
return err
}
}
}
// Run query
// Operate on a consistent set of state. This makes sure that the
// abandon channel goes with the state that the caller is using to
// build watches.
state := s.fsm.State()
// We can skip all watch tracking if this isn't a blocking query.
var ws memdb.WatchSet
if minQueryIndex > 0 {
ws = memdb.NewWatchSet()
// Operate on a consistent set of state. This makes sure that the
// abandon channel goes with the state that the caller is using to
// build watches.
state := s.fsm.State()
ws := memdb.NewWatchSet()
// This channel will be closed if a snapshot is restored and the
// whole state store is abandoned.
ws.Add(state.AbandonCh())
}
// Execute the queryFn
err := fn(ws, state)
err := fn(ws, state)
s.setQueryMeta(queryMeta, queryOpts.GetToken())
if err != nil {
return err
}
// Update the query metadata.
s.setQueryMeta(queryMeta, queryOpts.GetToken())
if queryMeta.GetIndex() > minQueryIndex {
return nil
}
// Note we check queryOpts.MinQueryIndex is greater than zero to determine if
// blocking was requested by client, NOT meta.Index since the state function
// might return zero if something is not initialized and care wasn't taken to
// handle that special case (in practice this happened a lot so fixing it
// systematically here beats trying to remember to add zero checks in every
// state method). We also need to ensure that unless there is an error, we
// return an index > 0 otherwise the client will never block and burn CPU and
// requests.
if err == nil && queryMeta.GetIndex() < 1 {
queryMeta.SetIndex(1)
}
// block up to the timeout if we don't see anything fresh.
if err == nil && minQueryIndex > 0 && queryMeta.GetIndex() <= minQueryIndex {
if err := ws.WatchCtx(ctx); err == nil {
// a non-nil error only occurs when the context is cancelled
// block until something changes, or the timeout
if err := ws.WatchCtx(ctx); err != nil {
// exit if we've reached the timeout, or other cancellation
return nil
}
// If a restore may have woken us up then bail out from
// the query immediately. This is slightly race-ey since
// this might have been interrupted for other reasons,
// but it's OK to kick it back to the caller in either
// case.
select {
case <-state.AbandonCh():
default:
// loop back and look for an update again
goto RUN_QUERY
}
// exit if the state store has been abandoned
select {
case <-state.AbandonCh():
return nil
default:
}
}
return err
}
// setQueryMeta is used to populate the QueryMeta data for an RPC call
@ -1035,6 +1001,17 @@ func (s *Server) setQueryMeta(m structs.QueryMetaCompat, token string) {
m.SetKnownLeader(s.raft.Leader() != "")
}
maskResultsFilteredByACLs(token, m)
// Always set a non-zero QueryMeta.Index. Generally we expect the
// QueryMeta.Index to be set to structs.RaftIndex.ModifyIndex. If the query
// returned no results we expect it to be set to the max index of the table,
// however we can't guarantee this always happens.
// To prevent a client from accidentally performing many non-blocking queries
// (which causes lots of unnecessary load), we always set a default value of 1.
// This is sufficient to prevent the unnecessary load in most cases.
if m.GetIndex() < 1 {
m.SetIndex(1)
}
}
// consistentRead is used to ensure we do not perform a stale
@ -1070,6 +1047,22 @@ func (s *Server) consistentRead() error {
return structs.ErrNotReadyForConsistentReads
}
// rpcQueryTimeout calculates the timeout for the query, ensures it is
// constrained to the configured limit, and adds jitter to prevent multiple
// blocking queries from all timing out at the same time.
func (s *Server) rpcQueryTimeout(queryTimeout time.Duration) time.Duration {
// Restrict the max query time, and ensure there is always one.
if queryTimeout > s.config.MaxQueryTime {
queryTimeout = s.config.MaxQueryTime
} else if queryTimeout <= 0 {
queryTimeout = s.config.DefaultQueryTime
}
// Apply a small amount of jitter to the request.
queryTimeout += lib.RandomStagger(queryTimeout / structs.JitterFraction)
return queryTimeout
}
// maskResultsFilteredByACLs blanks out the ResultsFilteredByACLs flag if the
// request is unauthenticated, to limit information leaking.
//

View File

@ -233,13 +233,10 @@ func TestRPC_blockingQuery(t *testing.T) {
defer os.RemoveAll(dir)
defer s.Shutdown()
require := require.New(t)
assert := assert.New(t)
// Perform a non-blocking query. Note that it's significant that the meta has
// a zero index in response - the implied opts.MinQueryIndex is also zero but
// this should not block still.
{
t.Run("non-blocking query", func(t *testing.T) {
var opts structs.QueryOptions
var meta structs.QueryMeta
var calls int
@ -247,16 +244,13 @@ func TestRPC_blockingQuery(t *testing.T) {
calls++
return nil
}
if err := s.blockingQuery(&opts, &meta, fn); err != nil {
t.Fatalf("err: %v", err)
}
if calls != 1 {
t.Fatalf("bad: %d", calls)
}
}
err := s.blockingQuery(&opts, &meta, fn)
require.NoError(t, err)
require.Equal(t, 1, calls)
})
// Perform a blocking query that gets woken up and loops around once.
{
t.Run("blocking query - single loop", func(t *testing.T) {
opts := structs.QueryOptions{
MinQueryIndex: 3,
}
@ -275,13 +269,10 @@ func TestRPC_blockingQuery(t *testing.T) {
calls++
return nil
}
if err := s.blockingQuery(&opts, &meta, fn); err != nil {
t.Fatalf("err: %v", err)
}
if calls != 2 {
t.Fatalf("bad: %d", calls)
}
}
err := s.blockingQuery(&opts, &meta, fn)
require.NoError(t, err)
require.Equal(t, 2, calls)
})
// Perform a blocking query that returns a zero index from blocking func (e.g.
// no state yet). This should still return an empty response immediately, but
@ -292,7 +283,7 @@ func TestRPC_blockingQuery(t *testing.T) {
// covered by tests but eventually when hit in the wild causes blocking
// clients to busy loop and burn CPU. This test ensure that blockingQuery
// systematically does the right thing to prevent future bugs like that.
{
t.Run("blocking query with 0 modifyIndex from state func", func(t *testing.T) {
opts := structs.QueryOptions{
MinQueryIndex: 0,
}
@ -311,9 +302,9 @@ func TestRPC_blockingQuery(t *testing.T) {
calls++
return nil
}
require.NoError(s.blockingQuery(&opts, &meta, fn))
assert.Equal(1, calls)
assert.Equal(uint64(1), meta.Index,
require.NoError(t, s.blockingQuery(&opts, &meta, fn))
assert.Equal(t, 1, calls)
assert.Equal(t, uint64(1), meta.Index,
"expect fake index of 1 to force client to block on next update")
// Simulate client making next request
@ -322,19 +313,19 @@ func TestRPC_blockingQuery(t *testing.T) {
// This time we should block even though the func returns index 0 still
t0 := time.Now()
require.NoError(s.blockingQuery(&opts, &meta, fn))
require.NoError(t, s.blockingQuery(&opts, &meta, fn))
t1 := time.Now()
assert.Equal(2, calls)
assert.Equal(uint64(1), meta.Index,
assert.Equal(t, 2, calls)
assert.Equal(t, uint64(1), meta.Index,
"expect fake index of 1 to force client to block on next update")
assert.True(t1.Sub(t0) > 20*time.Millisecond,
assert.True(t, t1.Sub(t0) > 20*time.Millisecond,
"should have actually blocked waiting for timeout")
}
})
// Perform a query that blocks and gets interrupted when the state store
// is abandoned.
{
t.Run("blocking query interrupted by abandonCh", func(t *testing.T) {
opts := structs.QueryOptions{
MinQueryIndex: 3,
}
@ -363,13 +354,10 @@ func TestRPC_blockingQuery(t *testing.T) {
calls++
return nil
}
if err := s.blockingQuery(&opts, &meta, fn); err != nil {
t.Fatalf("err: %v", err)
}
if calls != 1 {
t.Fatalf("bad: %d", calls)
}
}
err := s.blockingQuery(&opts, &meta, fn)
require.NoError(t, err)
require.Equal(t, 1, calls)
})
t.Run("ResultsFilteredByACLs is reset for unauthenticated calls", func(t *testing.T) {
opts := structs.QueryOptions{
@ -382,13 +370,13 @@ func TestRPC_blockingQuery(t *testing.T) {
}
err := s.blockingQuery(&opts, &meta, fn)
require.NoError(err)
require.False(meta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be reset for unauthenticated calls")
require.NoError(t, err)
require.False(t, meta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be reset for unauthenticated calls")
})
t.Run("ResultsFilteredByACLs is honored for authenticated calls", func(t *testing.T) {
token, err := lib.GenerateUUID(nil)
require.NoError(err)
require.NoError(t, err)
opts := structs.QueryOptions{
Token: token,
@ -400,8 +388,8 @@ func TestRPC_blockingQuery(t *testing.T) {
}
err = s.blockingQuery(&opts, &meta, fn)
require.NoError(err)
require.True(meta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be honored for authenticated calls")
require.NoError(t, err)
require.True(t, meta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be honored for authenticated calls")
})
}

View File

@ -141,7 +141,7 @@ type Server struct {
aclConfig *acl.Config
// acls is used to resolve tokens to effective policies
acls *ACLResolver
*ACLResolver
aclAuthMethodValidators authmethod.Cache
@ -450,14 +450,14 @@ func NewServer(config *Config, flat Deps) (*Server, error) {
s.aclConfig = newACLConfig(partitionInfo, logger)
aclConfig := ACLResolverConfig{
Config: config.ACLResolverSettings,
Delegate: s,
Backend: &serverACLResolverBackend{Server: s},
CacheConfig: serverACLCacheConfig,
Logger: logger,
ACLConfig: s.aclConfig,
Tokens: flat.Tokens,
}
// Initialize the ACL resolver.
if s.acls, err = NewACLResolver(&aclConfig); err != nil {
if s.ACLResolver, err = NewACLResolver(&aclConfig); err != nil {
s.Shutdown()
return nil, fmt.Errorf("Failed to create ACL resolver: %v", err)
}
@ -994,8 +994,8 @@ func (s *Server) Shutdown() error {
s.connPool.Shutdown()
}
if s.acls != nil {
s.acls.Close()
if s.ACLResolver != nil {
s.ACLResolver.Close()
}
if s.fsm != nil {

View File

@ -121,7 +121,7 @@ func (s *Server) setupSerfConfig(opts setupSerfOptions) (*serf.Config, error) {
// TODO(ACL-Legacy-Compat): remove in phase 2. These are kept for now to
// allow for upgrades.
if s.acls.ACLsEnabled() {
if s.ACLResolver.ACLsEnabled() {
conf.Tags[metadata.TagACLs] = string(structs.ACLModeEnabled)
} else {
conf.Tags[metadata.TagACLs] = string(structs.ACLModeDisabled)

View File

@ -35,7 +35,7 @@ import (
)
const (
TestDefaultMasterToken = "d9f05e83-a7ae-47ce-839e-c0d53a68c00a"
TestDefaultInitialManagementToken = "d9f05e83-a7ae-47ce-839e-c0d53a68c00a"
)
// testTLSCertificates Generates a TLS CA and server key/cert and returns them
@ -70,7 +70,7 @@ func testTLSCertificates(serverName string) (cert string, key string, cacert str
func testServerACLConfig(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
c.ACLInitialManagementToken = TestDefaultMasterToken
c.ACLInitialManagementToken = TestDefaultInitialManagementToken
c.ACLResolverSettings.ACLDefaultPolicy = "deny"
}
@ -245,7 +245,7 @@ func testACLServerWithConfig(t *testing.T, cb func(*Config), initReplicationToke
if initReplicationToken {
// setup some tokens here so we get less warnings in the logs
srv.tokens.UpdateReplicationToken(TestDefaultMasterToken, token.TokenSourceConfig)
srv.tokens.UpdateReplicationToken(TestDefaultInitialManagementToken, token.TokenSourceConfig)
}
codec := rpcClient(t, srv)

View File

@ -420,7 +420,6 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) {
require.NoError(t, err)
t.Run("Get", func(t *testing.T) {
require := require.New(t)
req := &structs.SessionSpecificRequest{
Datacenter: "dc1",
@ -432,30 +431,29 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) {
var sessions structs.IndexedSessions
err := msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions)
require.NoError(err)
require.Empty(sessions.Sessions)
require.True(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.NoError(t, err)
require.Empty(t, sessions.Sessions)
require.True(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// ACL-restricted results included.
req.Token = allowedToken
err = msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions)
require.NoError(err)
require.Len(sessions.Sessions, 1)
require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
require.NoError(t, err)
require.Len(t, sessions.Sessions, 1)
require.False(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
// Try to get a session that doesn't exist to make sure that's handled
// correctly by the filter (it will get passed a nil slice).
req.SessionID = "adf4238a-882b-9ddc-4a9d-5b6758e4159e"
err = msgpackrpc.CallWithCodec(codec, "Session.Get", req, &sessions)
require.NoError(err)
require.Empty(sessions.Sessions)
require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
require.NoError(t, err)
require.Empty(t, sessions.Sessions)
require.False(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
})
t.Run("List", func(t *testing.T) {
require := require.New(t)
req := &structs.DCSpecificRequest{
Datacenter: "dc1",
@ -466,21 +464,20 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) {
var sessions structs.IndexedSessions
err := msgpackrpc.CallWithCodec(codec, "Session.List", req, &sessions)
require.NoError(err)
require.Empty(sessions.Sessions)
require.True(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.NoError(t, err)
require.Empty(t, sessions.Sessions)
require.True(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// ACL-restricted results included.
req.Token = allowedToken
err = msgpackrpc.CallWithCodec(codec, "Session.List", req, &sessions)
require.NoError(err)
require.Len(sessions.Sessions, 1)
require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
require.NoError(t, err)
require.Len(t, sessions.Sessions, 1)
require.False(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
})
t.Run("NodeSessions", func(t *testing.T) {
require := require.New(t)
req := &structs.NodeSpecificRequest{
Datacenter: "dc1",
@ -492,17 +489,17 @@ func TestSession_Get_List_NodeSessions_ACLFilter(t *testing.T) {
var sessions structs.IndexedSessions
err := msgpackrpc.CallWithCodec(codec, "Session.NodeSessions", req, &sessions)
require.NoError(err)
require.Empty(sessions.Sessions)
require.True(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
require.NoError(t, err)
require.Empty(t, sessions.Sessions)
require.True(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be true")
// ACL-restricted results included.
req.Token = allowedToken
err = msgpackrpc.CallWithCodec(codec, "Session.NodeSessions", req, &sessions)
require.NoError(err)
require.Len(sessions.Sessions, 1)
require.False(sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
require.NoError(t, err)
require.Len(t, sessions.Sessions, 1)
require.False(t, sessions.QueryMeta.ResultsFilteredByACLs, "ResultsFilteredByACLs should be false")
})
}

View File

@ -32,31 +32,26 @@ func (e EventPayloadCheckServiceNode) HasReadPermission(authz acl.Authorizer) bo
return e.Value.CanRead(authz) == acl.Allow
}
func (e EventPayloadCheckServiceNode) MatchesKey(key, namespace, partition string) bool {
if key == "" && namespace == "" && partition == "" {
return true
}
if e.Value.Service == nil {
return false
}
name := e.Value.Service.Service
if e.overrideKey != "" {
name = e.overrideKey
}
ns := e.Value.Service.EnterpriseMeta.NamespaceOrDefault()
if e.overrideNamespace != "" {
ns = e.overrideNamespace
}
ap := e.Value.Service.EnterpriseMeta.PartitionOrDefault()
func (e EventPayloadCheckServiceNode) Subject() stream.Subject {
partition := e.Value.Service.PartitionOrDefault()
if e.overridePartition != "" {
ap = e.overridePartition
partition = e.overridePartition
}
partition = strings.ToLower(partition)
return (key == "" || strings.EqualFold(key, name)) &&
(namespace == "" || strings.EqualFold(namespace, ns)) &&
(partition == "" || strings.EqualFold(partition, ap))
namespace := e.Value.Service.NamespaceOrDefault()
if e.overrideNamespace != "" {
namespace = e.overrideNamespace
}
namespace = strings.ToLower(namespace)
key := e.Value.Service.Service
if e.overrideKey != "" {
key = e.overrideKey
}
key = strings.ToLower(key)
return stream.Subject(partition + "/" + namespace + "/" + key)
}
// serviceHealthSnapshot returns a stream.SnapshotFunc that provides a snapshot
@ -67,8 +62,7 @@ func serviceHealthSnapshot(db ReadDB, topic stream.Topic) stream.SnapshotFunc {
defer tx.Abort()
connect := topic == topicServiceHealthConnect
entMeta := structs.NewEnterpriseMetaWithPartition(req.Partition, req.Namespace)
idx, nodes, err := checkServiceNodesTxn(tx, nil, req.Key, connect, &entMeta)
idx, nodes, err := checkServiceNodesTxn(tx, nil, req.Key, connect, &req.EnterpriseMeta)
if err != nil {
return 0, err
}

View File

@ -11,11 +11,106 @@ import (
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/proto/pbcommon"
"github.com/hashicorp/consul/proto/pbsubscribe"
"github.com/hashicorp/consul/types"
)
func TestEventPayloadCheckServiceNode_SubjectMatchesRequests(t *testing.T) {
// Matches.
for desc, tc := range map[string]struct {
evt EventPayloadCheckServiceNode
req stream.SubscribeRequest
}{
"default partition and namespace": {
EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
Service: "foo",
},
},
},
stream.SubscribeRequest{
Key: "foo",
EnterpriseMeta: structs.EnterpriseMeta{},
},
},
"mixed casing": {
EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
Service: "FoO",
},
},
},
stream.SubscribeRequest{Key: "foo"},
},
"override key": {
EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
Service: "foo",
},
},
overrideKey: "bar",
},
stream.SubscribeRequest{Key: "bar"},
},
} {
t.Run(desc, func(t *testing.T) {
require.Equal(t, tc.req.Subject(), tc.evt.Subject())
})
}
// Non-matches.
for desc, tc := range map[string]struct {
evt EventPayloadCheckServiceNode
req stream.SubscribeRequest
}{
"different key": {
EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
Service: "foo",
},
},
},
stream.SubscribeRequest{
Key: "bar",
},
},
"different partition": {
EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
Service: "foo",
},
},
overridePartition: "bar",
},
stream.SubscribeRequest{
Key: "foo",
},
},
"different namespace": {
EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{
Service: &structs.NodeService{
Service: "foo",
},
},
overrideNamespace: "bar",
},
stream.SubscribeRequest{
Key: "foo",
},
},
} {
t.Run(desc, func(t *testing.T) {
require.NotEqual(t, tc.req.Subject(), tc.evt.Subject())
})
}
}
func TestServiceHealthSnapshot(t *testing.T) {
store := NewStateStore(nil)
@ -1771,7 +1866,7 @@ func assertDeepEqual(t *testing.T, x, y interface{}, opts ...cmp.Option) {
// all events for a particular topic are grouped together. The sort is
// stable so events with the same key retain their relative order.
//
// This sort should match the logic in EventPayloadCheckServiceNode.MatchesKey
// This sort should match the logic in EventPayloadCheckServiceNode.Subject
// to avoid masking bugs.
var cmpPartialOrderEvents = cmp.Options{
cmpopts.SortSlices(func(i, j stream.Event) bool {
@ -2418,107 +2513,6 @@ func newTestEventServiceHealthDeregister(index uint64, nodeNum int, svc string)
}
}
func TestEventPayloadCheckServiceNode_FilterByKey(t *testing.T) {
type testCase struct {
name string
payload EventPayloadCheckServiceNode
key string
namespace string
partition string // TODO(partitions): create test cases for this being set
expected bool
}
fn := func(t *testing.T, tc testCase) {
if tc.namespace != "" && pbcommon.DefaultEnterpriseMeta.Namespace == "" {
t.Skip("cant test namespace matching without namespace support")
}
require.Equal(t, tc.expected, tc.payload.MatchesKey(tc.key, tc.namespace, tc.partition))
}
var testCases = []testCase{
{
name: "no key or namespace",
payload: newPayloadCheckServiceNode("srv1", "ns1"),
expected: true,
},
{
name: "no key, with namespace match",
payload: newPayloadCheckServiceNode("srv1", "ns1"),
namespace: "ns1",
expected: true,
},
{
name: "no namespace, with key match",
payload: newPayloadCheckServiceNode("srv1", "ns1"),
key: "srv1",
expected: true,
},
{
name: "key match, namespace mismatch",
payload: newPayloadCheckServiceNode("srv1", "ns1"),
key: "srv1",
namespace: "ns2",
expected: false,
},
{
name: "key mismatch, namespace match",
payload: newPayloadCheckServiceNode("srv1", "ns1"),
key: "srv2",
namespace: "ns1",
expected: false,
},
{
name: "override key match",
payload: newPayloadCheckServiceNodeWithOverride("proxy", "ns1", "srv1", ""),
key: "srv1",
namespace: "ns1",
expected: true,
},
{
name: "override key mismatch",
payload: newPayloadCheckServiceNodeWithOverride("proxy", "ns1", "srv2", ""),
key: "proxy",
namespace: "ns1",
expected: false,
},
{
name: "override namespace match",
payload: newPayloadCheckServiceNodeWithOverride("proxy", "ns1", "", "ns2"),
key: "proxy",
namespace: "ns2",
expected: true,
},
{
name: "override namespace mismatch",
payload: newPayloadCheckServiceNodeWithOverride("proxy", "ns1", "", "ns3"),
key: "proxy",
namespace: "ns1",
expected: false,
},
{
name: "override both key and namespace match",
payload: newPayloadCheckServiceNodeWithOverride("proxy", "ns1", "srv1", "ns2"),
key: "srv1",
namespace: "ns2",
expected: true,
},
{
name: "override both key and namespace mismatch namespace",
payload: newPayloadCheckServiceNodeWithOverride("proxy", "ns1", "srv2", "ns3"),
key: "proxy",
namespace: "ns1",
expected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fn(t, tc)
})
}
}
func newPayloadCheckServiceNode(service, namespace string) EventPayloadCheckServiceNode {
return EventPayloadCheckServiceNode{
Value: &structs.CheckServiceNode{

View File

@ -1515,7 +1515,6 @@ func TestStateStore_EnsureService(t *testing.T) {
}
func TestStateStore_EnsureService_connectProxy(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Create the service registration.
@ -1535,21 +1534,20 @@ func TestStateStore_EnsureService_connectProxy(t *testing.T) {
// Service successfully registers into the state store.
testRegisterNode(t, s, 0, "node1")
assert.Nil(s.EnsureService(10, "node1", ns1))
assert.Nil(t, s.EnsureService(10, "node1", ns1))
// Retrieve and verify
_, out, err := s.NodeServices(nil, "node1", nil)
assert.Nil(err)
assert.NotNil(out)
assert.Len(out.Services, 1)
assert.Nil(t, err)
assert.NotNil(t, out)
assert.Len(t, out.Services, 1)
expect1 := *ns1
expect1.CreateIndex, expect1.ModifyIndex = 10, 10
assert.Equal(&expect1, out.Services["connect-proxy"])
assert.Equal(t, &expect1, out.Services["connect-proxy"])
}
func TestStateStore_EnsureService_VirtualIPAssign(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
setVirtualIPFlags(t, s)
@ -1575,17 +1573,17 @@ func TestStateStore_EnsureService_VirtualIPAssign(t *testing.T) {
// Make sure there's a virtual IP for the foo service.
vip, err := s.VirtualIPForService(structs.ServiceName{Name: "foo"})
require.NoError(t, err)
assert.Equal("240.0.0.1", vip)
assert.Equal(t, "240.0.0.1", vip)
// Retrieve and verify
_, out, err := s.NodeServices(nil, "node1", nil)
require.NoError(t, err)
assert.NotNil(out)
assert.Len(out.Services, 1)
assert.NotNil(t, out)
assert.Len(t, out.Services, 1)
taggedAddress := out.Services["foo"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address)
assert.Equal(ns1.Port, taggedAddress.Port)
assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(t, ns1.Port, taggedAddress.Port)
// Create the service registration.
ns2 := &structs.NodeService{
@ -1606,23 +1604,23 @@ func TestStateStore_EnsureService_VirtualIPAssign(t *testing.T) {
// Make sure the virtual IP has been incremented for the redis service.
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"})
require.NoError(t, err)
assert.Equal("240.0.0.2", vip)
assert.Equal(t, "240.0.0.2", vip)
// Retrieve and verify
_, out, err = s.NodeServices(nil, "node1", nil)
assert.Nil(err)
assert.NotNil(out)
assert.Len(out.Services, 2)
assert.Nil(t, err)
assert.NotNil(t, out)
assert.Len(t, out.Services, 2)
taggedAddress = out.Services["redis-proxy"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address)
assert.Equal(ns2.Port, taggedAddress.Port)
assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(t, ns2.Port, taggedAddress.Port)
// Delete the first service and make sure it no longer has a virtual IP assigned.
require.NoError(t, s.DeleteService(12, "node1", "foo", entMeta))
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "connect-proxy"})
require.NoError(t, err)
assert.Equal("", vip)
assert.Equal(t, "", vip)
// Register another instance of redis-proxy and make sure the virtual IP is unchanged.
ns3 := &structs.NodeService{
@ -1643,14 +1641,14 @@ func TestStateStore_EnsureService_VirtualIPAssign(t *testing.T) {
// Make sure the virtual IP is unchanged for the redis service.
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"})
require.NoError(t, err)
assert.Equal("240.0.0.2", vip)
assert.Equal(t, "240.0.0.2", vip)
// Make sure the new instance has the same virtual IP.
_, out, err = s.NodeServices(nil, "node1", nil)
require.NoError(t, err)
taggedAddress = out.Services["redis-proxy2"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address)
assert.Equal(ns3.Port, taggedAddress.Port)
assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(t, ns3.Port, taggedAddress.Port)
// Register another service to take its virtual IP.
ns4 := &structs.NodeService{
@ -1671,18 +1669,17 @@ func TestStateStore_EnsureService_VirtualIPAssign(t *testing.T) {
// Make sure the virtual IP has allocated from the previously freed service.
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "web"})
require.NoError(t, err)
assert.Equal("240.0.0.1", vip)
assert.Equal(t, "240.0.0.1", vip)
// Retrieve and verify
_, out, err = s.NodeServices(nil, "node1", nil)
require.NoError(t, err)
taggedAddress = out.Services["web-proxy"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address)
assert.Equal(ns4.Port, taggedAddress.Port)
assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(t, ns4.Port, taggedAddress.Port)
}
func TestStateStore_EnsureService_ReassignFreedVIPs(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
setVirtualIPFlags(t, s)
@ -1708,16 +1705,16 @@ func TestStateStore_EnsureService_ReassignFreedVIPs(t *testing.T) {
// Make sure there's a virtual IP for the foo service.
vip, err := s.VirtualIPForService(structs.ServiceName{Name: "foo"})
require.NoError(t, err)
assert.Equal("240.0.0.1", vip)
assert.Equal(t, "240.0.0.1", vip)
// Retrieve and verify
_, out, err := s.NodeServices(nil, "node1", nil)
require.NoError(t, err)
assert.NotNil(out)
assert.NotNil(t, out)
taggedAddress := out.Services["foo"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address)
assert.Equal(ns1.Port, taggedAddress.Port)
assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(t, ns1.Port, taggedAddress.Port)
// Create the service registration.
ns2 := &structs.NodeService{
@ -1738,22 +1735,22 @@ func TestStateStore_EnsureService_ReassignFreedVIPs(t *testing.T) {
// Make sure the virtual IP has been incremented for the redis service.
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"})
require.NoError(t, err)
assert.Equal("240.0.0.2", vip)
assert.Equal(t, "240.0.0.2", vip)
// Retrieve and verify
_, out, err = s.NodeServices(nil, "node1", nil)
assert.Nil(err)
assert.NotNil(out)
assert.Nil(t, err)
assert.NotNil(t, out)
taggedAddress = out.Services["redis"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address)
assert.Equal(ns2.Port, taggedAddress.Port)
assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(t, ns2.Port, taggedAddress.Port)
// Delete the last service and make sure it no longer has a virtual IP assigned.
require.NoError(t, s.DeleteService(12, "node1", "redis", entMeta))
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "redis"})
require.NoError(t, err)
assert.Equal("", vip)
assert.Equal(t, "", vip)
// Register a new service, should end up with the freed 240.0.0.2 address.
ns3 := &structs.NodeService{
@ -1773,16 +1770,16 @@ func TestStateStore_EnsureService_ReassignFreedVIPs(t *testing.T) {
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "backend"})
require.NoError(t, err)
assert.Equal("240.0.0.2", vip)
assert.Equal(t, "240.0.0.2", vip)
// Retrieve and verify
_, out, err = s.NodeServices(nil, "node1", nil)
assert.Nil(err)
assert.NotNil(out)
assert.Nil(t, err)
assert.NotNil(t, out)
taggedAddress = out.Services["backend"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address)
assert.Equal(ns3.Port, taggedAddress.Port)
assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(t, ns3.Port, taggedAddress.Port)
// Create a new service, no more freed VIPs so it should go back to using the counter.
ns4 := &structs.NodeService{
@ -1803,16 +1800,16 @@ func TestStateStore_EnsureService_ReassignFreedVIPs(t *testing.T) {
// Make sure the virtual IP has been incremented for the frontend service.
vip, err = s.VirtualIPForService(structs.ServiceName{Name: "frontend"})
require.NoError(t, err)
assert.Equal("240.0.0.3", vip)
assert.Equal(t, "240.0.0.3", vip)
// Retrieve and verify
_, out, err = s.NodeServices(nil, "node1", nil)
assert.Nil(err)
assert.NotNil(out)
assert.Nil(t, err)
assert.NotNil(t, out)
taggedAddress = out.Services["frontend"].TaggedAddresses[structs.TaggedAddressVirtualIP]
assert.Equal(vip, taggedAddress.Address)
assert.Equal(ns4.Port, taggedAddress.Port)
assert.Equal(t, vip, taggedAddress.Address)
assert.Equal(t, ns4.Port, taggedAddress.Port)
}
func TestStateStore_Services(t *testing.T) {
@ -2360,83 +2357,80 @@ func TestStateStore_DeleteService(t *testing.T) {
}
func TestStateStore_ConnectServiceNodes(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Listing with no results returns an empty list.
ws := memdb.NewWatchSet()
idx, nodes, err := s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(0))
assert.Len(nodes, 0)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(0))
assert.Len(t, nodes, 0)
// Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "native-db", Service: "db", Connect: structs.ServiceConnect{Native: true}}))
assert.Nil(s.EnsureService(17, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.True(watchFired(ws))
assert.Nil(t, s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(t, s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(t, s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(t, s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(t, s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(t, s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(t, s.EnsureService(16, "bar", &structs.NodeService{ID: "native-db", Service: "db", Connect: structs.ServiceConnect{Native: true}}))
assert.Nil(t, s.EnsureService(17, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.True(t, watchFired(ws))
// Read everything back.
ws = memdb.NewWatchSet()
idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(17))
assert.Len(nodes, 3)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(17))
assert.Len(t, nodes, 3)
for _, n := range nodes {
assert.True(
n.ServiceKind == structs.ServiceKindConnectProxy ||
n.ServiceConnect.Native,
assert.True(t, n.ServiceKind == structs.ServiceKindConnectProxy ||
n.ServiceConnect.Native,
"either proxy or connect native")
}
// Registering some unrelated node should not fire the watch.
testRegisterNode(t, s, 17, "nope")
assert.False(watchFired(ws))
assert.False(t, watchFired(ws))
// But removing a node with the "db" service should fire the watch.
assert.Nil(s.DeleteNode(18, "bar", nil))
assert.True(watchFired(ws))
assert.Nil(t, s.DeleteNode(18, "bar", nil))
assert.True(t, watchFired(ws))
}
func TestStateStore_ConnectServiceNodes_Gateways(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Listing with no results returns an empty list.
ws := memdb.NewWatchSet()
idx, nodes, err := s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(0))
assert.Len(nodes, 0)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(0))
assert.Len(t, nodes, 0)
// Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(t, s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(t, s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
// Typical services
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.False(watchFired(ws))
assert.Nil(t, s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(t, s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(t, s.EnsureService(14, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.False(t, watchFired(ws))
// Register a sidecar for db
assert.Nil(s.EnsureService(15, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.True(watchFired(ws))
assert.Nil(t, s.EnsureService(15, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.True(t, watchFired(ws))
// Reset WatchSet to ensure watch fires when associating db with gateway
ws = memdb.NewWatchSet()
_, _, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Nil(t, err)
// Associate gateway with db
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443}))
assert.Nil(s.EnsureConfigEntry(17, &structs.TerminatingGatewayConfigEntry{
assert.Nil(t, s.EnsureService(16, "bar", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443}))
assert.Nil(t, s.EnsureConfigEntry(17, &structs.TerminatingGatewayConfigEntry{
Kind: "terminating-gateway",
Name: "gateway",
Services: []structs.LinkedService{
@ -2445,71 +2439,71 @@ func TestStateStore_ConnectServiceNodes_Gateways(t *testing.T) {
},
},
}))
assert.True(watchFired(ws))
assert.True(t, watchFired(ws))
// Read everything back.
ws = memdb.NewWatchSet()
idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(17))
assert.Len(nodes, 2)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(17))
assert.Len(t, nodes, 2)
// Check sidecar
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal("foo", nodes[0].Node)
assert.Equal("proxy", nodes[0].ServiceName)
assert.Equal("proxy", nodes[0].ServiceID)
assert.Equal("db", nodes[0].ServiceProxy.DestinationServiceName)
assert.Equal(8000, nodes[0].ServicePort)
assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal(t, "foo", nodes[0].Node)
assert.Equal(t, "proxy", nodes[0].ServiceName)
assert.Equal(t, "proxy", nodes[0].ServiceID)
assert.Equal(t, "db", nodes[0].ServiceProxy.DestinationServiceName)
assert.Equal(t, 8000, nodes[0].ServicePort)
// Check gateway
assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].ServiceKind)
assert.Equal("bar", nodes[1].Node)
assert.Equal("gateway", nodes[1].ServiceName)
assert.Equal("gateway", nodes[1].ServiceID)
assert.Equal(443, nodes[1].ServicePort)
assert.Equal(t, structs.ServiceKindTerminatingGateway, nodes[1].ServiceKind)
assert.Equal(t, "bar", nodes[1].Node)
assert.Equal(t, "gateway", nodes[1].ServiceName)
assert.Equal(t, "gateway", nodes[1].ServiceID)
assert.Equal(t, 443, nodes[1].ServicePort)
// Watch should fire when another gateway instance is registered
assert.Nil(s.EnsureService(18, "foo", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway-2", Service: "gateway", Port: 443}))
assert.True(watchFired(ws))
assert.Nil(t, s.EnsureService(18, "foo", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway-2", Service: "gateway", Port: 443}))
assert.True(t, watchFired(ws))
// Reset WatchSet to ensure watch fires when deregistering gateway
ws = memdb.NewWatchSet()
_, _, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Nil(t, err)
// Watch should fire when a gateway instance is deregistered
assert.Nil(s.DeleteService(19, "bar", "gateway", nil))
assert.True(watchFired(ws))
assert.Nil(t, s.DeleteService(19, "bar", "gateway", nil))
assert.True(t, watchFired(ws))
ws = memdb.NewWatchSet()
idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(19))
assert.Len(nodes, 2)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(19))
assert.Len(t, nodes, 2)
// Check the new gateway
assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].ServiceKind)
assert.Equal("foo", nodes[1].Node)
assert.Equal("gateway", nodes[1].ServiceName)
assert.Equal("gateway-2", nodes[1].ServiceID)
assert.Equal(443, nodes[1].ServicePort)
assert.Equal(t, structs.ServiceKindTerminatingGateway, nodes[1].ServiceKind)
assert.Equal(t, "foo", nodes[1].Node)
assert.Equal(t, "gateway", nodes[1].ServiceName)
assert.Equal(t, "gateway-2", nodes[1].ServiceID)
assert.Equal(t, 443, nodes[1].ServicePort)
// Index should not slide back after deleting all instances of the gateway
assert.Nil(s.DeleteService(20, "foo", "gateway-2", nil))
assert.True(watchFired(ws))
assert.Nil(t, s.DeleteService(20, "foo", "gateway-2", nil))
assert.True(t, watchFired(ws))
idx, nodes, err = s.ConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(20))
assert.Len(nodes, 1)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(20))
assert.Len(t, nodes, 1)
// Ensure that remaining node is the proxy and not a gateway
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal("foo", nodes[0].Node)
assert.Equal("proxy", nodes[0].ServiceName)
assert.Equal("proxy", nodes[0].ServiceID)
assert.Equal(8000, nodes[0].ServicePort)
assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal(t, "foo", nodes[0].Node)
assert.Equal(t, "proxy", nodes[0].ServiceName)
assert.Equal(t, "proxy", nodes[0].ServiceID)
assert.Equal(t, 8000, nodes[0].ServicePort)
}
func TestStateStore_Service_Snapshot(t *testing.T) {
@ -3680,14 +3674,12 @@ func TestStateStore_ConnectQueryBlocking(t *testing.T) {
tt.setupFn(s)
}
require := require.New(t)
// Run the query
ws := memdb.NewWatchSet()
_, res, err := s.CheckConnectServiceNodes(ws, tt.svc, nil)
require.NoError(err)
require.Len(res, tt.wantBeforeResLen)
require.Len(ws, tt.wantBeforeWatchSetSize)
require.NoError(t, err)
require.Len(t, res, tt.wantBeforeResLen)
require.Len(t, ws, tt.wantBeforeWatchSetSize)
// Mutate the state store
if tt.updateFn != nil {
@ -3696,18 +3688,18 @@ func TestStateStore_ConnectQueryBlocking(t *testing.T) {
fired := watchFired(ws)
if tt.shouldFire {
require.True(fired, "WatchSet should have fired")
require.True(t, fired, "WatchSet should have fired")
} else {
require.False(fired, "WatchSet should not have fired")
require.False(t, fired, "WatchSet should not have fired")
}
// Re-query the same result. Should return the desired index and len
ws = memdb.NewWatchSet()
idx, res, err := s.CheckConnectServiceNodes(ws, tt.svc, nil)
require.NoError(err)
require.Len(res, tt.wantAfterResLen)
require.Equal(tt.wantAfterIndex, idx)
require.Len(ws, tt.wantAfterWatchSetSize)
require.NoError(t, err)
require.Len(t, res, tt.wantAfterResLen)
require.Equal(t, tt.wantAfterIndex, idx)
require.Len(t, ws, tt.wantAfterWatchSetSize)
})
}
}
@ -3829,25 +3821,24 @@ func TestStateStore_CheckServiceNodes(t *testing.T) {
}
func TestStateStore_CheckConnectServiceNodes(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Listing with no results returns an empty list.
ws := memdb.NewWatchSet()
idx, nodes, err := s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(0))
assert.Len(nodes, 0)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(0))
assert.Len(t, nodes, 0)
// Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.True(watchFired(ws))
assert.Nil(t, s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(t, s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(t, s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(t, s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(t, s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(t, s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.Nil(t, s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.True(t, watchFired(ws))
// Register node checks
testRegisterCheck(t, s, 17, "foo", "", "check1", api.HealthPassing)
@ -3860,13 +3851,13 @@ func TestStateStore_CheckConnectServiceNodes(t *testing.T) {
// Read everything back.
ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(20))
assert.Len(nodes, 2)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(20))
assert.Len(t, nodes, 2)
for _, n := range nodes {
assert.Equal(structs.ServiceKindConnectProxy, n.Service.Kind)
assert.Equal("db", n.Service.Proxy.DestinationServiceName)
assert.Equal(t, structs.ServiceKindConnectProxy, n.Service.Kind)
assert.Equal(t, "db", n.Service.Proxy.DestinationServiceName)
}
}
@ -3875,34 +3866,33 @@ func TestStateStore_CheckConnectServiceNodes_Gateways(t *testing.T) {
t.Skip("too slow for testing.Short")
}
assert := assert.New(t)
s := testStateStore(t)
// Listing with no results returns an empty list.
ws := memdb.NewWatchSet()
idx, nodes, err := s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(0))
assert.Len(nodes, 0)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(0))
assert.Len(t, nodes, 0)
// Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(t, s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(t, s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
// Typical services
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.False(watchFired(ws))
assert.Nil(t, s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(t, s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(t, s.EnsureService(14, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"replica"}, Address: "", Port: 8001}))
assert.False(t, watchFired(ws))
// Register node and service checks
testRegisterCheck(t, s, 15, "foo", "", "check1", api.HealthPassing)
testRegisterCheck(t, s, 16, "bar", "", "check2", api.HealthPassing)
testRegisterCheck(t, s, 17, "foo", "db", "check3", api.HealthPassing)
assert.False(watchFired(ws))
assert.False(t, watchFired(ws))
// Watch should fire when a gateway is associated with the service, even if the gateway doesn't exist yet
assert.Nil(s.EnsureConfigEntry(18, &structs.TerminatingGatewayConfigEntry{
assert.Nil(t, s.EnsureConfigEntry(18, &structs.TerminatingGatewayConfigEntry{
Kind: "terminating-gateway",
Name: "gateway",
Services: []structs.LinkedService{
@ -3911,90 +3901,90 @@ func TestStateStore_CheckConnectServiceNodes_Gateways(t *testing.T) {
},
},
}))
assert.True(watchFired(ws))
assert.True(t, watchFired(ws))
ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(18))
assert.Len(nodes, 0)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(18))
assert.Len(t, nodes, 0)
// Watch should fire when a gateway is added
assert.Nil(s.EnsureService(19, "bar", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443}))
assert.True(watchFired(ws))
assert.Nil(t, s.EnsureService(19, "bar", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443}))
assert.True(t, watchFired(ws))
// Watch should fire when a check is added to the gateway
testRegisterCheck(t, s, 20, "bar", "gateway", "check4", api.HealthPassing)
assert.True(watchFired(ws))
assert.True(t, watchFired(ws))
// Watch should fire when a different connect service is registered for db
assert.Nil(s.EnsureService(21, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.True(watchFired(ws))
assert.Nil(t, s.EnsureService(21, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", Proxy: structs.ConnectProxyConfig{DestinationServiceName: "db"}, Port: 8000}))
assert.True(t, watchFired(ws))
// Read everything back.
ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(21))
assert.Len(nodes, 2)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(21))
assert.Len(t, nodes, 2)
// Check sidecar
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].Service.Kind)
assert.Equal("foo", nodes[0].Node.Node)
assert.Equal("proxy", nodes[0].Service.Service)
assert.Equal("proxy", nodes[0].Service.ID)
assert.Equal("db", nodes[0].Service.Proxy.DestinationServiceName)
assert.Equal(8000, nodes[0].Service.Port)
assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].Service.Kind)
assert.Equal(t, "foo", nodes[0].Node.Node)
assert.Equal(t, "proxy", nodes[0].Service.Service)
assert.Equal(t, "proxy", nodes[0].Service.ID)
assert.Equal(t, "db", nodes[0].Service.Proxy.DestinationServiceName)
assert.Equal(t, 8000, nodes[0].Service.Port)
// Check gateway
assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].Service.Kind)
assert.Equal("bar", nodes[1].Node.Node)
assert.Equal("gateway", nodes[1].Service.Service)
assert.Equal("gateway", nodes[1].Service.ID)
assert.Equal(443, nodes[1].Service.Port)
assert.Equal(t, structs.ServiceKindTerminatingGateway, nodes[1].Service.Kind)
assert.Equal(t, "bar", nodes[1].Node.Node)
assert.Equal(t, "gateway", nodes[1].Service.Service)
assert.Equal(t, "gateway", nodes[1].Service.ID)
assert.Equal(t, 443, nodes[1].Service.Port)
// Watch should fire when another gateway instance is registered
assert.Nil(s.EnsureService(22, "foo", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway-2", Service: "gateway", Port: 443}))
assert.True(watchFired(ws))
assert.Nil(t, s.EnsureService(22, "foo", &structs.NodeService{Kind: structs.ServiceKindTerminatingGateway, ID: "gateway-2", Service: "gateway", Port: 443}))
assert.True(t, watchFired(ws))
ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(22))
assert.Len(nodes, 3)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(22))
assert.Len(t, nodes, 3)
// Watch should fire when a gateway instance is deregistered
assert.Nil(s.DeleteService(23, "bar", "gateway", nil))
assert.True(watchFired(ws))
assert.Nil(t, s.DeleteService(23, "bar", "gateway", nil))
assert.True(t, watchFired(ws))
ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(23))
assert.Len(nodes, 2)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(23))
assert.Len(t, nodes, 2)
// Check new gateway
assert.Equal(structs.ServiceKindTerminatingGateway, nodes[1].Service.Kind)
assert.Equal("foo", nodes[1].Node.Node)
assert.Equal("gateway", nodes[1].Service.Service)
assert.Equal("gateway-2", nodes[1].Service.ID)
assert.Equal(443, nodes[1].Service.Port)
assert.Equal(t, structs.ServiceKindTerminatingGateway, nodes[1].Service.Kind)
assert.Equal(t, "foo", nodes[1].Node.Node)
assert.Equal(t, "gateway", nodes[1].Service.Service)
assert.Equal(t, "gateway-2", nodes[1].Service.ID)
assert.Equal(t, 443, nodes[1].Service.Port)
// Index should not slide back after deleting all instances of the gateway
assert.Nil(s.DeleteService(24, "foo", "gateway-2", nil))
assert.True(watchFired(ws))
assert.Nil(t, s.DeleteService(24, "foo", "gateway-2", nil))
assert.True(t, watchFired(ws))
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db", nil)
assert.Nil(err)
assert.Equal(idx, uint64(24))
assert.Len(nodes, 1)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(24))
assert.Len(t, nodes, 1)
// Ensure that remaining node is the proxy and not a gateway
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].Service.Kind)
assert.Equal("foo", nodes[0].Node.Node)
assert.Equal("proxy", nodes[0].Service.Service)
assert.Equal("proxy", nodes[0].Service.ID)
assert.Equal(8000, nodes[0].Service.Port)
assert.Equal(t, structs.ServiceKindConnectProxy, nodes[0].Service.Kind)
assert.Equal(t, "foo", nodes[0].Node.Node)
assert.Equal(t, "proxy", nodes[0].Service.Service)
assert.Equal(t, "proxy", nodes[0].Service.ID)
assert.Equal(t, 8000, nodes[0].Service.Port)
}
func BenchmarkCheckServiceNodes(b *testing.B) {
@ -5255,14 +5245,13 @@ func TestStateStore_GatewayServices_ServiceDeletion(t *testing.T) {
func TestStateStore_CheckIngressServiceNodes(t *testing.T) {
s := testStateStore(t)
ws := setupIngressState(t, s)
require := require.New(t)
t.Run("check service1 ingress gateway", func(t *testing.T) {
idx, results, err := s.CheckIngressServiceNodes(ws, "service1", nil)
require.NoError(err)
require.Equal(uint64(15), idx)
require.NoError(t, err)
require.Equal(t, uint64(15), idx)
// Multiple instances of the ingress2 service
require.Len(results, 4)
require.Len(t, results, 4)
ids := make(map[string]struct{})
for _, n := range results {
@ -5273,14 +5262,14 @@ func TestStateStore_CheckIngressServiceNodes(t *testing.T) {
"ingress2": {},
"wildcardIngress": {},
}
require.Equal(expectedIds, ids)
require.Equal(t, expectedIds, ids)
})
t.Run("check service2 ingress gateway", func(t *testing.T) {
idx, results, err := s.CheckIngressServiceNodes(ws, "service2", nil)
require.NoError(err)
require.Equal(uint64(15), idx)
require.Len(results, 2)
require.NoError(t, err)
require.Equal(t, uint64(15), idx)
require.Len(t, results, 2)
ids := make(map[string]struct{})
for _, n := range results {
@ -5290,38 +5279,38 @@ func TestStateStore_CheckIngressServiceNodes(t *testing.T) {
"ingress1": {},
"wildcardIngress": {},
}
require.Equal(expectedIds, ids)
require.Equal(t, expectedIds, ids)
})
t.Run("check service3 ingress gateway", func(t *testing.T) {
ws := memdb.NewWatchSet()
idx, results, err := s.CheckIngressServiceNodes(ws, "service3", nil)
require.NoError(err)
require.Equal(uint64(15), idx)
require.Len(results, 1)
require.Equal("wildcardIngress", results[0].Service.ID)
require.NoError(t, err)
require.Equal(t, uint64(15), idx)
require.Len(t, results, 1)
require.Equal(t, "wildcardIngress", results[0].Service.ID)
})
t.Run("delete a wildcard entry", func(t *testing.T) {
require.Nil(s.DeleteConfigEntry(19, "ingress-gateway", "wildcardIngress", nil))
require.True(watchFired(ws))
require.Nil(t, s.DeleteConfigEntry(19, "ingress-gateway", "wildcardIngress", nil))
require.True(t, watchFired(ws))
idx, results, err := s.CheckIngressServiceNodes(ws, "service1", nil)
require.NoError(err)
require.Equal(uint64(15), idx)
require.Len(results, 3)
require.NoError(t, err)
require.Equal(t, uint64(15), idx)
require.Len(t, results, 3)
idx, results, err = s.CheckIngressServiceNodes(ws, "service2", nil)
require.NoError(err)
require.Equal(uint64(15), idx)
require.Len(results, 1)
require.NoError(t, err)
require.Equal(t, uint64(15), idx)
require.Len(t, results, 1)
idx, results, err = s.CheckIngressServiceNodes(ws, "service3", nil)
require.NoError(err)
require.Equal(uint64(15), idx)
require.NoError(t, err)
require.Equal(t, uint64(15), idx)
// TODO(ingress): index goes backward when deleting last config entry
// require.Equal(uint64(11), idx)
require.Len(results, 0)
// require.Equal(t,uint64(11), idx)
require.Len(t, results, 0)
})
}
@ -5629,56 +5618,55 @@ func TestStateStore_GatewayServices_WildcardAssociation(t *testing.T) {
s := testStateStore(t)
setupIngressState(t, s)
require := require.New(t)
ws := memdb.NewWatchSet()
t.Run("base case for wildcard", func(t *testing.T) {
idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil)
require.NoError(err)
require.Equal(uint64(16), idx)
require.Len(results, 3)
require.NoError(t, err)
require.Equal(t, uint64(16), idx)
require.Len(t, results, 3)
})
t.Run("do not associate ingress services with gateway", func(t *testing.T) {
testRegisterIngressService(t, s, 17, "node1", "testIngress")
require.False(watchFired(ws))
require.False(t, watchFired(ws))
idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil)
require.NoError(err)
require.Equal(uint64(16), idx)
require.Len(results, 3)
require.NoError(t, err)
require.Equal(t, uint64(16), idx)
require.Len(t, results, 3)
})
t.Run("do not associate terminating-gateway services with gateway", func(t *testing.T) {
require.Nil(s.EnsureService(18, "node1",
require.Nil(t, s.EnsureService(18, "node1",
&structs.NodeService{
Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443,
},
))
require.False(watchFired(ws))
require.False(t, watchFired(ws))
idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil)
require.NoError(err)
require.Equal(uint64(16), idx)
require.Len(results, 3)
require.NoError(t, err)
require.Equal(t, uint64(16), idx)
require.Len(t, results, 3)
})
t.Run("do not associate connect-proxy services with gateway", func(t *testing.T) {
testRegisterSidecarProxy(t, s, 19, "node1", "web")
require.False(watchFired(ws))
require.False(t, watchFired(ws))
idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil)
require.NoError(err)
require.Equal(uint64(16), idx)
require.Len(results, 3)
require.NoError(t, err)
require.Equal(t, uint64(16), idx)
require.Len(t, results, 3)
})
t.Run("do not associate consul services with gateway", func(t *testing.T) {
require.Nil(s.EnsureService(20, "node1",
require.Nil(t, s.EnsureService(20, "node1",
&structs.NodeService{ID: "consul", Service: "consul", Tags: nil},
))
require.False(watchFired(ws))
require.False(t, watchFired(ws))
idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil)
require.NoError(err)
require.Equal(uint64(16), idx)
require.Len(results, 3)
require.NoError(t, err)
require.Equal(t, uint64(16), idx)
require.Len(t, results, 3)
})
}
@ -5709,15 +5697,13 @@ func TestStateStore_GatewayServices_IngressProtocolFiltering(t *testing.T) {
})
t.Run("no services from default tcp protocol", func(t *testing.T) {
require := require.New(t)
idx, results, err := s.GatewayServices(nil, "ingress1", nil)
require.NoError(err)
require.Equal(uint64(4), idx)
require.Len(results, 0)
require.NoError(t, err)
require.Equal(t, uint64(4), idx)
require.Len(t, results, 0)
})
t.Run("service-defaults", func(t *testing.T) {
require := require.New(t)
expected := structs.GatewayServices{
{
Gateway: structs.NewServiceName("ingress1", nil),
@ -5740,13 +5726,12 @@ func TestStateStore_GatewayServices_IngressProtocolFiltering(t *testing.T) {
}
assert.NoError(t, s.EnsureConfigEntry(5, svcDefaults))
idx, results, err := s.GatewayServices(nil, "ingress1", nil)
require.NoError(err)
require.Equal(uint64(5), idx)
require.ElementsMatch(results, expected)
require.NoError(t, err)
require.Equal(t, uint64(5), idx)
require.ElementsMatch(t, results, expected)
})
t.Run("proxy-defaults", func(t *testing.T) {
require := require.New(t)
expected := structs.GatewayServices{
{
Gateway: structs.NewServiceName("ingress1", nil),
@ -5784,13 +5769,12 @@ func TestStateStore_GatewayServices_IngressProtocolFiltering(t *testing.T) {
assert.NoError(t, s.EnsureConfigEntry(6, proxyDefaults))
idx, results, err := s.GatewayServices(nil, "ingress1", nil)
require.NoError(err)
require.Equal(uint64(6), idx)
require.ElementsMatch(results, expected)
require.NoError(t, err)
require.Equal(t, uint64(6), idx)
require.ElementsMatch(t, results, expected)
})
t.Run("service-defaults overrides proxy-defaults", func(t *testing.T) {
require := require.New(t)
expected := structs.GatewayServices{
{
Gateway: structs.NewServiceName("ingress1", nil),
@ -5814,13 +5798,12 @@ func TestStateStore_GatewayServices_IngressProtocolFiltering(t *testing.T) {
assert.NoError(t, s.EnsureConfigEntry(7, svcDefaults))
idx, results, err := s.GatewayServices(nil, "ingress1", nil)
require.NoError(err)
require.Equal(uint64(7), idx)
require.ElementsMatch(results, expected)
require.NoError(t, err)
require.Equal(t, uint64(7), idx)
require.ElementsMatch(t, results, expected)
})
t.Run("change listener protocol and expect different filter", func(t *testing.T) {
require := require.New(t)
expected := structs.GatewayServices{
{
Gateway: structs.NewServiceName("ingress1", nil),
@ -5854,9 +5837,9 @@ func TestStateStore_GatewayServices_IngressProtocolFiltering(t *testing.T) {
assert.NoError(t, s.EnsureConfigEntry(8, ingress1))
idx, results, err := s.GatewayServices(nil, "ingress1", nil)
require.NoError(err)
require.Equal(uint64(8), idx)
require.ElementsMatch(results, expected)
require.NoError(t, err)
require.Equal(t, uint64(8), idx)
require.ElementsMatch(t, results, expected)
})
}

View File

@ -12,7 +12,6 @@ import (
)
func TestStore_ConfigEntry(t *testing.T) {
require := require.New(t)
s := testConfigStateStore(t)
expected := &structs.ProxyConfigEntry{
@ -24,12 +23,12 @@ func TestStore_ConfigEntry(t *testing.T) {
}
// Create
require.NoError(s.EnsureConfigEntry(0, expected))
require.NoError(t, s.EnsureConfigEntry(0, expected))
idx, config, err := s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err)
require.Equal(uint64(0), idx)
require.Equal(expected, config)
require.NoError(t, err)
require.Equal(t, uint64(0), idx)
require.Equal(t, expected, config)
// Update
updated := &structs.ProxyConfigEntry{
@ -39,44 +38,43 @@ func TestStore_ConfigEntry(t *testing.T) {
"DestinationServiceName": "bar",
},
}
require.NoError(s.EnsureConfigEntry(1, updated))
require.NoError(t, s.EnsureConfigEntry(1, updated))
idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err)
require.Equal(uint64(1), idx)
require.Equal(updated, config)
require.NoError(t, err)
require.Equal(t, uint64(1), idx)
require.Equal(t, updated, config)
// Delete
require.NoError(s.DeleteConfigEntry(2, structs.ProxyDefaults, "global", nil))
require.NoError(t, s.DeleteConfigEntry(2, structs.ProxyDefaults, "global", nil))
idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err)
require.Equal(uint64(2), idx)
require.Nil(config)
require.NoError(t, err)
require.Equal(t, uint64(2), idx)
require.Nil(t, config)
// Set up a watch.
serviceConf := &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "foo",
}
require.NoError(s.EnsureConfigEntry(3, serviceConf))
require.NoError(t, s.EnsureConfigEntry(3, serviceConf))
ws := memdb.NewWatchSet()
_, _, err = s.ConfigEntry(ws, structs.ServiceDefaults, "foo", nil)
require.NoError(err)
require.NoError(t, err)
// Make an unrelated modification and make sure the watch doesn't fire.
require.NoError(s.EnsureConfigEntry(4, updated))
require.False(watchFired(ws))
require.NoError(t, s.EnsureConfigEntry(4, updated))
require.False(t, watchFired(ws))
// Update the watched config and make sure it fires.
serviceConf.Protocol = "http"
require.NoError(s.EnsureConfigEntry(5, serviceConf))
require.True(watchFired(ws))
require.NoError(t, s.EnsureConfigEntry(5, serviceConf))
require.True(t, watchFired(ws))
}
func TestStore_ConfigEntryCAS(t *testing.T) {
require := require.New(t)
s := testConfigStateStore(t)
expected := &structs.ProxyConfigEntry{
@ -88,12 +86,12 @@ func TestStore_ConfigEntryCAS(t *testing.T) {
}
// Create
require.NoError(s.EnsureConfigEntry(1, expected))
require.NoError(t, s.EnsureConfigEntry(1, expected))
idx, config, err := s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err)
require.Equal(uint64(1), idx)
require.Equal(expected, config)
require.NoError(t, err)
require.Equal(t, uint64(1), idx)
require.Equal(t, expected, config)
// Update with invalid index
updated := &structs.ProxyConfigEntry{
@ -104,29 +102,28 @@ func TestStore_ConfigEntryCAS(t *testing.T) {
},
}
ok, err := s.EnsureConfigEntryCAS(2, 99, updated)
require.False(ok)
require.NoError(err)
require.False(t, ok)
require.NoError(t, err)
// Entry should not be changed
idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err)
require.Equal(uint64(1), idx)
require.Equal(expected, config)
require.NoError(t, err)
require.Equal(t, uint64(1), idx)
require.Equal(t, expected, config)
// Update with a valid index
ok, err = s.EnsureConfigEntryCAS(2, 1, updated)
require.True(ok)
require.NoError(err)
require.True(t, ok)
require.NoError(t, err)
// Entry should be updated
idx, config, err = s.ConfigEntry(nil, structs.ProxyDefaults, "global", nil)
require.NoError(err)
require.Equal(uint64(2), idx)
require.Equal(updated, config)
require.NoError(t, err)
require.Equal(t, uint64(2), idx)
require.Equal(t, updated, config)
}
func TestStore_ConfigEntry_DeleteCAS(t *testing.T) {
require := require.New(t)
s := testConfigStateStore(t)
entry := &structs.ProxyConfigEntry{
@ -139,31 +136,31 @@ func TestStore_ConfigEntry_DeleteCAS(t *testing.T) {
// Attempt to delete the entry before it exists.
ok, err := s.DeleteConfigEntryCAS(1, 0, entry)
require.NoError(err)
require.False(ok)
require.NoError(t, err)
require.False(t, ok)
// Create the entry.
require.NoError(s.EnsureConfigEntry(1, entry))
require.NoError(t, s.EnsureConfigEntry(1, entry))
// Attempt to delete with an invalid index.
ok, err = s.DeleteConfigEntryCAS(2, 99, entry)
require.NoError(err)
require.False(ok)
require.NoError(t, err)
require.False(t, ok)
// Entry should not be deleted.
_, config, err := s.ConfigEntry(nil, entry.Kind, entry.Name, nil)
require.NoError(err)
require.NotNil(config)
require.NoError(t, err)
require.NotNil(t, config)
// Attempt to delete with a valid index.
ok, err = s.DeleteConfigEntryCAS(2, 1, entry)
require.NoError(err)
require.True(ok)
require.NoError(t, err)
require.True(t, ok)
// Entry should be deleted.
_, config, err = s.ConfigEntry(nil, entry.Kind, entry.Name, nil)
require.NoError(err)
require.Nil(config)
require.NoError(t, err)
require.Nil(t, config)
}
func TestStore_ConfigEntry_UpdateOver(t *testing.T) {
@ -263,7 +260,6 @@ func TestStore_ConfigEntry_UpdateOver(t *testing.T) {
}
func TestStore_ConfigEntries(t *testing.T) {
require := require.New(t)
s := testConfigStateStore(t)
// Create some config entries.
@ -280,39 +276,39 @@ func TestStore_ConfigEntries(t *testing.T) {
Name: "test3",
}
require.NoError(s.EnsureConfigEntry(0, entry1))
require.NoError(s.EnsureConfigEntry(1, entry2))
require.NoError(s.EnsureConfigEntry(2, entry3))
require.NoError(t, s.EnsureConfigEntry(0, entry1))
require.NoError(t, s.EnsureConfigEntry(1, entry2))
require.NoError(t, s.EnsureConfigEntry(2, entry3))
// Get all entries
idx, entries, err := s.ConfigEntries(nil, nil)
require.NoError(err)
require.Equal(uint64(2), idx)
require.Equal([]structs.ConfigEntry{entry1, entry2, entry3}, entries)
require.NoError(t, err)
require.Equal(t, uint64(2), idx)
require.Equal(t, []structs.ConfigEntry{entry1, entry2, entry3}, entries)
// Get all proxy entries
idx, entries, err = s.ConfigEntriesByKind(nil, structs.ProxyDefaults, nil)
require.NoError(err)
require.Equal(uint64(2), idx)
require.Equal([]structs.ConfigEntry{entry1}, entries)
require.NoError(t, err)
require.Equal(t, uint64(2), idx)
require.Equal(t, []structs.ConfigEntry{entry1}, entries)
// Get all service entries
ws := memdb.NewWatchSet()
idx, entries, err = s.ConfigEntriesByKind(ws, structs.ServiceDefaults, nil)
require.NoError(err)
require.Equal(uint64(2), idx)
require.Equal([]structs.ConfigEntry{entry2, entry3}, entries)
require.NoError(t, err)
require.Equal(t, uint64(2), idx)
require.Equal(t, []structs.ConfigEntry{entry2, entry3}, entries)
// Watch should not have fired
require.False(watchFired(ws))
require.False(t, watchFired(ws))
// Now make an update and make sure the watch fires.
require.NoError(s.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
require.NoError(t, s.EnsureConfigEntry(3, &structs.ServiceConfigEntry{
Kind: structs.ServiceDefaults,
Name: "test2",
Protocol: "tcp",
}))
require.True(watchFired(ws))
require.True(t, watchFired(ws))
}
func TestStore_ConfigEntry_GraphValidation(t *testing.T) {

View File

@ -184,25 +184,24 @@ func TestStore_CAConfig_Snapshot_Restore_BlankConfig(t *testing.T) {
}
func TestStore_CARootSetList(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
assert.Nil(t, err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
expected := *ca1
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
assert.Nil(err)
assert.True(ok)
assert.Nil(t, err)
assert.True(t, ok)
// Make sure the index got updated.
assert.Equal(s.maxIndex(tableConnectCARoots), uint64(1))
assert.True(watchFired(ws), "watch fired")
assert.Equal(t, s.maxIndex(tableConnectCARoots), uint64(1))
assert.True(t, watchFired(ws), "watch fired")
// Read it back out and verify it.
@ -212,20 +211,19 @@ func TestStore_CARootSetList(t *testing.T) {
}
ws = memdb.NewWatchSet()
_, roots, err := s.CARoots(ws)
assert.Nil(err)
assert.Len(roots, 1)
assert.Nil(t, err)
assert.Len(t, roots, 1)
actual := roots[0]
assertDeepEqual(t, expected, *actual)
}
func TestStore_CARootSet_emptyID(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
assert.Nil(t, err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
@ -233,29 +231,28 @@ func TestStore_CARootSet_emptyID(t *testing.T) {
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
assert.NotNil(err)
assert.Contains(err.Error(), ErrMissingCARootID.Error())
assert.False(ok)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), ErrMissingCARootID.Error())
assert.False(t, ok)
// Make sure the index got updated.
assert.Equal(s.maxIndex(tableConnectCARoots), uint64(0))
assert.False(watchFired(ws), "watch fired")
assert.Equal(t, s.maxIndex(tableConnectCARoots), uint64(0))
assert.False(t, watchFired(ws), "watch fired")
// Read it back out and verify it.
ws = memdb.NewWatchSet()
_, roots, err := s.CARoots(ws)
assert.Nil(err)
assert.Len(roots, 0)
assert.Nil(t, err)
assert.Len(t, roots, 0)
}
func TestStore_CARootSet_noActive(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
assert.Nil(t, err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
@ -265,19 +262,18 @@ func TestStore_CARootSet_noActive(t *testing.T) {
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
assert.NotNil(err)
assert.Contains(err.Error(), "exactly one active")
assert.False(ok)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "exactly one active")
assert.False(t, ok)
}
func TestStore_CARootSet_multipleActive(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
assert.Nil(t, err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
@ -285,13 +281,12 @@ func TestStore_CARootSet_multipleActive(t *testing.T) {
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
assert.NotNil(err)
assert.Contains(err.Error(), "exactly one active")
assert.False(ok)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "exactly one active")
assert.False(t, ok)
}
func TestStore_CARootActive_valid(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Build a valid value
@ -303,33 +298,31 @@ func TestStore_CARootActive_valid(t *testing.T) {
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2, ca3})
assert.Nil(err)
assert.True(ok)
assert.Nil(t, err)
assert.True(t, ok)
// Query
ws := memdb.NewWatchSet()
idx, res, err := s.CARootActive(ws)
assert.Equal(idx, uint64(1))
assert.Nil(err)
assert.NotNil(res)
assert.Equal(ca2.ID, res.ID)
assert.Equal(t, idx, uint64(1))
assert.Nil(t, err)
assert.NotNil(t, res)
assert.Equal(t, ca2.ID, res.ID)
}
// Test that querying the active CA returns the correct value.
func TestStore_CARootActive_none(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Querying with no results returns nil.
ws := memdb.NewWatchSet()
idx, res, err := s.CARootActive(ws)
assert.Equal(idx, uint64(0))
assert.Nil(res)
assert.Nil(err)
assert.Equal(t, idx, uint64(0))
assert.Nil(t, res)
assert.Nil(t, err)
}
func TestStore_CARoot_Snapshot_Restore(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Create some intentions.
@ -351,8 +344,8 @@ func TestStore_CARoot_Snapshot_Restore(t *testing.T) {
// Now create
ok, err := s.CARootSetCAS(1, 0, roots)
assert.Nil(err)
assert.True(ok)
assert.Nil(t, err)
assert.True(t, ok)
// Snapshot the queries.
snap := s.Snapshot()
@ -360,34 +353,33 @@ func TestStore_CARoot_Snapshot_Restore(t *testing.T) {
// Alter the real state store.
ok, err = s.CARootSetCAS(2, 1, roots[:1])
assert.Nil(err)
assert.True(ok)
assert.Nil(t, err)
assert.True(t, ok)
// Verify the snapshot.
assert.Equal(snap.LastIndex(), uint64(1))
assert.Equal(t, snap.LastIndex(), uint64(1))
dump, err := snap.CARoots()
assert.Nil(err)
assert.Equal(roots, dump)
assert.Nil(t, err)
assert.Equal(t, roots, dump)
// Restore the values into a new state store.
func() {
s := testStateStore(t)
restore := s.Restore()
for _, r := range dump {
assert.Nil(restore.CARoot(r))
assert.Nil(t, restore.CARoot(r))
}
restore.Commit()
// Read the restored values back out and verify that they match.
idx, actual, err := s.CARoots(nil)
assert.Nil(err)
assert.Equal(idx, uint64(2))
assert.Equal(roots, actual)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(2))
assert.Equal(t, roots, actual)
}()
}
func TestStore_CABuiltinProvider(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
{
@ -398,13 +390,13 @@ func TestStore_CABuiltinProvider(t *testing.T) {
}
ok, err := s.CASetProviderState(0, expected)
assert.NoError(err)
assert.True(ok)
assert.NoError(t, err)
assert.True(t, ok)
idx, state, err := s.CAProviderState(expected.ID)
assert.NoError(err)
assert.Equal(idx, uint64(0))
assert.Equal(expected, state)
assert.NoError(t, err)
assert.Equal(t, idx, uint64(0))
assert.Equal(t, expected, state)
}
{
@ -415,13 +407,13 @@ func TestStore_CABuiltinProvider(t *testing.T) {
}
ok, err := s.CASetProviderState(1, expected)
assert.NoError(err)
assert.True(ok)
assert.NoError(t, err)
assert.True(t, ok)
idx, state, err := s.CAProviderState(expected.ID)
assert.NoError(err)
assert.Equal(idx, uint64(1))
assert.Equal(expected, state)
assert.NoError(t, err)
assert.Equal(t, idx, uint64(1))
assert.Equal(t, expected, state)
}
{
@ -429,21 +421,20 @@ func TestStore_CABuiltinProvider(t *testing.T) {
// numbers will initialize from the max index of the provider table.
// That's why this first serial is 2 and not 1.
sn, err := s.CAIncrementProviderSerialNumber(10)
assert.NoError(err)
assert.Equal(uint64(2), sn)
assert.NoError(t, err)
assert.Equal(t, uint64(2), sn)
sn, err = s.CAIncrementProviderSerialNumber(10)
assert.NoError(err)
assert.Equal(uint64(3), sn)
assert.NoError(t, err)
assert.Equal(t, uint64(3), sn)
sn, err = s.CAIncrementProviderSerialNumber(10)
assert.NoError(err)
assert.Equal(uint64(4), sn)
assert.NoError(t, err)
assert.Equal(t, uint64(4), sn)
}
}
func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Create multiple state entries.
@ -462,8 +453,8 @@ func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) {
for i, state := range before {
ok, err := s.CASetProviderState(uint64(98+i), state)
assert.NoError(err)
assert.True(ok)
assert.NoError(t, err)
assert.True(t, ok)
}
// Take a snapshot.
@ -477,26 +468,26 @@ func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) {
RootCert: "d",
}
ok, err := s.CASetProviderState(100, after)
assert.NoError(err)
assert.True(ok)
assert.NoError(t, err)
assert.True(t, ok)
snapped, err := snap.CAProviderState()
assert.NoError(err)
assert.Equal(before, snapped)
assert.NoError(t, err)
assert.Equal(t, before, snapped)
// Restore onto a new state store.
s2 := testStateStore(t)
restore := s2.Restore()
for _, entry := range snapped {
assert.NoError(restore.CAProviderState(entry))
assert.NoError(t, restore.CAProviderState(entry))
}
restore.Commit()
// Verify the restored values match those from before the snapshot.
for _, state := range before {
idx, res, err := s2.CAProviderState(state.ID)
assert.NoError(err)
assert.Equal(idx, uint64(99))
assert.Equal(state, res)
assert.NoError(t, err)
assert.Equal(t, idx, uint64(99))
assert.Equal(t, state, res)
}
}

View File

@ -46,14 +46,13 @@ func testBothIntentionFormats(t *testing.T, f func(t *testing.T, s *Store, legac
func TestStore_IntentionGet_none(t *testing.T) {
testBothIntentionFormats(t, func(t *testing.T, s *Store, legacy bool) {
assert := assert.New(t)
// Querying with no results returns nil.
ws := memdb.NewWatchSet()
idx, _, res, err := s.IntentionGet(ws, testUUID())
assert.Equal(uint64(1), idx)
assert.Nil(res)
assert.Nil(err)
assert.Equal(t, uint64(1), idx)
assert.Nil(t, res)
assert.Nil(t, err)
})
}

View File

@ -5,8 +5,9 @@ import (
"strings"
"testing"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/consul/agent/structs"
)
func TestStateStore_PreparedQuery_isUUID(t *testing.T) {
@ -663,7 +664,7 @@ func TestStateStore_PreparedQueryResolve(t *testing.T) {
Regexp: "^prod-(.*)$",
},
Service: structs.ServiceQuery{
Service: "${match(1)}-master",
Service: "${match(1)}-primary",
},
}
if err := s.PreparedQuerySet(5, tmpl2); err != nil {
@ -705,7 +706,7 @@ func TestStateStore_PreparedQueryResolve(t *testing.T) {
Regexp: "^prod-(.*)$",
},
Service: structs.ServiceQuery{
Service: "redis-foobar-master",
Service: "redis-foobar-primary",
},
RaftIndex: structs.RaftIndex{
CreateIndex: 5,

View File

@ -18,7 +18,6 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
}
t.Parallel()
require := require.New(t)
s := testACLTokensStateStore(t)
// Setup token and wait for good state
@ -37,14 +36,14 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
go publisher.Run(ctx)
s.db.publisher = publisher
sub, err := publisher.Subscribe(subscription)
require.NoError(err)
require.NoError(t, err)
defer sub.Unsubscribe()
eventCh := testRunSub(sub)
// Stream should get EndOfSnapshot
e := assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot())
require.True(t, e.IsEndOfSnapshot())
// Update an unrelated token.
token2 := &structs.ACLToken{
@ -52,7 +51,7 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
SecretID: "72e81982-7a0f-491f-a60e-c9c802ac1402",
}
token2.SetHash(false)
require.NoError(s.ACLTokenSet(3, token2.Clone()))
require.NoError(t, s.ACLTokenSet(3, token2.Clone()))
// Ensure there's no reset event.
assertNoEvent(t, eventCh)
@ -64,11 +63,11 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
Description: "something else",
}
token3.SetHash(false)
require.NoError(s.ACLTokenSet(4, token3.Clone()))
require.NoError(t, s.ACLTokenSet(4, token3.Clone()))
// Ensure the reset event was sent.
err = assertErr(t, eventCh)
require.Equal(stream.ErrSubForceClosed, err)
require.Equal(t, stream.ErrSubForceClosed, err)
// Register another subscription.
subscription2 := &stream.SubscribeRequest{
@ -77,27 +76,27 @@ func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
Token: token.SecretID,
}
sub2, err := publisher.Subscribe(subscription2)
require.NoError(err)
require.NoError(t, err)
defer sub2.Unsubscribe()
eventCh2 := testRunSub(sub2)
// Expect initial EoS
e = assertEvent(t, eventCh2)
require.True(e.IsEndOfSnapshot())
require.True(t, e.IsEndOfSnapshot())
// Delete the unrelated token.
require.NoError(s.ACLTokenDeleteByAccessor(5, token2.AccessorID, nil))
require.NoError(t, s.ACLTokenDeleteByAccessor(5, token2.AccessorID, nil))
// Ensure there's no reset event.
assertNoEvent(t, eventCh2)
// Delete the token used by the subscriber.
require.NoError(s.ACLTokenDeleteByAccessor(6, token.AccessorID, nil))
require.NoError(t, s.ACLTokenDeleteByAccessor(6, token.AccessorID, nil))
// Ensure the reset event was sent.
err = assertErr(t, eventCh2)
require.Equal(stream.ErrSubForceClosed, err)
require.Equal(t, stream.ErrSubForceClosed, err)
}
func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
@ -106,7 +105,6 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
}
t.Parallel()
require := require.New(t)
s := testACLTokensStateStore(t)
// Create token and wait for good state
@ -125,14 +123,14 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
go publisher.Run(ctx)
s.db.publisher = publisher
sub, err := publisher.Subscribe(subscription)
require.NoError(err)
require.NoError(t, err)
defer sub.Unsubscribe()
eventCh := testRunSub(sub)
// Ignore the end of snapshot event
e := assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
require.True(t, e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
// Update an unrelated policy.
policy2 := structs.ACLPolicy{
@ -143,7 +141,7 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
Datacenters: []string{"dc1"},
}
policy2.SetHash(false)
require.NoError(s.ACLPolicySet(3, &policy2))
require.NoError(t, s.ACLPolicySet(3, &policy2))
// Ensure there's no reset event.
assertNoEvent(t, eventCh)
@ -157,7 +155,7 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
Datacenters: []string{"dc1"},
}
policy3.SetHash(false)
require.NoError(s.ACLPolicySet(4, &policy3))
require.NoError(t, s.ACLPolicySet(4, &policy3))
// Ensure the reset event was sent.
assertReset(t, eventCh, true)
@ -169,27 +167,27 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
Token: token.SecretID,
}
sub, err = publisher.Subscribe(subscription2)
require.NoError(err)
require.NoError(t, err)
defer sub.Unsubscribe()
eventCh = testRunSub(sub)
// Ignore the end of snapshot event
e = assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
require.True(t, e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
// Delete the unrelated policy.
require.NoError(s.ACLPolicyDeleteByID(5, testPolicyID_C, nil))
require.NoError(t, s.ACLPolicyDeleteByID(5, testPolicyID_C, nil))
// Ensure there's no reload event.
assertNoEvent(t, eventCh)
// Delete the policy used by the subscriber.
require.NoError(s.ACLPolicyDeleteByID(6, testPolicyID_A, nil))
require.NoError(t, s.ACLPolicyDeleteByID(6, testPolicyID_A, nil))
// Ensure the reload event was sent.
err = assertErr(t, eventCh)
require.Equal(stream.ErrSubForceClosed, err)
require.Equal(t, stream.ErrSubForceClosed, err)
// Register another subscription.
subscription3 := &stream.SubscribeRequest{
@ -198,14 +196,14 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
Token: token.SecretID,
}
sub, err = publisher.Subscribe(subscription3)
require.NoError(err)
require.NoError(t, err)
defer sub.Unsubscribe()
eventCh = testRunSub(sub)
// Ignore the end of snapshot event
e = assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
require.True(t, e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
// Now update the policy used in role B, but not directly in the token.
policy4 := structs.ACLPolicy{
@ -216,7 +214,7 @@ func TestStore_IntegrationWithEventPublisher_ACLPolicyUpdate(t *testing.T) {
Datacenters: []string{"dc1"},
}
policy4.SetHash(false)
require.NoError(s.ACLPolicySet(7, &policy4))
require.NoError(t, s.ACLPolicySet(7, &policy4))
// Ensure the reset event was sent.
assertReset(t, eventCh, true)
@ -228,7 +226,6 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
}
t.Parallel()
require := require.New(t)
s := testACLTokensStateStore(t)
// Create token and wait for good state
@ -247,13 +244,13 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
go publisher.Run(ctx)
s.db.publisher = publisher
sub, err := publisher.Subscribe(subscription)
require.NoError(err)
require.NoError(t, err)
eventCh := testRunSub(sub)
// Stream should get EndOfSnapshot
e := assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot())
require.True(t, e.IsEndOfSnapshot())
// Update an unrelated role (the token has role testRoleID_B).
role := structs.ACLRole{
@ -262,7 +259,7 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
Description: "test",
}
role.SetHash(false)
require.NoError(s.ACLRoleSet(3, &role))
require.NoError(t, s.ACLRoleSet(3, &role))
// Ensure there's no reload event.
assertNoEvent(t, eventCh)
@ -274,7 +271,7 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
Description: "changed",
}
role2.SetHash(false)
require.NoError(s.ACLRoleSet(4, &role2))
require.NoError(t, s.ACLRoleSet(4, &role2))
// Ensure the reload event was sent.
assertReset(t, eventCh, false)
@ -286,22 +283,22 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
Token: token.SecretID,
}
sub, err = publisher.Subscribe(subscription2)
require.NoError(err)
require.NoError(t, err)
eventCh = testRunSub(sub)
// Ignore the end of snapshot event
e = assertEvent(t, eventCh)
require.True(e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
require.True(t, e.IsEndOfSnapshot(), "event should be a EoS got %v", e)
// Delete the unrelated policy.
require.NoError(s.ACLRoleDeleteByID(5, testRoleID_A, nil))
require.NoError(t, s.ACLRoleDeleteByID(5, testRoleID_A, nil))
// Ensure there's no reload event.
assertNoEvent(t, eventCh)
// Delete the policy used by the subscriber.
require.NoError(s.ACLRoleDeleteByID(6, testRoleID_B, nil))
require.NoError(t, s.ACLRoleDeleteByID(6, testRoleID_B, nil))
// Ensure the reload event was sent.
assertReset(t, eventCh, false)
@ -422,26 +419,14 @@ type nodePayload struct {
node *structs.ServiceNode
}
func (p nodePayload) MatchesKey(key, _, partition string) bool {
if key == "" && partition == "" {
return true
}
if p.node == nil {
return false
}
if structs.PartitionOrDefault(partition) != p.node.PartitionOrDefault() {
return false
}
return p.key == key
}
func (p nodePayload) HasReadPermission(acl.Authorizer) bool {
return true
}
func (p nodePayload) Subject() stream.Subject {
return stream.Subject(p.node.PartitionOrDefault() + "/" + p.node.NamespaceOrDefault() + "/" + p.key)
}
func createTokenAndWaitForACLEventPublish(t *testing.T, s *Store) *structs.ACLToken {
token := &structs.ACLToken{
AccessorID: "3af117a9-2233-4cf4-8ff8-3c749c9906b4",

View File

@ -14,6 +14,11 @@ import (
// events which match the Topic.
type Topic fmt.Stringer
// Subject identifies a portion of a topic for which a subscriber wishes to
// receive events (e.g. health events for a particular service) usually the
// normalized resource name (including partition and namespace if applicable).
type Subject string
// Event is a structure with identifiers and a payload. Events are Published to
// EventPublisher and returned to Subscribers.
type Event struct {
@ -26,18 +31,16 @@ type Event struct {
// should not modify the state of the payload if the Event is being submitted to
// EventPublisher.Publish.
type Payload interface {
// MatchesKey must return true if the Payload should be included in a
// subscription requested with the key, namespace, and partition.
//
// Generally this means that the payload matches the key, namespace, and
// partition or the payload is a special framing event that should be
// returned to every subscription.
MatchesKey(key, namespace, partition string) bool
// HasReadPermission uses the acl.Authorizer to determine if the items in the
// Payload are visible to the request. It returns true if the payload is
// authorized for Read, otherwise returns false.
HasReadPermission(authz acl.Authorizer) bool
// Subject is used to identify which subscribers should be notified of this
// event - e.g. those subscribing to health events for a particular service.
// it is usually the normalized resource name (including the partition and
// namespace if applicable).
Subject() Subject
}
// PayloadEvents is a Payload that may be returned by Subscription.Next when
@ -81,14 +84,6 @@ func (p *PayloadEvents) filter(f func(Event) bool) bool {
return true
}
// MatchesKey filters the PayloadEvents to those which match the key,
// namespace, and partition.
func (p *PayloadEvents) MatchesKey(key, namespace, partition string) bool {
return p.filter(func(event Event) bool {
return event.Payload.MatchesKey(key, namespace, partition)
})
}
func (p *PayloadEvents) Len() int {
return len(p.Items)
}
@ -101,6 +96,14 @@ func (p *PayloadEvents) HasReadPermission(authz acl.Authorizer) bool {
})
}
// Subject is required to satisfy the Payload interface but is not implemented
// by PayloadEvents. PayloadEvents structs are constructed by Subscription.Next
// *after* Subject has been used to dispatch the enclosed events to the correct
// buffer.
func (PayloadEvents) Subject() Subject {
panic("PayloadEvents does not implement Subject")
}
// IsEndOfSnapshot returns true if this is a framing event that indicates the
// snapshot has completed. Subsequent events from Subscription.Next will be
// streamed as they occur.
@ -117,12 +120,15 @@ func (e Event) IsNewSnapshotToFollow() bool {
type framingEvent struct{}
func (framingEvent) MatchesKey(string, string, string) bool {
func (framingEvent) HasReadPermission(acl.Authorizer) bool {
return true
}
func (framingEvent) HasReadPermission(acl.Authorizer) bool {
return true
// Subject is required by the Payload interface but is not implemented by
// framing events, as they are typically *manually* appended to the correct
// buffer and do not need to be routed using a Subject.
func (framingEvent) Subject() Subject {
panic("framing events do not implement Subject")
}
type endOfSnapshot struct {
@ -137,12 +143,15 @@ type closeSubscriptionPayload struct {
tokensSecretIDs []string
}
func (closeSubscriptionPayload) MatchesKey(string, string, string) bool {
func (closeSubscriptionPayload) HasReadPermission(acl.Authorizer) bool {
return false
}
func (closeSubscriptionPayload) HasReadPermission(acl.Authorizer) bool {
return false
// Subject is required by the Payload interface but it is not implemented by
// closeSubscriptionPayload, as this event type is handled separately and not
// actually appended to the buffer.
func (closeSubscriptionPayload) Subject() Subject {
panic("closeSubscriptionPayload does not implement Subject")
}
// NewCloseSubscriptionEvent returns a special Event that is handled by the

View File

@ -20,16 +20,16 @@ type EventPublisher struct {
// seconds.
snapCacheTTL time.Duration
// This lock protects the topicBuffers, and snapCache
// This lock protects the snapCache, topicBuffers and topicBuffer.refs.
lock sync.RWMutex
// topicBuffers stores the head of the linked-list buffer to publish events to
// topicBuffers stores the head of the linked-list buffers to publish events to
// for a topic.
topicBuffers map[Topic]*eventBuffer
topicBuffers map[topicSubject]*topicBuffer
// snapCache if a cache of EventSnapshots indexed by topic and key.
// snapCache if a cache of EventSnapshots indexed by topic and subject.
// TODO(streaming): new snapshotCache struct for snapCache and snapCacheTTL
snapCache map[Topic]map[string]*eventSnapshot
snapCache map[topicSubject]*eventSnapshot
subscriptions *subscriptions
@ -41,6 +41,13 @@ type EventPublisher struct {
snapshotHandlers SnapshotHandlers
}
// topicSubject is used as a map key when accessing topic buffers and cached
// snapshots.
type topicSubject struct {
Topic Topic
Subject Subject
}
type subscriptions struct {
// lock for byToken. If both subscription.lock and EventPublisher.lock need
// to be held, EventPublisher.lock MUST always be acquired first.
@ -54,6 +61,14 @@ type subscriptions struct {
byToken map[string]map[*SubscribeRequest]*Subscription
}
// topicBuffer augments the eventBuffer with a reference counter, enabling
// clean up of unused buffers once there are no longer any subscribers for
// the given topic and key.
type topicBuffer struct {
refs int // refs is guarded by EventPublisher.lock.
buf *eventBuffer
}
// SnapshotHandlers is a mapping of Topic to a function which produces a snapshot
// of events for the SubscribeRequest. Events are appended to the snapshot using SnapshotAppender.
// The nil Topic is reserved and should not be used.
@ -79,8 +94,8 @@ type SnapshotAppender interface {
func NewEventPublisher(handlers SnapshotHandlers, snapCacheTTL time.Duration) *EventPublisher {
e := &EventPublisher{
snapCacheTTL: snapCacheTTL,
topicBuffers: make(map[Topic]*eventBuffer),
snapCache: make(map[Topic]map[string]*eventSnapshot),
topicBuffers: make(map[topicSubject]*topicBuffer),
snapCache: make(map[topicSubject]*eventSnapshot),
publishCh: make(chan []Event, 64),
subscriptions: &subscriptions{
byToken: make(map[string]map[*SubscribeRequest]*Subscription),
@ -116,36 +131,59 @@ func (e *EventPublisher) Run(ctx context.Context) {
// publishEvent appends the events to any applicable topic buffers. It handles
// any closeSubscriptionPayload events by closing associated subscriptions.
func (e *EventPublisher) publishEvent(events []Event) {
eventsByTopic := make(map[Topic][]Event)
groupedEvents := make(map[topicSubject][]Event)
for _, event := range events {
if unsubEvent, ok := event.Payload.(closeSubscriptionPayload); ok {
e.subscriptions.closeSubscriptionsForTokens(unsubEvent.tokensSecretIDs)
continue
}
eventsByTopic[event.Topic] = append(eventsByTopic[event.Topic], event)
groupKey := topicSubject{event.Topic, event.Payload.Subject()}
groupedEvents[groupKey] = append(groupedEvents[groupKey], event)
}
e.lock.Lock()
defer e.lock.Unlock()
for topic, events := range eventsByTopic {
e.getTopicBuffer(topic).Append(events)
for groupKey, events := range groupedEvents {
// Note: bufferForPublishing returns nil if there are no subscribers for the
// given topic and subject, in which case events will be dropped on the floor and
// future subscribers will catch up by consuming the snapshot.
if buf := e.bufferForPublishing(groupKey); buf != nil {
buf.Append(events)
}
}
}
// getTopicBuffer for the topic. Creates a new event buffer if one does not
// already exist.
// bufferForSubscription returns the topic event buffer to which events for the
// given topic and key will be appended. If no such buffer exists, a new buffer
// will be created.
//
// EventPublisher.lock must be held to call this method.
func (e *EventPublisher) getTopicBuffer(topic Topic) *eventBuffer {
buf, ok := e.topicBuffers[topic]
// Warning: e.lock MUST be held when calling this function.
func (e *EventPublisher) bufferForSubscription(key topicSubject) *topicBuffer {
buf, ok := e.topicBuffers[key]
if !ok {
buf = newEventBuffer()
e.topicBuffers[topic] = buf
buf = &topicBuffer{
buf: newEventBuffer(),
}
e.topicBuffers[key] = buf
}
return buf
}
// bufferForPublishing returns the event buffer to which events for the given
// topic and key should be appended. nil will be returned if there are no
// subscribers for the given topic and key.
//
// Warning: e.lock MUST be held when calling this function.
func (e *EventPublisher) bufferForPublishing(key topicSubject) *eventBuffer {
buf, ok := e.topicBuffers[key]
if !ok {
return nil
}
return buf.buf
}
// Subscribe returns a new Subscription for the given request. A subscription
// will receive an initial snapshot of events matching the request if req.Index > 0.
// After the snapshot, events will be streamed as they are created.
@ -163,7 +201,34 @@ func (e *EventPublisher) Subscribe(req *SubscribeRequest) (*Subscription, error)
e.lock.Lock()
defer e.lock.Unlock()
topicHead := e.getTopicBuffer(req.Topic).Head()
topicBuf := e.bufferForSubscription(req.topicSubject())
topicBuf.refs++
// freeBuf is used to free the topic buffer once there are no remaining
// subscribers for the given topic and key.
//
// Note: it's called by Subcription.Unsubscribe which has its own side-effects
// that are made without holding e.lock (so there's a moment where the ref
// counter is inconsistent with the subscription map) — in practice this is
// fine, we don't need these things to be strongly consistent. The alternative
// would be to hold both locks, which introduces the risk of deadlocks.
freeBuf := func() {
e.lock.Lock()
defer e.lock.Unlock()
topicBuf.refs--
if topicBuf.refs == 0 {
delete(e.topicBuffers, req.topicSubject())
// Evict cached snapshot too because the topic buffer will have been spliced
// onto it. If we don't do this, any new subscribers started before the cache
// TTL is reached will get "stuck" waiting on the old buffer.
delete(e.snapCache, req.topicSubject())
}
}
topicHead := topicBuf.buf.Head()
// If the client view is fresh, resume the stream.
if req.Index > 0 && topicHead.HasEventIndex(req.Index) {
@ -173,7 +238,7 @@ func (e *EventPublisher) Subscribe(req *SubscribeRequest) (*Subscription, error)
// the subscription will receive new events.
next, _ := topicHead.NextNoBlock()
buf.AppendItem(next)
return e.subscriptions.add(req, subscriptionHead), nil
return e.subscriptions.add(req, subscriptionHead, freeBuf), nil
}
snapFromCache := e.getCachedSnapshotLocked(req)
@ -186,7 +251,7 @@ func (e *EventPublisher) Subscribe(req *SubscribeRequest) (*Subscription, error)
// If the request.Index is 0 the client has no view, send a full snapshot.
if req.Index == 0 {
return e.subscriptions.add(req, snapFromCache.First), nil
return e.subscriptions.add(req, snapFromCache.First, freeBuf), nil
}
// otherwise the request has an Index, the client view is stale and must be reset
@ -197,11 +262,17 @@ func (e *EventPublisher) Subscribe(req *SubscribeRequest) (*Subscription, error)
Payload: newSnapshotToFollow{},
}})
result.buffer.AppendItem(snapFromCache.First)
return e.subscriptions.add(req, result.First), nil
return e.subscriptions.add(req, result.First, freeBuf), nil
}
func (s *subscriptions) add(req *SubscribeRequest, head *bufferItem) *Subscription {
sub := newSubscription(*req, head, s.unsubscribe(req))
func (s *subscriptions) add(req *SubscribeRequest, head *bufferItem, freeBuf func()) *Subscription {
// We wrap freeBuf in a sync.Once as it's expected that Subscription.unsub is
// idempotent, but freeBuf decrements the reference counter on every call.
var once sync.Once
sub := newSubscription(*req, head, func() {
s.unsubscribe(req)
once.Do(freeBuf)
})
s.lock.Lock()
defer s.lock.Unlock()
@ -228,24 +299,17 @@ func (s *subscriptions) closeSubscriptionsForTokens(tokenSecretIDs []string) {
}
}
// unsubscribe returns a function that the subscription will call to remove
// itself from the subsByToken.
// This function is returned as a closure so that the caller doesn't need to keep
// track of the SubscriptionRequest, and can not accidentally call unsubscribe with the
// wrong pointer.
func (s *subscriptions) unsubscribe(req *SubscribeRequest) func() {
return func() {
s.lock.Lock()
defer s.lock.Unlock()
func (s *subscriptions) unsubscribe(req *SubscribeRequest) {
s.lock.Lock()
defer s.lock.Unlock()
subsByToken, ok := s.byToken[req.Token]
if !ok {
return
}
delete(subsByToken, req)
if len(subsByToken) == 0 {
delete(s.byToken, req.Token)
}
subsByToken, ok := s.byToken[req.Token]
if !ok {
return
}
delete(subsByToken, req)
if len(subsByToken) == 0 {
delete(s.byToken, req.Token)
}
}
@ -262,13 +326,7 @@ func (s *subscriptions) closeAll() {
// EventPublisher.lock must be held to call this method.
func (e *EventPublisher) getCachedSnapshotLocked(req *SubscribeRequest) *eventSnapshot {
topicSnaps, ok := e.snapCache[req.Topic]
if !ok {
topicSnaps = make(map[string]*eventSnapshot)
e.snapCache[req.Topic] = topicSnaps
}
snap, ok := topicSnaps[snapCacheKey(req)]
snap, ok := e.snapCache[req.topicSubject()]
if ok && snap.err() == nil {
return snap
}
@ -280,16 +338,12 @@ func (e *EventPublisher) setCachedSnapshotLocked(req *SubscribeRequest, snap *ev
if e.snapCacheTTL == 0 {
return
}
e.snapCache[req.Topic][snapCacheKey(req)] = snap
e.snapCache[req.topicSubject()] = snap
// Setup a cache eviction
time.AfterFunc(e.snapCacheTTL, func() {
e.lock.Lock()
defer e.lock.Unlock()
delete(e.snapCache[req.Topic], snapCacheKey(req))
delete(e.snapCache, req.topicSubject())
})
}
func snapCacheKey(req *SubscribeRequest) string {
return req.Partition + "/" + req.Namespace + "/" + req.Key
}

Some files were not shown because too many files have changed in this diff Show More