Merge branch 'master' of github.com:hashicorp/consul into WinService
This commit is contained in:
commit
2182e289a3
|
@ -0,0 +1,3 @@
|
|||
pkg/
|
||||
.git
|
||||
bin/
|
10
CHANGELOG.md
10
CHANGELOG.md
|
@ -1,13 +1,21 @@
|
|||
## UNRELEASED
|
||||
|
||||
## 1.2.0 (June 26, 2018)
|
||||
|
||||
FEATURES:
|
||||
|
||||
* **Connect Feature Beta**: This version includes a major new feature for Consul named Connect. Connect enables secure service-to-service communication with automatic TLS encryption and identity-based authorization. For more details and links to demos and getting started guides, see the [announcement blog post](https://www.hashicorp.com/blog/consul-1-2-service-mesh).
|
||||
* Connect must be enabled explicitly in configuration so upgrading a cluster will not affect any existing functionality until it's enabled.
|
||||
* This is a Beta feature, we don't recommend enabling this in production yet. Please see the documentation for more information.
|
||||
* dns: Enable PTR record lookups for services with IPs that have no registered node [[PR-4083](https://github.com/hashicorp/consul/pull/4083)]
|
||||
* ui: Default to serving the new UI. Setting the `CONSUL_UI_LEGACY` environment variable to `1` or `true` will revert to serving the old UI
|
||||
|
||||
IMPROVEMENTS:
|
||||
|
||||
* agent: A Consul user-agent string is now sent to providers when making retry-join requests [GH-4013](https://github.com/hashicorp/consul/pull/4013)
|
||||
* agent: A Consul user-agent string is now sent to providers when making retry-join requests [[GH-4013](https://github.com/hashicorp/consul/issues/4013)](https://github.com/hashicorp/consul/pull/4013)
|
||||
* client: Add metrics for failed RPCs [PR-4220](https://github.com/hashicorp/consul/pull/4220)
|
||||
* agent: Add configuration entry to control including TXT records for node meta in DNS responses [PR-4215](https://github.com/hashicorp/consul/pull/4215)
|
||||
* client: Make RPC rate limit configuration reloadable [GH-4012](https://github.com/hashicorp/consul/issues/4012)
|
||||
|
||||
BUG FIXES:
|
||||
|
||||
|
|
166
GNUmakefile
166
GNUmakefile
|
@ -16,35 +16,97 @@ GOTEST_PKGS ?= "./..."
|
|||
else
|
||||
GOTEST_PKGS=$(shell go list ./... | sed 's/github.com\/hashicorp\/consul/./' | egrep -v "^($(GOTEST_PKGS_EXCLUDE))$$")
|
||||
endif
|
||||
GOOS=$(shell go env GOOS)
|
||||
GOARCH=$(shell go env GOARCH)
|
||||
GOOS?=$(shell go env GOOS)
|
||||
GOARCH?=$(shell go env GOARCH)
|
||||
GOPATH=$(shell go env GOPATH)
|
||||
|
||||
ASSETFS_PATH?=agent/bindata_assetfs.go
|
||||
# Get the git commit
|
||||
GIT_COMMIT=$(shell git rev-parse --short HEAD)
|
||||
GIT_DIRTY=$(shell test -n "`git status --porcelain`" && echo "+CHANGES" || true)
|
||||
GIT_DESCRIBE=$(shell git describe --tags --always)
|
||||
GIT_COMMIT?=$(shell git rev-parse --short HEAD)
|
||||
GIT_DIRTY?=$(shell test -n "`git status --porcelain`" && echo "+CHANGES" || true)
|
||||
GIT_DESCRIBE?=$(shell git describe --tags --always)
|
||||
GIT_IMPORT=github.com/hashicorp/consul/version
|
||||
GOLDFLAGS=-X $(GIT_IMPORT).GitCommit=$(GIT_COMMIT)$(GIT_DIRTY) -X $(GIT_IMPORT).GitDescribe=$(GIT_DESCRIBE)
|
||||
|
||||
ifeq ($(FORCE_REBUILD),1)
|
||||
NOCACHE=--no-cache
|
||||
else
|
||||
NOCACHE=
|
||||
endif
|
||||
|
||||
DOCKER_BUILD_QUIET?=1
|
||||
ifeq (${DOCKER_BUILD_QUIET},1)
|
||||
QUIET=-q
|
||||
else
|
||||
QUIET=
|
||||
endif
|
||||
|
||||
CONSUL_DEV_IMAGE?=consul-dev
|
||||
GO_BUILD_TAG?=consul-build-go
|
||||
UI_BUILD_TAG?=consul-build-ui
|
||||
UI_LEGACY_BUILD_TAG?=consul-build-ui-legacy
|
||||
BUILD_CONTAINER_NAME?=consul-builder
|
||||
|
||||
DIST_TAG?=1
|
||||
DIST_BUILD?=1
|
||||
DIST_SIGN?=1
|
||||
|
||||
ifdef DIST_VERSION
|
||||
DIST_VERSION_ARG=-v "$(DIST_VERSION)"
|
||||
else
|
||||
DIST_VERSION_ARG=
|
||||
endif
|
||||
|
||||
ifdef DIST_RELEASE_DATE
|
||||
DIST_DATE_ARG=-d "$(DIST_RELEASE_DATE)"
|
||||
else
|
||||
DIST_DATE_ARG=
|
||||
endif
|
||||
|
||||
ifdef DIST_PRERELEASE
|
||||
DIST_REL_ARG=-r "$(DIST_PRERELEASE)"
|
||||
else
|
||||
DIST_REL_ARG=
|
||||
endif
|
||||
|
||||
PUB_GIT?=1
|
||||
PUB_WEBSITE?=1
|
||||
|
||||
ifeq ($(PUB_GIT),1)
|
||||
PUB_GIT_ARG=-g
|
||||
else
|
||||
PUB_GIT_ARG=
|
||||
endif
|
||||
|
||||
ifeq ($(PUB_WEBSITE),1)
|
||||
PUB_WEBSITE_ARG=-g
|
||||
else
|
||||
PUB_WEBSITE_ARG=
|
||||
endif
|
||||
|
||||
export GO_BUILD_TAG
|
||||
export UI_BUILD_TAG
|
||||
export UI_LEGACY_BUILD_TAG
|
||||
export BUILD_CONTAINER_NAME
|
||||
export GIT_COMMIT
|
||||
export GIT_DIRTY
|
||||
export GIT_DESCRIBE
|
||||
export GOTAGS
|
||||
export GOLDFLAGS
|
||||
|
||||
# all builds binaries for all targets
|
||||
all: bin
|
||||
|
||||
bin: tools
|
||||
@mkdir -p bin/
|
||||
@GOTAGS='$(GOTAGS)' sh -c "'$(CURDIR)/scripts/build.sh'"
|
||||
bin: tools dev-build
|
||||
|
||||
# dev creates binaries for testing locally - these are put into ./bin and $GOPATH
|
||||
dev: changelogfmt vendorfmt dev-build
|
||||
|
||||
dev-build:
|
||||
@echo "--> Building consul"
|
||||
mkdir -p pkg/$(GOOS)_$(GOARCH)/ bin/
|
||||
go install -ldflags '$(GOLDFLAGS)' -tags '$(GOTAGS)'
|
||||
cp $(GOPATH)/bin/consul bin/
|
||||
cp $(GOPATH)/bin/consul pkg/$(GOOS)_$(GOARCH)
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/build-local.sh -o $(GOOS) -a $(GOARCH)
|
||||
|
||||
dev-docker:
|
||||
@docker build -t '$(CONSUL_DEV_IMAGE)' --build-arg 'GIT_COMMIT=$(GIT_COMMIT)' --build-arg 'GIT_DIRTY=$(GIT_DIRTY)' --build-arg 'GIT_DESCRIBE=$(GIT_DESCRIBE)' -f $(CURDIR)/build-support/docker/Consul-Dev.dockerfile $(CURDIR)
|
||||
|
||||
vendorfmt:
|
||||
@echo "--> Formatting vendor/vendor.json"
|
||||
|
@ -57,12 +119,17 @@ changelogfmt:
|
|||
|
||||
# linux builds a linux package independent of the source platform
|
||||
linux:
|
||||
mkdir -p pkg/linux_amd64/
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags '$(GOLDFLAGS)' -tags '$(GOTAGS)' -o pkg/linux_amd64/consul
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/build-local.sh -o linux -a amd64
|
||||
|
||||
# dist builds binaries for all platforms and packages them for distribution
|
||||
dist:
|
||||
@GOTAGS='$(GOTAGS)' sh -c "'$(CURDIR)/scripts/dist.sh'"
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/release.sh -t '$(DIST_TAG)' -b '$(DIST_BUILD)' -S '$(DIST_SIGN)' $(DIST_VERSION_ARG) $(DIST_DATE_ARG) $(DIST_REL_ARG)
|
||||
|
||||
publish:
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/publish.sh $(PUB_GIT_ARG) $(PUB_WEBSITE_ARG)
|
||||
|
||||
dev-tree:
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/dev.sh
|
||||
|
||||
cov:
|
||||
gocov test $(GOFILES) | gocov-html > /tmp/coverage.html
|
||||
|
@ -78,10 +145,16 @@ test: other-consul dev-build vet
|
|||
@# _something_ to stop them terminating us due to inactivity...
|
||||
{ go test $(GOTEST_FLAGS) -tags '$(GOTAGS)' -timeout 5m $(GOTEST_PKGS) 2>&1 ; echo $$? > exit-code ; } | tee test.log | egrep '^(ok|FAIL)\s*github.com/hashicorp/consul'
|
||||
@echo "Exit code: $$(cat exit-code)" >> test.log
|
||||
@grep -A5 'DATA RACE' test.log || true
|
||||
@grep -A10 'panic: test timed out' test.log || true
|
||||
@grep -A1 -- '--- SKIP:' test.log || true
|
||||
@grep -A1 -- '--- FAIL:' test.log || true
|
||||
@# This prints all the race report between ====== lines
|
||||
@awk '/^WARNING: DATA RACE/ {do_print=1; print "=================="} do_print==1 {print} /^={10,}/ {do_print=0}' test.log || true
|
||||
@grep -A10 'panic: ' test.log || true
|
||||
@# Prints all the failure output until the next non-indented line - testify
|
||||
@# helpers often output multiple lines for readability but useless if we can't
|
||||
@# see them. Un-intuitive order of matches is necessary. No || true because
|
||||
@# awk always returns true even if there is no match and it breaks non-bash
|
||||
@# shells locally.
|
||||
@awk '/^[^[:space:]]/ {do_print=0} /--- SKIP/ {do_print=1} do_print==1 {print}' test.log
|
||||
@awk '/^[^[:space:]]/ {do_print=0} /--- FAIL/ {do_print=1} do_print==1 {print}' test.log
|
||||
@grep '^FAIL' test.log || true
|
||||
@if [ "$$(cat exit-code)" == "0" ] ; then echo "PASS" ; exit 0 ; else exit 1 ; fi
|
||||
|
||||
|
@ -111,20 +184,57 @@ vet:
|
|||
exit 1; \
|
||||
fi
|
||||
|
||||
# Build the static web ui and build static assets inside a Docker container, the
|
||||
# same way a release build works. This implicitly does a "make static-assets" at
|
||||
# the end.
|
||||
ui:
|
||||
@sh -c "'$(CURDIR)/scripts/ui.sh'"
|
||||
|
||||
# If you've run "make ui" manually then this will get called for you. This is
|
||||
# also run as part of the release build script when it verifies that there are no
|
||||
# changes to the UI assets that aren't checked in.
|
||||
static-assets:
|
||||
@go-bindata-assetfs -pkg agent -prefix pkg -o agent/bindata_assetfs.go ./pkg/web_ui/...
|
||||
@go-bindata-assetfs -pkg agent -prefix pkg -o $(ASSETFS_PATH) ./pkg/web_ui/...
|
||||
$(MAKE) format
|
||||
|
||||
|
||||
# Build the static web ui and build static assets inside a Docker container
|
||||
ui: ui-legacy-docker ui-docker static-assets-docker
|
||||
|
||||
tools:
|
||||
go get -u -v $(GOTOOLS)
|
||||
|
||||
.PHONY: all ci bin dev dist cov test cover format vet ui static-assets tools vendorfmt
|
||||
version:
|
||||
@echo -n "Version: "
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/version.sh
|
||||
@echo -n "Version + release: "
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/version.sh -r
|
||||
@echo -n "Version + git: "
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/version.sh -g
|
||||
@echo -n "Version + release + git: "
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/version.sh -r -g
|
||||
|
||||
|
||||
docker-images: go-build-image ui-build-image ui-legacy-build-image
|
||||
|
||||
go-build-image:
|
||||
@echo "Building Golang build container"
|
||||
@docker build $(NOCACHE) $(QUIET) --build-arg 'GOTOOLS=$(GOTOOLS)' -t $(GO_BUILD_TAG) - < build-support/docker/Build-Go.dockerfile
|
||||
|
||||
ui-build-image:
|
||||
@echo "Building UI build container"
|
||||
@docker build $(NOCACHE) $(QUIET) -t $(UI_BUILD_TAG) - < build-support/docker/Build-UI.dockerfile
|
||||
|
||||
ui-legacy-build-image:
|
||||
@echo "Building Legacy UI build container"
|
||||
@docker build $(NOCACHE) $(QUIET) -t $(UI_LEGACY_BUILD_TAG) - < build-support/docker/Build-UI-Legacy.dockerfile
|
||||
|
||||
static-assets-docker: go-build-image
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/build-docker.sh static-assets
|
||||
|
||||
consul-docker: go-build-image
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/build-docker.sh consul
|
||||
|
||||
ui-docker: ui-build-image
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/build-docker.sh ui
|
||||
|
||||
ui-legacy-docker: ui-legacy-build-image
|
||||
@$(SHELL) $(CURDIR)/build-support/scripts/build-docker.sh ui-legacy
|
||||
|
||||
|
||||
.PHONY: all ci bin dev dist cov test cover format vet ui static-assets tools vendorfmt
|
||||
.PHONY: docker-images go-build-image ui-build-image ui-legacy-build-image static-assets-docker consul-docker ui-docker ui-legacy-docker version
|
||||
|
|
|
@ -25,6 +25,9 @@ Consul provides several key features:
|
|||
* **Multi-Datacenter** - Consul is built to be datacenter aware, and can
|
||||
support any number of regions without complex configuration.
|
||||
|
||||
* **Service Segmentation** - Consul Connect enables secure service-to-service
|
||||
communication with automatic TLS encryption and identity-based authorization.
|
||||
|
||||
Consul runs on Linux, Mac OS X, FreeBSD, Solaris, and Windows. A commercial
|
||||
version called [Consul Enterprise](https://www.hashicorp.com/products/consul)
|
||||
is also available.
|
||||
|
|
91
acl/acl.go
91
acl/acl.go
|
@ -60,6 +60,17 @@ type ACL interface {
|
|||
// EventWrite determines if a specific event may be fired.
|
||||
EventWrite(string) bool
|
||||
|
||||
// IntentionDefaultAllow determines the default authorized behavior
|
||||
// when no intentions match a Connect request.
|
||||
IntentionDefaultAllow() bool
|
||||
|
||||
// IntentionRead determines if a specific intention can be read.
|
||||
IntentionRead(string) bool
|
||||
|
||||
// IntentionWrite determines if a specific intention can be
|
||||
// created, modified, or deleted.
|
||||
IntentionWrite(string) bool
|
||||
|
||||
// KeyList checks for permission to list keys under a prefix
|
||||
KeyList(string) bool
|
||||
|
||||
|
@ -154,6 +165,18 @@ func (s *StaticACL) EventWrite(string) bool {
|
|||
return s.defaultAllow
|
||||
}
|
||||
|
||||
func (s *StaticACL) IntentionDefaultAllow() bool {
|
||||
return s.defaultAllow
|
||||
}
|
||||
|
||||
func (s *StaticACL) IntentionRead(string) bool {
|
||||
return s.defaultAllow
|
||||
}
|
||||
|
||||
func (s *StaticACL) IntentionWrite(string) bool {
|
||||
return s.defaultAllow
|
||||
}
|
||||
|
||||
func (s *StaticACL) KeyRead(string) bool {
|
||||
return s.defaultAllow
|
||||
}
|
||||
|
@ -275,6 +298,9 @@ type PolicyACL struct {
|
|||
// agentRules contains the agent policies
|
||||
agentRules *radix.Tree
|
||||
|
||||
// intentionRules contains the service intention policies
|
||||
intentionRules *radix.Tree
|
||||
|
||||
// keyRules contains the key policies
|
||||
keyRules *radix.Tree
|
||||
|
||||
|
@ -308,6 +334,7 @@ func New(parent ACL, policy *Policy, sentinel sentinel.Evaluator) (*PolicyACL, e
|
|||
p := &PolicyACL{
|
||||
parent: parent,
|
||||
agentRules: radix.New(),
|
||||
intentionRules: radix.New(),
|
||||
keyRules: radix.New(),
|
||||
nodeRules: radix.New(),
|
||||
serviceRules: radix.New(),
|
||||
|
@ -347,6 +374,25 @@ func New(parent ACL, policy *Policy, sentinel sentinel.Evaluator) (*PolicyACL, e
|
|||
sentinelPolicy: sp.Sentinel,
|
||||
}
|
||||
p.serviceRules.Insert(sp.Name, policyRule)
|
||||
|
||||
// Determine the intention. The intention could be blank (not set).
|
||||
// If the intention is not set, the value depends on the value of
|
||||
// the service policy.
|
||||
intention := sp.Intentions
|
||||
if intention == "" {
|
||||
switch sp.Policy {
|
||||
case PolicyRead, PolicyWrite:
|
||||
intention = PolicyRead
|
||||
default:
|
||||
intention = PolicyDeny
|
||||
}
|
||||
}
|
||||
|
||||
policyRule = PolicyRule{
|
||||
aclPolicy: intention,
|
||||
sentinelPolicy: sp.Sentinel,
|
||||
}
|
||||
p.intentionRules.Insert(sp.Name, policyRule)
|
||||
}
|
||||
|
||||
// Load the session policy
|
||||
|
@ -455,6 +501,51 @@ func (p *PolicyACL) EventWrite(name string) bool {
|
|||
return p.parent.EventWrite(name)
|
||||
}
|
||||
|
||||
// IntentionDefaultAllow returns whether the default behavior when there are
|
||||
// no matching intentions is to allow or deny.
|
||||
func (p *PolicyACL) IntentionDefaultAllow() bool {
|
||||
// We always go up, this can't be determined by a policy.
|
||||
return p.parent.IntentionDefaultAllow()
|
||||
}
|
||||
|
||||
// IntentionRead checks if writing (creating, updating, or deleting) of an
|
||||
// intention is allowed.
|
||||
func (p *PolicyACL) IntentionRead(prefix string) bool {
|
||||
// Check for an exact rule or catch-all
|
||||
_, rule, ok := p.intentionRules.LongestPrefix(prefix)
|
||||
if ok {
|
||||
pr := rule.(PolicyRule)
|
||||
switch pr.aclPolicy {
|
||||
case PolicyRead, PolicyWrite:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// No matching rule, use the parent.
|
||||
return p.parent.IntentionRead(prefix)
|
||||
}
|
||||
|
||||
// IntentionWrite checks if writing (creating, updating, or deleting) of an
|
||||
// intention is allowed.
|
||||
func (p *PolicyACL) IntentionWrite(prefix string) bool {
|
||||
// Check for an exact rule or catch-all
|
||||
_, rule, ok := p.intentionRules.LongestPrefix(prefix)
|
||||
if ok {
|
||||
pr := rule.(PolicyRule)
|
||||
switch pr.aclPolicy {
|
||||
case PolicyWrite:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// No matching rule, use the parent.
|
||||
return p.parent.IntentionWrite(prefix)
|
||||
}
|
||||
|
||||
// KeyRead returns if a key is allowed to be read
|
||||
func (p *PolicyACL) KeyRead(key string) bool {
|
||||
// Look for a matching rule
|
||||
|
|
|
@ -53,6 +53,12 @@ func TestStaticACL(t *testing.T) {
|
|||
if !all.EventWrite("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !all.IntentionDefaultAllow() {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !all.IntentionWrite("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !all.KeyRead("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
|
@ -123,6 +129,12 @@ func TestStaticACL(t *testing.T) {
|
|||
if none.EventWrite("") {
|
||||
t.Fatalf("should not allow")
|
||||
}
|
||||
if none.IntentionDefaultAllow() {
|
||||
t.Fatalf("should not allow")
|
||||
}
|
||||
if none.IntentionWrite("foo") {
|
||||
t.Fatalf("should not allow")
|
||||
}
|
||||
if none.KeyRead("foobar") {
|
||||
t.Fatalf("should not allow")
|
||||
}
|
||||
|
@ -187,6 +199,12 @@ func TestStaticACL(t *testing.T) {
|
|||
if !manage.EventWrite("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !manage.IntentionDefaultAllow() {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !manage.IntentionWrite("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !manage.KeyRead("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
|
@ -305,8 +323,14 @@ func TestPolicyACL(t *testing.T) {
|
|||
Policy: PolicyDeny,
|
||||
},
|
||||
&ServicePolicy{
|
||||
Name: "barfoo",
|
||||
Policy: PolicyWrite,
|
||||
Name: "barfoo",
|
||||
Policy: PolicyWrite,
|
||||
Intentions: PolicyWrite,
|
||||
},
|
||||
&ServicePolicy{
|
||||
Name: "intbaz",
|
||||
Policy: PolicyWrite,
|
||||
Intentions: PolicyDeny,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -344,6 +368,31 @@ func TestPolicyACL(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test the intentions
|
||||
type intentioncase struct {
|
||||
inp string
|
||||
read bool
|
||||
write bool
|
||||
}
|
||||
icases := []intentioncase{
|
||||
{"other", true, false},
|
||||
{"foo", true, false},
|
||||
{"bar", false, false},
|
||||
{"foobar", true, false},
|
||||
{"barfo", false, false},
|
||||
{"barfoo", true, true},
|
||||
{"barfoo2", true, true},
|
||||
{"intbaz", false, false},
|
||||
}
|
||||
for _, c := range icases {
|
||||
if c.read != acl.IntentionRead(c.inp) {
|
||||
t.Fatalf("Read fail: %#v", c)
|
||||
}
|
||||
if c.write != acl.IntentionWrite(c.inp) {
|
||||
t.Fatalf("Write fail: %#v", c)
|
||||
}
|
||||
}
|
||||
|
||||
// Test the services
|
||||
type servicecase struct {
|
||||
inp string
|
||||
|
@ -414,6 +463,11 @@ func TestPolicyACL(t *testing.T) {
|
|||
t.Fatalf("Prepared query fail: %#v", c)
|
||||
}
|
||||
}
|
||||
|
||||
// Check default intentions bubble up
|
||||
if !acl.IntentionDefaultAllow() {
|
||||
t.Fatal("should allow")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyACL_Parent(t *testing.T) {
|
||||
|
@ -567,6 +621,11 @@ func TestPolicyACL_Parent(t *testing.T) {
|
|||
if acl.Snapshot() {
|
||||
t.Fatalf("should not allow")
|
||||
}
|
||||
|
||||
// Check default intentions
|
||||
if acl.IntentionDefaultAllow() {
|
||||
t.Fatal("should not allow")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyACL_Agent(t *testing.T) {
|
||||
|
|
|
@ -73,6 +73,11 @@ type ServicePolicy struct {
|
|||
Name string `hcl:",key"`
|
||||
Policy string
|
||||
Sentinel Sentinel
|
||||
|
||||
// Intentions is the policy for intentions where this service is the
|
||||
// destination. This may be empty, in which case the Policy determines
|
||||
// the intentions policy.
|
||||
Intentions string
|
||||
}
|
||||
|
||||
func (s *ServicePolicy) GoString() string {
|
||||
|
@ -197,6 +202,9 @@ func Parse(rules string, sentinel sentinel.Evaluator) (*Policy, error) {
|
|||
if !isPolicyValid(sp.Policy) {
|
||||
return nil, fmt.Errorf("Invalid service policy: %#v", sp)
|
||||
}
|
||||
if sp.Intentions != "" && !isPolicyValid(sp.Intentions) {
|
||||
return nil, fmt.Errorf("Invalid service intentions policy: %#v", sp)
|
||||
}
|
||||
if err := isSentinelValid(sentinel, sp.Policy, sp.Sentinel); err != nil {
|
||||
return nil, fmt.Errorf("Invalid service Sentinel policy: %#v, got error:%v", sp, err)
|
||||
}
|
||||
|
|
|
@ -4,8 +4,85 @@ import (
|
|||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParse_table(t *testing.T) {
|
||||
// Note that the table tests are newer than other tests. Many of the
|
||||
// other aspects of policy parsing are tested in older tests below. New
|
||||
// parsing tests should be added to this table as its easier to maintain.
|
||||
cases := []struct {
|
||||
Name string
|
||||
Input string
|
||||
Expected *Policy
|
||||
Err string
|
||||
}{
|
||||
{
|
||||
"service no intentions",
|
||||
`
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
}
|
||||
`,
|
||||
&Policy{
|
||||
Services: []*ServicePolicy{
|
||||
{
|
||||
Name: "foo",
|
||||
Policy: "write",
|
||||
},
|
||||
},
|
||||
},
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"service intentions",
|
||||
`
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
intentions = "read"
|
||||
}
|
||||
`,
|
||||
&Policy{
|
||||
Services: []*ServicePolicy{
|
||||
{
|
||||
Name: "foo",
|
||||
Policy: "write",
|
||||
Intentions: "read",
|
||||
},
|
||||
},
|
||||
},
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"service intention: invalid value",
|
||||
`
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
intentions = "foo"
|
||||
}
|
||||
`,
|
||||
nil,
|
||||
"service intentions",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
actual, err := Parse(tc.Input, nil)
|
||||
assert.Equal(tc.Err != "", err != nil, err)
|
||||
if err != nil {
|
||||
assert.Contains(err.Error(), tc.Err)
|
||||
return
|
||||
}
|
||||
assert.Equal(tc.Expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestACLPolicy_Parse_HCL(t *testing.T) {
|
||||
inp := `
|
||||
agent "foo" {
|
||||
|
|
13
agent/acl.go
13
agent/acl.go
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
"github.com/hashicorp/consul/agent/local"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/hashicorp/golang-lru"
|
||||
|
@ -239,6 +240,18 @@ func (a *Agent) resolveToken(id string) (acl.ACL, error) {
|
|||
return a.acls.lookupACL(a, id)
|
||||
}
|
||||
|
||||
// resolveProxyToken attempts to resolve an ACL ID to a local proxy token.
|
||||
// If a local proxy isn't found with that token, nil is returned.
|
||||
func (a *Agent) resolveProxyToken(id string) *local.ManagedProxy {
|
||||
for _, p := range a.State.Proxies() {
|
||||
if p.ProxyToken == id {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// vetServiceRegister makes sure the service registration action is allowed by
|
||||
// the given token.
|
||||
func (a *Agent) vetServiceRegister(token string, service *structs.NodeService) error {
|
||||
|
|
638
agent/agent.go
638
agent/agent.go
|
@ -21,16 +21,20 @@ import (
|
|||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/ae"
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/cache-types"
|
||||
"github.com/hashicorp/consul/agent/checks"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
"github.com/hashicorp/consul/agent/consul"
|
||||
"github.com/hashicorp/consul/agent/local"
|
||||
"github.com/hashicorp/consul/agent/proxy"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/agent/systemd"
|
||||
"github.com/hashicorp/consul/agent/token"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/ipaddr"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/lib/file"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/hashicorp/consul/watch"
|
||||
|
@ -46,6 +50,9 @@ const (
|
|||
// Path to save agent service definitions
|
||||
servicesDir = "services"
|
||||
|
||||
// Path to save agent proxy definitions
|
||||
proxyDir = "proxies"
|
||||
|
||||
// Path to save local agent checks
|
||||
checksDir = "checks"
|
||||
checkStateDir = "checks/state"
|
||||
|
@ -73,6 +80,7 @@ type delegate interface {
|
|||
SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error
|
||||
Shutdown() error
|
||||
Stats() map[string]map[string]string
|
||||
ReloadConfig(config *consul.Config) error
|
||||
enterpriseDelegate
|
||||
}
|
||||
|
||||
|
@ -118,6 +126,9 @@ type Agent struct {
|
|||
// and the remote state.
|
||||
sync *ae.StateSyncer
|
||||
|
||||
// cache is the in-memory cache for data the Agent requests.
|
||||
cache *cache.Cache
|
||||
|
||||
// checkReapAfter maps the check ID to a timeout after which we should
|
||||
// reap its associated service
|
||||
checkReapAfter map[types.CheckID]time.Duration
|
||||
|
@ -194,6 +205,9 @@ type Agent struct {
|
|||
// be updated at runtime, so should always be used instead of going to
|
||||
// the configuration directly.
|
||||
tokens *token.Store
|
||||
|
||||
// proxyManager is the proxy process manager for managed Connect proxies.
|
||||
proxyManager *proxy.Manager
|
||||
}
|
||||
|
||||
func New(c *config.RuntimeConfig) (*Agent, error) {
|
||||
|
@ -246,6 +260,8 @@ func LocalConfig(cfg *config.RuntimeConfig) local.Config {
|
|||
NodeID: cfg.NodeID,
|
||||
NodeName: cfg.NodeName,
|
||||
TaggedAddresses: map[string]string{},
|
||||
ProxyBindMinPort: cfg.ConnectProxyBindMinPort,
|
||||
ProxyBindMaxPort: cfg.ConnectProxyBindMaxPort,
|
||||
}
|
||||
for k, v := range cfg.TaggedAddresses {
|
||||
lc.TaggedAddresses[k] = v
|
||||
|
@ -288,6 +304,9 @@ func (a *Agent) Start() error {
|
|||
// regular and on-demand state synchronizations (anti-entropy).
|
||||
a.sync = ae.NewStateSyncer(a.State, c.AEInterval, a.shutdownCh, a.logger)
|
||||
|
||||
// create the cache
|
||||
a.cache = cache.New(nil)
|
||||
|
||||
// create the config for the rpc server/client
|
||||
consulCfg, err := a.consulConfig()
|
||||
if err != nil {
|
||||
|
@ -324,10 +343,17 @@ func (a *Agent) Start() error {
|
|||
a.State.Delegate = a.delegate
|
||||
a.State.TriggerSyncChanges = a.sync.SyncChanges.Trigger
|
||||
|
||||
// Register the cache. We do this much later so the delegate is
|
||||
// populated from above.
|
||||
a.registerCache()
|
||||
|
||||
// Load checks/services/metadata.
|
||||
if err := a.loadServices(c); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.loadProxies(c); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.loadChecks(c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -335,6 +361,28 @@ func (a *Agent) Start() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// create the proxy process manager and start it. This is purposely
|
||||
// done here after the local state above is loaded in so we can have
|
||||
// a more accurate initial state view.
|
||||
if !c.ConnectTestDisableManagedProxies {
|
||||
a.proxyManager = proxy.NewManager()
|
||||
a.proxyManager.AllowRoot = a.config.ConnectProxyAllowManagedRoot
|
||||
a.proxyManager.State = a.State
|
||||
a.proxyManager.Logger = a.logger
|
||||
if a.config.DataDir != "" {
|
||||
// DataDir is required for all non-dev mode agents, but we want
|
||||
// to allow setting the data dir for demos and so on for the agent,
|
||||
// so do the check above instead.
|
||||
a.proxyManager.DataDir = filepath.Join(a.config.DataDir, "proxy")
|
||||
|
||||
// Restore from our snapshot (if it exists)
|
||||
if err := a.proxyManager.Restore(a.proxyManager.SnapshotPath()); err != nil {
|
||||
a.logger.Printf("[WARN] agent: error restoring proxy state: %s", err)
|
||||
}
|
||||
}
|
||||
go a.proxyManager.Run()
|
||||
}
|
||||
|
||||
// Start watching for critical services to deregister, based on their
|
||||
// checks.
|
||||
go a.reapServices()
|
||||
|
@ -604,6 +652,16 @@ func (a *Agent) reloadWatches(cfg *config.RuntimeConfig) error {
|
|||
return fmt.Errorf("Handler type '%s' not recognized", params["handler_type"])
|
||||
}
|
||||
|
||||
// Don't let people use connect watches via this mechanism for now as it
|
||||
// needs thought about how to do securely and shouldn't be necessary. Note
|
||||
// that if the type assertion fails an type is not a string then
|
||||
// ParseExample below will error so we don't need to handle that case.
|
||||
if typ, ok := params["type"].(string); ok {
|
||||
if strings.HasPrefix(typ, "connect_") {
|
||||
return fmt.Errorf("Watch type %s is not allowed in agent config", typ)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the watches, excluding 'handler' and 'args'
|
||||
wp, err := watch.ParseExempt(params, []string{"handler", "args"})
|
||||
if err != nil {
|
||||
|
@ -877,6 +935,47 @@ func (a *Agent) consulConfig() (*consul.Config, error) {
|
|||
base.TLSCipherSuites = a.config.TLSCipherSuites
|
||||
base.TLSPreferServerCipherSuites = a.config.TLSPreferServerCipherSuites
|
||||
|
||||
// Copy the Connect CA bootstrap config
|
||||
if a.config.ConnectEnabled {
|
||||
base.ConnectEnabled = true
|
||||
|
||||
// Allow config to specify cluster_id provided it's a valid UUID. This is
|
||||
// meant only for tests where a deterministic ID makes fixtures much simpler
|
||||
// to work with but since it's only read on initial cluster bootstrap it's not
|
||||
// that much of a liability in production. The worst a user could do is
|
||||
// configure logically separate clusters with same ID by mistake but we can
|
||||
// avoid documenting this is even an option.
|
||||
if clusterID, ok := a.config.ConnectCAConfig["cluster_id"]; ok {
|
||||
if cIDStr, ok := clusterID.(string); ok {
|
||||
if _, err := uuid.ParseUUID(cIDStr); err == nil {
|
||||
// Valid UUID configured, use that
|
||||
base.CAConfig.ClusterID = cIDStr
|
||||
}
|
||||
}
|
||||
if base.CAConfig.ClusterID == "" {
|
||||
// If the tried to specify an ID but typoed it don't ignore as they will
|
||||
// then bootstrap with a new ID and have to throw away the whole cluster
|
||||
// and start again.
|
||||
a.logger.Println("[ERR] connect CA config cluster_id specified but " +
|
||||
"is not a valid UUID, aborting startup")
|
||||
return nil, fmt.Errorf("cluster_id was supplied but was not a valid UUID")
|
||||
}
|
||||
}
|
||||
|
||||
if a.config.ConnectCAProvider != "" {
|
||||
base.CAConfig.Provider = a.config.ConnectCAProvider
|
||||
|
||||
// Merge with the default config if it's the consul provider.
|
||||
if a.config.ConnectCAProvider == "consul" {
|
||||
for k, v := range a.config.ConnectCAConfig {
|
||||
base.CAConfig.Config[k] = v
|
||||
}
|
||||
} else {
|
||||
base.CAConfig.Config = a.config.ConnectCAConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup the user event callback
|
||||
base.UserEventHandler = func(e serf.UserEvent) {
|
||||
select {
|
||||
|
@ -1009,7 +1108,7 @@ func (a *Agent) setupNodeID(config *config.RuntimeConfig) error {
|
|||
}
|
||||
|
||||
// For dev mode we have no filesystem access so just make one.
|
||||
if a.config.DevMode {
|
||||
if a.config.DataDir == "" {
|
||||
id, err := a.makeNodeID()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -1223,6 +1322,24 @@ func (a *Agent) ShutdownAgent() error {
|
|||
chk.Stop()
|
||||
}
|
||||
|
||||
// Stop the proxy manager
|
||||
if a.proxyManager != nil {
|
||||
// If persistence is disabled (implies DevMode but a subset of DevMode) then
|
||||
// don't leave the proxies running since the agent will not be able to
|
||||
// recover them later.
|
||||
if a.config.DataDir == "" {
|
||||
a.logger.Printf("[WARN] agent: dev mode disabled persistence, killing " +
|
||||
"all proxies since we can't recover them")
|
||||
if err := a.proxyManager.Kill(); err != nil {
|
||||
a.logger.Printf("[WARN] agent: error shutting down proxy manager: %s", err)
|
||||
}
|
||||
} else {
|
||||
if err := a.proxyManager.Close(); err != nil {
|
||||
a.logger.Printf("[WARN] agent: error shutting down proxy manager: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if a.delegate != nil {
|
||||
err = a.delegate.Shutdown()
|
||||
|
@ -1492,7 +1609,7 @@ func (a *Agent) persistService(service *structs.NodeService) error {
|
|||
return err
|
||||
}
|
||||
|
||||
return writeFileAtomic(svcPath, encoded)
|
||||
return file.WriteAtomic(svcPath, encoded)
|
||||
}
|
||||
|
||||
// purgeService removes a persisted service definition file from the data dir
|
||||
|
@ -1504,6 +1621,39 @@ func (a *Agent) purgeService(serviceID string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// persistedProxy is used to wrap a proxy definition and bundle it with an Proxy
|
||||
// token so we can continue to authenticate the running proxy after a restart.
|
||||
type persistedProxy struct {
|
||||
ProxyToken string
|
||||
Proxy *structs.ConnectManagedProxy
|
||||
}
|
||||
|
||||
// persistProxy saves a proxy definition to a JSON file in the data dir
|
||||
func (a *Agent) persistProxy(proxy *local.ManagedProxy) error {
|
||||
proxyPath := filepath.Join(a.config.DataDir, proxyDir,
|
||||
stringHash(proxy.Proxy.ProxyService.ID))
|
||||
|
||||
wrapped := persistedProxy{
|
||||
ProxyToken: proxy.ProxyToken,
|
||||
Proxy: proxy.Proxy,
|
||||
}
|
||||
encoded, err := json.Marshal(wrapped)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return file.WriteAtomic(proxyPath, encoded)
|
||||
}
|
||||
|
||||
// purgeProxy removes a persisted proxy definition file from the data dir
|
||||
func (a *Agent) purgeProxy(proxyID string) error {
|
||||
proxyPath := filepath.Join(a.config.DataDir, proxyDir, stringHash(proxyID))
|
||||
if _, err := os.Stat(proxyPath); err == nil {
|
||||
return os.Remove(proxyPath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// persistCheck saves a check definition to the local agent's state directory
|
||||
func (a *Agent) persistCheck(check *structs.HealthCheck, chkType *structs.CheckType) error {
|
||||
checkPath := filepath.Join(a.config.DataDir, checksDir, checkIDHash(check.CheckID))
|
||||
|
@ -1520,7 +1670,7 @@ func (a *Agent) persistCheck(check *structs.HealthCheck, chkType *structs.CheckT
|
|||
return err
|
||||
}
|
||||
|
||||
return writeFileAtomic(checkPath, encoded)
|
||||
return file.WriteAtomic(checkPath, encoded)
|
||||
}
|
||||
|
||||
// purgeCheck removes a persisted check definition file from the data dir
|
||||
|
@ -1532,43 +1682,6 @@ func (a *Agent) purgeCheck(checkID types.CheckID) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// writeFileAtomic writes the given contents to a temporary file in the same
|
||||
// directory, does an fsync and then renames the file to its real path
|
||||
func writeFileAtomic(path string, contents []byte) error {
|
||||
uuid, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tempPath := fmt.Sprintf("%s-%s.tmp", path, uuid)
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
fh, err := os.OpenFile(tempPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fh.Write(contents); err != nil {
|
||||
fh.Close()
|
||||
os.Remove(tempPath)
|
||||
return err
|
||||
}
|
||||
if err := fh.Sync(); err != nil {
|
||||
fh.Close()
|
||||
os.Remove(tempPath)
|
||||
return err
|
||||
}
|
||||
if err := fh.Close(); err != nil {
|
||||
os.Remove(tempPath)
|
||||
return err
|
||||
}
|
||||
if err := os.Rename(tempPath, path); err != nil {
|
||||
os.Remove(tempPath)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddService is used to add a service entry.
|
||||
// This entry is persistent and the agent will make a best effort to
|
||||
// ensure it is registered
|
||||
|
@ -1622,7 +1735,7 @@ func (a *Agent) AddService(service *structs.NodeService, chkTypes []*structs.Che
|
|||
a.State.AddService(service, token)
|
||||
|
||||
// Persist the service to a file
|
||||
if persist && !a.config.DevMode {
|
||||
if persist && a.config.DataDir != "" {
|
||||
if err := a.persistService(service); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1921,7 +2034,7 @@ func (a *Agent) AddCheck(check *structs.HealthCheck, chkType *structs.CheckType,
|
|||
}
|
||||
|
||||
// Persist the check
|
||||
if persist && !a.config.DevMode {
|
||||
if persist && a.config.DataDir != "" {
|
||||
return a.persistCheck(check, chkType)
|
||||
}
|
||||
|
||||
|
@ -1973,6 +2086,277 @@ func (a *Agent) RemoveCheck(checkID types.CheckID, persist bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// AddProxy adds a new local Connect Proxy instance to be managed by the agent.
|
||||
//
|
||||
// It REQUIRES that the service that is being proxied is already present in the
|
||||
// local state. Note that this is only used for agent-managed proxies so we can
|
||||
// ensure that we always make this true. For externally managed and registered
|
||||
// proxies we explicitly allow the proxy to be registered first to make
|
||||
// bootstrap ordering of a new service simpler but the same is not true here
|
||||
// since this is only ever called when setting up a _managed_ proxy which was
|
||||
// registered as part of a service registration either from config or HTTP API
|
||||
// call.
|
||||
//
|
||||
// The restoredProxyToken argument should only be used when restoring proxy
|
||||
// definitions from disk; new proxies must leave it blank to get a new token
|
||||
// assigned. We need to restore from disk to enable to continue authenticating
|
||||
// running proxies that already had that credential injected.
|
||||
func (a *Agent) AddProxy(proxy *structs.ConnectManagedProxy, persist bool,
|
||||
restoredProxyToken string) error {
|
||||
// Lookup the target service token in state if there is one.
|
||||
token := a.State.ServiceToken(proxy.TargetServiceID)
|
||||
|
||||
// Copy the basic proxy structure so it isn't modified w/ defaults
|
||||
proxyCopy := *proxy
|
||||
proxy = &proxyCopy
|
||||
if err := a.applyProxyDefaults(proxy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the proxy to local state first since we may need to assign a port which
|
||||
// needs to be coordinate under state lock. AddProxy will generate the
|
||||
// NodeService for the proxy populated with the allocated (or configured) port
|
||||
// and an ID, but it doesn't add it to the agent directly since that could
|
||||
// deadlock and we may need to coordinate adding it and persisting etc.
|
||||
proxyState, err := a.State.AddProxy(proxy, token, restoredProxyToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
proxyService := proxyState.Proxy.ProxyService
|
||||
|
||||
// Register proxy TCP check. The built in proxy doesn't listen publically
|
||||
// until it's loaded certs so this ensures we won't route traffic until it's
|
||||
// ready.
|
||||
proxyCfg, err := a.applyProxyConfigDefaults(proxyState.Proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chkTypes := []*structs.CheckType{
|
||||
&structs.CheckType{
|
||||
Name: "Connect Proxy Listening",
|
||||
TCP: fmt.Sprintf("%s:%d", proxyCfg["bind_address"],
|
||||
proxyCfg["bind_port"]),
|
||||
Interval: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
err = a.AddService(proxyService, chkTypes, persist, token)
|
||||
if err != nil {
|
||||
// Remove the state too
|
||||
a.State.RemoveProxy(proxyService.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
// Persist the proxy
|
||||
if persist && a.config.DataDir != "" {
|
||||
return a.persistProxy(proxyState)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyProxyConfigDefaults takes a *structs.ConnectManagedProxy and returns
|
||||
// it's Config map merged with any defaults from the Agent's config. It would be
|
||||
// nicer if this were defined as a method on structs.ConnectManagedProxy but we
|
||||
// can't do that because ot the import cycle it causes with agent/config.
|
||||
func (a *Agent) applyProxyConfigDefaults(p *structs.ConnectManagedProxy) (map[string]interface{}, error) {
|
||||
if p == nil || p.ProxyService == nil {
|
||||
// Should never happen but protect from panic
|
||||
return nil, fmt.Errorf("invalid proxy state")
|
||||
}
|
||||
|
||||
// Lookup the target service
|
||||
target := a.State.Service(p.TargetServiceID)
|
||||
if target == nil {
|
||||
// Can happen during deregistration race between proxy and scheduler.
|
||||
return nil, fmt.Errorf("unknown target service ID: %s", p.TargetServiceID)
|
||||
}
|
||||
|
||||
// Merge globals defaults
|
||||
config := make(map[string]interface{})
|
||||
for k, v := range a.config.ConnectProxyDefaultConfig {
|
||||
if _, ok := config[k]; !ok {
|
||||
config[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Copy config from the proxy
|
||||
for k, v := range p.Config {
|
||||
config[k] = v
|
||||
}
|
||||
|
||||
// Set defaults for anything that is still not specified but required.
|
||||
// Note that these are not included in the content hash. Since we expect
|
||||
// them to be static in general but some like the default target service
|
||||
// port might not be. In that edge case services can set that explicitly
|
||||
// when they re-register which will be caught though.
|
||||
if _, ok := config["bind_port"]; !ok {
|
||||
config["bind_port"] = p.ProxyService.Port
|
||||
}
|
||||
if _, ok := config["bind_address"]; !ok {
|
||||
// Default to binding to the same address the agent is configured to
|
||||
// bind to.
|
||||
config["bind_address"] = a.config.BindAddr.String()
|
||||
}
|
||||
if _, ok := config["local_service_address"]; !ok {
|
||||
// Default to localhost and the port the service registered with
|
||||
config["local_service_address"] = fmt.Sprintf("127.0.0.1:%d", target.Port)
|
||||
}
|
||||
|
||||
// Basic type conversions for expected types.
|
||||
if raw, ok := config["bind_port"]; ok {
|
||||
switch v := raw.(type) {
|
||||
case float64:
|
||||
// Common since HCL/JSON parse as float64
|
||||
config["bind_port"] = int(v)
|
||||
|
||||
// NOTE(mitchellh): No default case since errors and validation
|
||||
// are handled by the ServiceDefinition.Validate function.
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// applyProxyDefaults modifies the given proxy by applying any configured
|
||||
// defaults, such as the default execution mode, command, etc.
|
||||
func (a *Agent) applyProxyDefaults(proxy *structs.ConnectManagedProxy) error {
|
||||
// Set the default exec mode
|
||||
if proxy.ExecMode == structs.ProxyExecModeUnspecified {
|
||||
mode, err := structs.NewProxyExecMode(a.config.ConnectProxyDefaultExecMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
proxy.ExecMode = mode
|
||||
}
|
||||
if proxy.ExecMode == structs.ProxyExecModeUnspecified {
|
||||
proxy.ExecMode = structs.ProxyExecModeDaemon
|
||||
}
|
||||
|
||||
// Set the default command to the globally configured default
|
||||
if len(proxy.Command) == 0 {
|
||||
switch proxy.ExecMode {
|
||||
case structs.ProxyExecModeDaemon:
|
||||
proxy.Command = a.config.ConnectProxyDefaultDaemonCommand
|
||||
|
||||
case structs.ProxyExecModeScript:
|
||||
proxy.Command = a.config.ConnectProxyDefaultScriptCommand
|
||||
}
|
||||
}
|
||||
|
||||
// If there is no globally configured default we need to get the
|
||||
// default command so we can do "consul connect proxy"
|
||||
if len(proxy.Command) == 0 {
|
||||
command, err := defaultProxyCommand(a.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
proxy.Command = command
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveProxy stops and removes a local proxy instance.
|
||||
func (a *Agent) RemoveProxy(proxyID string, persist bool) error {
|
||||
// Validate proxyID
|
||||
if proxyID == "" {
|
||||
return fmt.Errorf("proxyID missing")
|
||||
}
|
||||
|
||||
// Remove the proxy from the local state
|
||||
p, err := a.State.RemoveProxy(proxyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the proxy service as well. The proxy ID is also the ID
|
||||
// of the servie, but we might as well use the service pointer.
|
||||
if err := a.RemoveService(p.Proxy.ProxyService.ID, persist); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if persist && a.config.DataDir != "" {
|
||||
return a.purgeProxy(proxyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyProxyToken takes a token and attempts to verify it against the
|
||||
// targetService name. If targetProxy is specified, then the local proxy token
|
||||
// must exactly match the given proxy ID. cert, config, etc.).
|
||||
//
|
||||
// The given token may be a local-only proxy token or it may be an ACL token. We
|
||||
// will attempt to verify the local proxy token first.
|
||||
//
|
||||
// The effective ACL token is returned along with a boolean which is true if the
|
||||
// match was against a proxy token rather than an ACL token, and any error. In
|
||||
// the case the token matches a proxy token, then the ACL token used to register
|
||||
// that proxy's target service is returned for use in any RPC calls the proxy
|
||||
// needs to make on behalf of that service. If the token was an ACL token
|
||||
// already then it is always returned. Provided error is nil, a valid ACL token
|
||||
// is always returned.
|
||||
func (a *Agent) verifyProxyToken(token, targetService,
|
||||
targetProxy string) (string, bool, error) {
|
||||
// If we specify a target proxy, we look up that proxy directly. Otherwise,
|
||||
// we resolve with any proxy we can find.
|
||||
var proxy *local.ManagedProxy
|
||||
if targetProxy != "" {
|
||||
proxy = a.State.Proxy(targetProxy)
|
||||
if proxy == nil {
|
||||
return "", false, fmt.Errorf("unknown proxy service ID: %q", targetProxy)
|
||||
}
|
||||
|
||||
// If the token DOESN'T match, then we reset the proxy which will
|
||||
// cause the logic below to fall back to normal ACLs. Otherwise,
|
||||
// we keep the proxy set because we also have to verify that the
|
||||
// target service matches on the proxy.
|
||||
if token != proxy.ProxyToken {
|
||||
proxy = nil
|
||||
}
|
||||
} else {
|
||||
proxy = a.resolveProxyToken(token)
|
||||
}
|
||||
|
||||
// The existence of a token isn't enough, we also need to verify
|
||||
// that the service name of the matching proxy matches our target
|
||||
// service.
|
||||
if proxy != nil {
|
||||
// Get the target service since we only have the name. The nil
|
||||
// check below should never be true since a proxy token always
|
||||
// represents the existence of a local service.
|
||||
target := a.State.Service(proxy.Proxy.TargetServiceID)
|
||||
if target == nil {
|
||||
return "", false, fmt.Errorf("proxy target service not found: %q",
|
||||
proxy.Proxy.TargetServiceID)
|
||||
}
|
||||
|
||||
if target.Service != targetService {
|
||||
return "", false, acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// Resolve the actual ACL token used to register the proxy/service and
|
||||
// return that for use in RPC calls.
|
||||
return a.State.ServiceToken(proxy.Proxy.TargetServiceID), true, nil
|
||||
}
|
||||
|
||||
// Doesn't match, we have to do a full token resolution. The required
|
||||
// permission for any proxy-related endpoint is service:write, since
|
||||
// to register a proxy you require that permission and sensitive data
|
||||
// is usually present in the configuration.
|
||||
rule, err := a.resolveToken(token)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
if rule != nil && !rule.ServiceWrite(targetService, nil) {
|
||||
return "", false, acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
return token, false, nil
|
||||
}
|
||||
|
||||
func (a *Agent) cancelCheckMonitors(checkID types.CheckID) {
|
||||
// Stop any monitors
|
||||
delete(a.checkReapAfter, checkID)
|
||||
|
@ -2017,7 +2401,7 @@ func (a *Agent) updateTTLCheck(checkID types.CheckID, status, output string) err
|
|||
check.SetStatus(status, output)
|
||||
|
||||
// We don't write any files in dev mode so bail here.
|
||||
if a.config.DevMode {
|
||||
if a.config.DataDir == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -2366,6 +2750,96 @@ func (a *Agent) unloadChecks() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// loadProxies will load connect proxy definitions from configuration and
|
||||
// persisted definitions on disk, and load them into the local agent.
|
||||
func (a *Agent) loadProxies(conf *config.RuntimeConfig) error {
|
||||
for _, svc := range conf.Services {
|
||||
if svc.Connect != nil {
|
||||
proxy, err := svc.ConnectManagedProxy()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding proxy: %s", err)
|
||||
}
|
||||
if proxy == nil {
|
||||
continue
|
||||
}
|
||||
if err := a.AddProxy(proxy, false, ""); err != nil {
|
||||
return fmt.Errorf("failed adding proxy: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load any persisted proxies
|
||||
proxyDir := filepath.Join(a.config.DataDir, proxyDir)
|
||||
files, err := ioutil.ReadDir(proxyDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("Failed reading proxies dir %q: %s", proxyDir, err)
|
||||
}
|
||||
for _, fi := range files {
|
||||
// Skip all dirs
|
||||
if fi.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip all partially written temporary files
|
||||
if strings.HasSuffix(fi.Name(), "tmp") {
|
||||
a.logger.Printf("[WARN] agent: Ignoring temporary proxy file %v", fi.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
// Open the file for reading
|
||||
file := filepath.Join(proxyDir, fi.Name())
|
||||
fh, err := os.Open(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed opening proxy file %q: %s", file, err)
|
||||
}
|
||||
|
||||
// Read the contents into a buffer
|
||||
buf, err := ioutil.ReadAll(fh)
|
||||
fh.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading proxy file %q: %s", file, err)
|
||||
}
|
||||
|
||||
// Try decoding the proxy definition
|
||||
var p persistedProxy
|
||||
if err := json.Unmarshal(buf, &p); err != nil {
|
||||
a.logger.Printf("[ERR] agent: Failed decoding proxy file %q: %s", file, err)
|
||||
continue
|
||||
}
|
||||
proxyID := p.Proxy.ProxyService.ID
|
||||
|
||||
if a.State.Proxy(proxyID) != nil {
|
||||
// Purge previously persisted proxy. This allows config to be preferred
|
||||
// over services persisted from the API.
|
||||
a.logger.Printf("[DEBUG] agent: proxy %q exists, not restoring from %q",
|
||||
proxyID, file)
|
||||
if err := a.purgeProxy(proxyID); err != nil {
|
||||
return fmt.Errorf("failed purging proxy %q: %s", proxyID, err)
|
||||
}
|
||||
} else {
|
||||
a.logger.Printf("[DEBUG] agent: restored proxy definition %q from %q",
|
||||
proxyID, file)
|
||||
if err := a.AddProxy(p.Proxy, false, p.ProxyToken); err != nil {
|
||||
return fmt.Errorf("failed adding proxy %q: %s", proxyID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// unloadProxies will deregister all proxies known to the local agent.
|
||||
func (a *Agent) unloadProxies() error {
|
||||
for id := range a.State.Proxies() {
|
||||
if err := a.RemoveProxy(id, false); err != nil {
|
||||
return fmt.Errorf("Failed deregistering proxy '%s': %s", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// snapshotCheckState is used to snapshot the current state of the health
|
||||
// checks. This is done before we reload our checks, so that we can properly
|
||||
// restore into the same state.
|
||||
|
@ -2491,6 +2965,11 @@ func (a *Agent) DisableNodeMaintenance() {
|
|||
a.logger.Printf("[INFO] agent: Node left maintenance mode")
|
||||
}
|
||||
|
||||
func (a *Agent) loadLimits(conf *config.RuntimeConfig) {
|
||||
a.config.RPCRateLimit = conf.RPCRateLimit
|
||||
a.config.RPCMaxBurst = conf.RPCMaxBurst
|
||||
}
|
||||
|
||||
func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
|
||||
// Bulk update the services and checks
|
||||
a.PauseSync()
|
||||
|
@ -2502,6 +2981,9 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
|
|||
|
||||
// First unload all checks, services, and metadata. This lets us begin the reload
|
||||
// with a clean slate.
|
||||
if err := a.unloadProxies(); err != nil {
|
||||
return fmt.Errorf("Failed unloading proxies: %s", err)
|
||||
}
|
||||
if err := a.unloadServices(); err != nil {
|
||||
return fmt.Errorf("Failed unloading services: %s", err)
|
||||
}
|
||||
|
@ -2514,6 +2996,9 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
|
|||
if err := a.loadServices(newCfg); err != nil {
|
||||
return fmt.Errorf("Failed reloading services: %s", err)
|
||||
}
|
||||
if err := a.loadProxies(newCfg); err != nil {
|
||||
return fmt.Errorf("Failed reloading proxies: %s", err)
|
||||
}
|
||||
if err := a.loadChecks(newCfg); err != nil {
|
||||
return fmt.Errorf("Failed reloading checks: %s", err)
|
||||
}
|
||||
|
@ -2525,10 +3010,75 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
|
|||
return fmt.Errorf("Failed reloading watches: %v", err)
|
||||
}
|
||||
|
||||
a.loadLimits(newCfg)
|
||||
|
||||
// create the config for the rpc server/client
|
||||
consulCfg, err := a.consulConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := a.delegate.ReloadConfig(consulCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update filtered metrics
|
||||
metrics.UpdateFilter(newCfg.TelemetryAllowedPrefixes, newCfg.TelemetryBlockedPrefixes)
|
||||
metrics.UpdateFilter(newCfg.Telemetry.AllowedPrefixes,
|
||||
newCfg.Telemetry.BlockedPrefixes)
|
||||
|
||||
a.State.SetDiscardCheckOutput(newCfg.DiscardCheckOutput)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerCache configures the cache and registers all the supported
|
||||
// types onto the cache. This is NOT safe to call multiple times so
|
||||
// care should be taken to call this exactly once after the cache
|
||||
// field has been initialized.
|
||||
func (a *Agent) registerCache() {
|
||||
a.cache.RegisterType(cachetype.ConnectCARootName, &cachetype.ConnectCARoot{
|
||||
RPC: a.delegate,
|
||||
}, &cache.RegisterOptions{
|
||||
// Maintain a blocking query, retry dropped connections quickly
|
||||
Refresh: true,
|
||||
RefreshTimer: 0 * time.Second,
|
||||
RefreshTimeout: 10 * time.Minute,
|
||||
})
|
||||
|
||||
a.cache.RegisterType(cachetype.ConnectCALeafName, &cachetype.ConnectCALeaf{
|
||||
RPC: a.delegate,
|
||||
Cache: a.cache,
|
||||
}, &cache.RegisterOptions{
|
||||
// Maintain a blocking query, retry dropped connections quickly
|
||||
Refresh: true,
|
||||
RefreshTimer: 0 * time.Second,
|
||||
RefreshTimeout: 10 * time.Minute,
|
||||
})
|
||||
|
||||
a.cache.RegisterType(cachetype.IntentionMatchName, &cachetype.IntentionMatch{
|
||||
RPC: a.delegate,
|
||||
}, &cache.RegisterOptions{
|
||||
// Maintain a blocking query, retry dropped connections quickly
|
||||
Refresh: true,
|
||||
RefreshTimer: 0 * time.Second,
|
||||
RefreshTimeout: 10 * time.Minute,
|
||||
})
|
||||
}
|
||||
|
||||
// defaultProxyCommand returns the default Connect managed proxy command.
|
||||
func defaultProxyCommand(agentCfg *config.RuntimeConfig) ([]string, error) {
|
||||
// Get the path to the current exectuable. This is cached once by the
|
||||
// library so this is effectively just a variable read.
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// "consul connect proxy" default value for managed daemon proxy
|
||||
cmd := []string{execPath, "connect", "proxy"}
|
||||
|
||||
if agentCfg != nil && agentCfg.LogLevel != "INFO" {
|
||||
cmd = append(cmd, "-log-level", agentCfg.LogLevel)
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
|
|
@ -4,12 +4,21 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/mitchellh/hashstructure"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/cache-types"
|
||||
"github.com/hashicorp/consul/agent/checks"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/ipaddr"
|
||||
|
@ -97,7 +106,7 @@ func (s *HTTPServer) AgentMetrics(resp http.ResponseWriter, req *http.Request) (
|
|||
return nil, acl.ErrPermissionDenied
|
||||
}
|
||||
if enablePrometheusOutput(req) {
|
||||
if s.agent.config.TelemetryPrometheusRetentionTime < 1 {
|
||||
if s.agent.config.Telemetry.PrometheusRetentionTime < 1 {
|
||||
resp.WriteHeader(http.StatusUnsupportedMediaType)
|
||||
fmt.Fprint(resp, "Prometheus is not enable since its retention time is not positive")
|
||||
return nil, nil
|
||||
|
@ -153,25 +162,49 @@ func (s *HTTPServer) AgentServices(resp http.ResponseWriter, req *http.Request)
|
|||
return nil, err
|
||||
}
|
||||
|
||||
proxies := s.agent.State.Proxies()
|
||||
|
||||
// Convert into api.AgentService since that includes Connect config but so far
|
||||
// NodeService doesn't need to internally. They are otherwise identical since
|
||||
// that is the struct used in client for reading the one we output here
|
||||
// anyway.
|
||||
agentSvcs := make(map[string]*api.AgentService)
|
||||
|
||||
// Use empty list instead of nil
|
||||
for id, s := range services {
|
||||
if s.Tags == nil || s.Meta == nil {
|
||||
clone := *s
|
||||
if s.Tags == nil {
|
||||
clone.Tags = make([]string, 0)
|
||||
} else {
|
||||
clone.Tags = s.Tags
|
||||
}
|
||||
if s.Meta == nil {
|
||||
clone.Meta = make(map[string]string)
|
||||
} else {
|
||||
clone.Meta = s.Meta
|
||||
}
|
||||
services[id] = &clone
|
||||
as := &api.AgentService{
|
||||
Kind: api.ServiceKind(s.Kind),
|
||||
ID: s.ID,
|
||||
Service: s.Service,
|
||||
Tags: s.Tags,
|
||||
Meta: s.Meta,
|
||||
Port: s.Port,
|
||||
Address: s.Address,
|
||||
EnableTagOverride: s.EnableTagOverride,
|
||||
CreateIndex: s.CreateIndex,
|
||||
ModifyIndex: s.ModifyIndex,
|
||||
ProxyDestination: s.ProxyDestination,
|
||||
}
|
||||
if as.Tags == nil {
|
||||
as.Tags = []string{}
|
||||
}
|
||||
if as.Meta == nil {
|
||||
as.Meta = map[string]string{}
|
||||
}
|
||||
// Attach Connect configs if the exist
|
||||
if proxy, ok := proxies[id+"-proxy"]; ok {
|
||||
as.Connect = &api.AgentServiceConnect{
|
||||
Proxy: &api.AgentServiceConnectProxy{
|
||||
ExecMode: api.ProxyExecMode(proxy.Proxy.ExecMode.String()),
|
||||
Command: proxy.Proxy.Command,
|
||||
Config: proxy.Proxy.Config,
|
||||
},
|
||||
}
|
||||
}
|
||||
agentSvcs[id] = as
|
||||
}
|
||||
|
||||
return services, nil
|
||||
return agentSvcs, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) AgentChecks(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
|
@ -554,6 +587,14 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// Run validation. This is the same validation that would happen on
|
||||
// the catalog endpoint so it helps ensure the sync will work properly.
|
||||
if err := ns.Validate(); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, err.Error())
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Verify the check type.
|
||||
chkTypes, err := args.CheckTypes()
|
||||
if err != nil {
|
||||
|
@ -576,10 +617,30 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Get any proxy registrations
|
||||
proxy, err := args.ConnectManagedProxy()
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, err.Error())
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If we have a proxy, verify that we're allowed to add a proxy via the API
|
||||
if proxy != nil && !s.agent.config.ConnectProxyAllowManagedAPIRegistration {
|
||||
return nil, &BadRequestError{
|
||||
Reason: "Managed proxy registration via the API is disallowed."}
|
||||
}
|
||||
|
||||
// Add the service.
|
||||
if err := s.agent.AddService(ns, chkTypes, true, token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Add proxy (which will add proxy service so do it before we trigger sync)
|
||||
if proxy != nil {
|
||||
if err := s.agent.AddProxy(proxy, true, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
s.syncChanges()
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -594,9 +655,27 @@ func (s *HTTPServer) AgentDeregisterService(resp http.ResponseWriter, req *http.
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Verify this isn't a proxy
|
||||
if s.agent.State.Proxy(serviceID) != nil {
|
||||
return nil, &BadRequestError{
|
||||
Reason: "Managed proxy service cannot be deregistered directly. " +
|
||||
"Deregister the service that has a managed proxy to automatically " +
|
||||
"deregister the managed proxy itself."}
|
||||
}
|
||||
|
||||
if err := s.agent.RemoveService(serviceID, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove the associated managed proxy if it exists
|
||||
for proxyID, p := range s.agent.State.Proxies() {
|
||||
if p.Proxy.TargetServiceID == serviceID {
|
||||
if err := s.agent.RemoveProxy(proxyID, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.syncChanges()
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -828,3 +907,394 @@ func (s *HTTPServer) AgentToken(resp http.ResponseWriter, req *http.Request) (in
|
|||
s.agent.logger.Printf("[INFO] agent: Updated agent's ACL token %q", target)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// AgentConnectCARoots returns the trusted CA roots.
|
||||
func (s *HTTPServer) AgentConnectCARoots(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
var args structs.DCSpecificRequest
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
raw, m, err := s.agent.cache.Get(cachetype.ConnectCARootName, &args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer setCacheMeta(resp, &m)
|
||||
|
||||
// Add cache hit
|
||||
|
||||
reply, ok := raw.(*structs.IndexedCARoots)
|
||||
if !ok {
|
||||
// This should never happen, but we want to protect against panics
|
||||
return nil, fmt.Errorf("internal error: response type not correct")
|
||||
}
|
||||
defer setMeta(resp, &reply.QueryMeta)
|
||||
|
||||
return *reply, nil
|
||||
}
|
||||
|
||||
// AgentConnectCALeafCert returns the certificate bundle for a service
|
||||
// instance. This supports blocking queries to update the returned bundle.
|
||||
func (s *HTTPServer) AgentConnectCALeafCert(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Get the service name. Note that this is the name of the sevice,
|
||||
// not the ID of the service instance.
|
||||
serviceName := strings.TrimPrefix(req.URL.Path, "/v1/agent/connect/ca/leaf/")
|
||||
|
||||
args := cachetype.ConnectCALeafRequest{
|
||||
Service: serviceName, // Need name not ID
|
||||
}
|
||||
var qOpts structs.QueryOptions
|
||||
// Store DC in the ConnectCALeafRequest but query opts separately
|
||||
if done := s.parse(resp, req, &args.Datacenter, &qOpts); done {
|
||||
return nil, nil
|
||||
}
|
||||
args.MinQueryIndex = qOpts.MinQueryIndex
|
||||
|
||||
// Verify the proxy token. This will check both the local proxy token
|
||||
// as well as the ACL if the token isn't local.
|
||||
effectiveToken, _, err := s.agent.verifyProxyToken(qOpts.Token, serviceName, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args.Token = effectiveToken
|
||||
|
||||
raw, m, err := s.agent.cache.Get(cachetype.ConnectCALeafName, &args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer setCacheMeta(resp, &m)
|
||||
|
||||
reply, ok := raw.(*structs.IssuedCert)
|
||||
if !ok {
|
||||
// This should never happen, but we want to protect against panics
|
||||
return nil, fmt.Errorf("internal error: response type not correct")
|
||||
}
|
||||
setIndex(resp, reply.ModifyIndex)
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// GET /v1/agent/connect/proxy/:proxy_service_id
|
||||
//
|
||||
// Returns the local proxy config for the identified proxy. Requires token=
|
||||
// param with the correct local ProxyToken (not ACL token).
|
||||
func (s *HTTPServer) AgentConnectProxyConfig(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Get the proxy ID. Note that this is the ID of a proxy's service instance.
|
||||
id := strings.TrimPrefix(req.URL.Path, "/v1/agent/connect/proxy/")
|
||||
|
||||
// Maybe block
|
||||
var queryOpts structs.QueryOptions
|
||||
if parseWait(resp, req, &queryOpts) {
|
||||
// parseWait returns an error itself
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Parse the token
|
||||
var token string
|
||||
s.parseToken(req, &token)
|
||||
|
||||
// Parse hash specially since it's only this endpoint that uses it currently.
|
||||
// Eventually this should happen in parseWait and end up in QueryOptions but I
|
||||
// didn't want to make very general changes right away.
|
||||
hash := req.URL.Query().Get("hash")
|
||||
|
||||
return s.agentLocalBlockingQuery(resp, hash, &queryOpts,
|
||||
func(ws memdb.WatchSet) (string, interface{}, error) {
|
||||
// Retrieve the proxy specified
|
||||
proxy := s.agent.State.Proxy(id)
|
||||
if proxy == nil {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(resp, "unknown proxy service ID: %s", id)
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// Lookup the target service as a convenience
|
||||
target := s.agent.State.Service(proxy.Proxy.TargetServiceID)
|
||||
if target == nil {
|
||||
// Not found since this endpoint is only useful for agent-managed proxies so
|
||||
// service missing means the service was deregistered racily with this call.
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(resp, "unknown target service ID: %s", proxy.Proxy.TargetServiceID)
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// Validate the ACL token
|
||||
_, isProxyToken, err := s.agent.verifyProxyToken(token, target.Service, id)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Watch the proxy for changes
|
||||
ws.Add(proxy.WatchCh)
|
||||
|
||||
hash, err := hashstructure.Hash(proxy.Proxy, nil)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
contentHash := fmt.Sprintf("%x", hash)
|
||||
|
||||
// Set defaults
|
||||
config, err := s.agent.applyProxyConfigDefaults(proxy.Proxy)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Only merge in telemetry config from agent if the requested is
|
||||
// authorized with a proxy token. This prevents us leaking potentially
|
||||
// sensitive config like Circonus API token via a public endpoint. Proxy
|
||||
// tokens are only ever generated in-memory and passed via ENV to a child
|
||||
// proxy process so potential for abuse here seems small. This endpoint in
|
||||
// general is only useful for managed proxies now so it should _always_ be
|
||||
// true that auth is via a proxy token but inconvenient for testing if we
|
||||
// lock it down so strictly.
|
||||
if isProxyToken {
|
||||
// Add telemetry config. Copy the global config so we can customize the
|
||||
// prefix.
|
||||
telemetryCfg := s.agent.config.Telemetry
|
||||
telemetryCfg.MetricsPrefix = telemetryCfg.MetricsPrefix + ".proxy." + target.ID
|
||||
|
||||
// First see if the user has specified telemetry
|
||||
if userRaw, ok := config["telemetry"]; ok {
|
||||
// User specified domething, see if it is compatible with agent
|
||||
// telemetry config:
|
||||
var uCfg lib.TelemetryConfig
|
||||
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
Result: &uCfg,
|
||||
// Make sure that if the user passes something that isn't just a
|
||||
// simple override of a valid TelemetryConfig that we fail so that we
|
||||
// don't clobber their custom config.
|
||||
ErrorUnused: true,
|
||||
})
|
||||
if err == nil {
|
||||
if err = dec.Decode(userRaw); err == nil {
|
||||
// It did decode! Merge any unspecified fields from agent config.
|
||||
uCfg.MergeDefaults(&telemetryCfg)
|
||||
config["telemetry"] = uCfg
|
||||
}
|
||||
}
|
||||
// Failed to decode, just keep user's config["telemetry"] verbatim
|
||||
// with no agent merge.
|
||||
} else {
|
||||
// Add agent telemetry config.
|
||||
config["telemetry"] = telemetryCfg
|
||||
}
|
||||
}
|
||||
|
||||
reply := &api.ConnectProxyConfig{
|
||||
ProxyServiceID: proxy.Proxy.ProxyService.ID,
|
||||
TargetServiceID: target.ID,
|
||||
TargetServiceName: target.Service,
|
||||
ContentHash: contentHash,
|
||||
ExecMode: api.ProxyExecMode(proxy.Proxy.ExecMode.String()),
|
||||
Command: proxy.Proxy.Command,
|
||||
Config: config,
|
||||
}
|
||||
return contentHash, reply, nil
|
||||
})
|
||||
}
|
||||
|
||||
type agentLocalBlockingFunc func(ws memdb.WatchSet) (string, interface{}, error)
|
||||
|
||||
// agentLocalBlockingQuery performs a blocking query in a generic way against
|
||||
// local agent state that has no RPC or raft to back it. It uses `hash` paramter
|
||||
// instead of an `index`. The resp is needed to write the `X-Consul-ContentHash`
|
||||
// header back on return no Status nor body content is ever written to it.
|
||||
func (s *HTTPServer) agentLocalBlockingQuery(resp http.ResponseWriter, hash string,
|
||||
queryOpts *structs.QueryOptions, fn agentLocalBlockingFunc) (interface{}, error) {
|
||||
|
||||
// If we are not blocking we can skip tracking and allocating - nil WatchSet
|
||||
// is still valid to call Add on and will just be a no op.
|
||||
var ws memdb.WatchSet
|
||||
var timeout *time.Timer
|
||||
|
||||
if hash != "" {
|
||||
// TODO(banks) at least define these defaults somewhere in a const. Would be
|
||||
// nice not to duplicate the ones in consul/rpc.go too...
|
||||
wait := queryOpts.MaxQueryTime
|
||||
if wait == 0 {
|
||||
wait = 5 * time.Minute
|
||||
}
|
||||
if wait > 10*time.Minute {
|
||||
wait = 10 * time.Minute
|
||||
}
|
||||
// Apply a small amount of jitter to the request.
|
||||
wait += lib.RandomStagger(wait / 16)
|
||||
timeout = time.NewTimer(wait)
|
||||
}
|
||||
|
||||
for {
|
||||
// Must reset this every loop in case the Watch set is already closed but
|
||||
// hash remains same. In that case we'll need to re-block on ws.Watch()
|
||||
// again.
|
||||
ws = memdb.NewWatchSet()
|
||||
curHash, curResp, err := fn(ws)
|
||||
if err != nil {
|
||||
return curResp, err
|
||||
}
|
||||
// Return immediately if there is no timeout, the hash is different or the
|
||||
// Watch returns true (indicating timeout fired). Note that Watch on a nil
|
||||
// WatchSet immediately returns false which would incorrectly cause this to
|
||||
// loop and repeat again, however we rely on the invariant that ws == nil
|
||||
// IFF timeout == nil in which case the Watch call is never invoked.
|
||||
if timeout == nil || hash != curHash || ws.Watch(timeout.C) {
|
||||
resp.Header().Set("X-Consul-ContentHash", curHash)
|
||||
return curResp, err
|
||||
}
|
||||
// Watch returned false indicating a change was detected, loop and repeat
|
||||
// the callback to load the new value.
|
||||
}
|
||||
}
|
||||
|
||||
// AgentConnectAuthorize
|
||||
//
|
||||
// POST /v1/agent/connect/authorize
|
||||
//
|
||||
// Note: when this logic changes, consider if the Intention.Check RPC method
|
||||
// also needs to be updated.
|
||||
func (s *HTTPServer) AgentConnectAuthorize(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Fetch the token
|
||||
var token string
|
||||
s.parseToken(req, &token)
|
||||
|
||||
// Decode the request from the request body
|
||||
var authReq structs.ConnectAuthorizeRequest
|
||||
if err := decodeBody(req, &authReq, nil); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Request decode failed: %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// We need to have a target to check intentions
|
||||
if authReq.Target == "" {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Target service must be specified")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Parse the certificate URI from the client ID
|
||||
uriRaw, err := url.Parse(authReq.ClientCertURI)
|
||||
if err != nil {
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: false,
|
||||
Reason: fmt.Sprintf("Client ID must be a URI: %s", err),
|
||||
}, nil
|
||||
}
|
||||
uri, err := connect.ParseCertURI(uriRaw)
|
||||
if err != nil {
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: false,
|
||||
Reason: fmt.Sprintf("Invalid client ID: %s", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
uriService, ok := uri.(*connect.SpiffeIDService)
|
||||
if !ok {
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: false,
|
||||
Reason: "Client ID must be a valid SPIFFE service URI",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// We need to verify service:write permissions for the given token.
|
||||
// We do this manually here since the RPC request below only verifies
|
||||
// service:read.
|
||||
rule, err := s.agent.resolveToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rule != nil && !rule.ServiceWrite(authReq.Target, nil) {
|
||||
return nil, acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// Validate the trust domain matches ours. Later we will support explicit
|
||||
// external federation but not built yet.
|
||||
rootArgs := &structs.DCSpecificRequest{Datacenter: s.agent.config.Datacenter}
|
||||
raw, _, err := s.agent.cache.Get(cachetype.ConnectCARootName, rootArgs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roots, ok := raw.(*structs.IndexedCARoots)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal error: roots response type not correct")
|
||||
}
|
||||
if roots.TrustDomain == "" {
|
||||
return nil, fmt.Errorf("connect CA not bootstrapped yet")
|
||||
}
|
||||
if roots.TrustDomain != strings.ToLower(uriService.Host) {
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: false,
|
||||
Reason: fmt.Sprintf("Identity from an external trust domain: %s",
|
||||
uriService.Host),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO(banks): Implement revocation list checking here.
|
||||
|
||||
// Get the intentions for this target service.
|
||||
args := &structs.IntentionQueryRequest{
|
||||
Datacenter: s.agent.config.Datacenter,
|
||||
Match: &structs.IntentionQueryMatch{
|
||||
Type: structs.IntentionMatchDestination,
|
||||
Entries: []structs.IntentionMatchEntry{
|
||||
{
|
||||
Namespace: structs.IntentionDefaultNamespace,
|
||||
Name: authReq.Target,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
args.Token = token
|
||||
|
||||
raw, m, err := s.agent.cache.Get(cachetype.IntentionMatchName, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setCacheMeta(resp, &m)
|
||||
|
||||
reply, ok := raw.(*structs.IndexedIntentionMatches)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal error: response type not correct")
|
||||
}
|
||||
if len(reply.Matches) != 1 {
|
||||
return nil, fmt.Errorf("Internal error loading matches")
|
||||
}
|
||||
|
||||
// Test the authorization for each match
|
||||
for _, ixn := range reply.Matches[0] {
|
||||
if auth, ok := uriService.Authorize(ixn); ok {
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: auth,
|
||||
Reason: fmt.Sprintf("Matched intention: %s", ixn.String()),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// No match, we need to determine the default behavior. We do this by
|
||||
// specifying the anonymous token token, which will get that behavior.
|
||||
// The default behavior if ACLs are disabled is to allow connections
|
||||
// to mimic the behavior of Consul itself: everything is allowed if
|
||||
// ACLs are disabled.
|
||||
rule, err = s.agent.resolveToken("")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authz := true
|
||||
reason := "ACLs disabled, access is allowed by default"
|
||||
if rule != nil {
|
||||
authz = rule.IntentionDefaultAllow()
|
||||
reason = "Default behavior configured by ACLs"
|
||||
}
|
||||
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: authz,
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// connectAuthorizeResp is the response format/structure for the
|
||||
// /v1/agent/connect/authorize endpoint.
|
||||
type connectAuthorizeResp struct {
|
||||
Authorized bool // True if authorized, false if not
|
||||
Reason string // Reason for the Authorized value (whether true or false)
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -16,6 +16,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/checks"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
|
@ -23,6 +24,8 @@ import (
|
|||
"github.com/hashicorp/consul/types"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func externalIP() (string, error) {
|
||||
|
@ -51,10 +54,62 @@ func TestAgent_MultiStartStop(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAgent_ConnectClusterIDConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hcl string
|
||||
wantClusterID string
|
||||
wantPanic bool
|
||||
}{
|
||||
{
|
||||
name: "default TestAgent has fixed cluster id",
|
||||
hcl: "",
|
||||
wantClusterID: connect.TestClusterID,
|
||||
},
|
||||
{
|
||||
name: "no cluster ID specified sets to test ID",
|
||||
hcl: "connect { enabled = true }",
|
||||
wantClusterID: connect.TestClusterID,
|
||||
},
|
||||
{
|
||||
name: "non-UUID cluster_id is fatal",
|
||||
hcl: `connect {
|
||||
enabled = true
|
||||
ca_config {
|
||||
cluster_id = "fake-id"
|
||||
}
|
||||
}`,
|
||||
wantClusterID: "",
|
||||
wantPanic: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Indirection to support panic recovery cleanly
|
||||
testFn := func() {
|
||||
a := &TestAgent{Name: "test", HCL: tt.hcl}
|
||||
a.ExpectConfigError = tt.wantPanic
|
||||
a.Start()
|
||||
defer a.Shutdown()
|
||||
|
||||
cfg := a.consulConfig()
|
||||
assert.Equal(t, tt.wantClusterID, cfg.CAConfig.ClusterID)
|
||||
}
|
||||
|
||||
if tt.wantPanic {
|
||||
require.Panics(t, testFn)
|
||||
} else {
|
||||
testFn()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_StartStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
// defer a.Shutdown()
|
||||
defer a.Shutdown()
|
||||
|
||||
if err := a.Leave(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
|
@ -1294,6 +1349,187 @@ func TestAgent_PurgeServiceOnDuplicate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAgent_PersistProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
|
||||
cfg := `
|
||||
server = false
|
||||
bootstrap = false
|
||||
data_dir = "` + dataDir + `"
|
||||
`
|
||||
a := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
|
||||
a.Start()
|
||||
defer os.RemoveAll(dataDir)
|
||||
defer a.Shutdown()
|
||||
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
|
||||
// Add a service to proxy (precondition for AddProxy)
|
||||
svc1 := &structs.NodeService{
|
||||
ID: "redis",
|
||||
Service: "redis",
|
||||
Tags: []string{"foo"},
|
||||
Port: 8000,
|
||||
}
|
||||
require.NoError(a.AddService(svc1, nil, true, ""))
|
||||
|
||||
// Add a proxy for it
|
||||
proxy := &structs.ConnectManagedProxy{
|
||||
TargetServiceID: svc1.ID,
|
||||
Command: []string{"/bin/sleep", "3600"},
|
||||
}
|
||||
|
||||
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash("redis-proxy"))
|
||||
|
||||
// Proxy is not persisted unless requested
|
||||
require.NoError(a.AddProxy(proxy, false, ""))
|
||||
_, err := os.Stat(file)
|
||||
require.Error(err, "proxy should not be persisted")
|
||||
|
||||
// Proxy is persisted if requested
|
||||
require.NoError(a.AddProxy(proxy, true, ""))
|
||||
_, err = os.Stat(file)
|
||||
require.NoError(err, "proxy should be persisted")
|
||||
|
||||
content, err := ioutil.ReadFile(file)
|
||||
require.NoError(err)
|
||||
|
||||
var gotProxy persistedProxy
|
||||
require.NoError(json.Unmarshal(content, &gotProxy))
|
||||
assert.Equal(proxy.Command, gotProxy.Proxy.Command)
|
||||
assert.Len(gotProxy.ProxyToken, 36) // sanity check for UUID
|
||||
|
||||
// Updates service definition on disk
|
||||
proxy.Config = map[string]interface{}{
|
||||
"foo": "bar",
|
||||
}
|
||||
require.NoError(a.AddProxy(proxy, true, ""))
|
||||
|
||||
content, err = ioutil.ReadFile(file)
|
||||
require.NoError(err)
|
||||
|
||||
require.NoError(json.Unmarshal(content, &gotProxy))
|
||||
assert.Equal(gotProxy.Proxy.Command, proxy.Command)
|
||||
assert.Equal(gotProxy.Proxy.Config, proxy.Config)
|
||||
assert.Len(gotProxy.ProxyToken, 36) // sanity check for UUID
|
||||
|
||||
a.Shutdown()
|
||||
|
||||
// Should load it back during later start
|
||||
a2 := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
|
||||
a2.Start()
|
||||
defer a2.Shutdown()
|
||||
|
||||
restored := a2.State.Proxy("redis-proxy")
|
||||
require.NotNil(restored)
|
||||
assert.Equal(gotProxy.ProxyToken, restored.ProxyToken)
|
||||
// Ensure the port that was auto picked at random is the same again
|
||||
assert.Equal(gotProxy.Proxy.ProxyService.Port, restored.Proxy.ProxyService.Port)
|
||||
assert.Equal(gotProxy.Proxy.Command, restored.Proxy.Command)
|
||||
}
|
||||
|
||||
func TestAgent_PurgeProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
// Add a service to proxy (precondition for AddProxy)
|
||||
svc1 := &structs.NodeService{
|
||||
ID: "redis",
|
||||
Service: "redis",
|
||||
Tags: []string{"foo"},
|
||||
Port: 8000,
|
||||
}
|
||||
require.NoError(a.AddService(svc1, nil, true, ""))
|
||||
|
||||
// Add a proxy for it
|
||||
proxy := &structs.ConnectManagedProxy{
|
||||
TargetServiceID: svc1.ID,
|
||||
Command: []string{"/bin/sleep", "3600"},
|
||||
}
|
||||
proxyID := "redis-proxy"
|
||||
require.NoError(a.AddProxy(proxy, true, ""))
|
||||
|
||||
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash("redis-proxy"))
|
||||
|
||||
// Not removed
|
||||
require.NoError(a.RemoveProxy(proxyID, false))
|
||||
_, err := os.Stat(file)
|
||||
require.NoError(err, "should not be removed")
|
||||
|
||||
// Re-add the proxy
|
||||
require.NoError(a.AddProxy(proxy, true, ""))
|
||||
|
||||
// Removed
|
||||
require.NoError(a.RemoveProxy(proxyID, true))
|
||||
_, err = os.Stat(file)
|
||||
require.Error(err, "should be removed")
|
||||
}
|
||||
|
||||
func TestAgent_PurgeProxyOnDuplicate(t *testing.T) {
|
||||
t.Parallel()
|
||||
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
|
||||
cfg := `
|
||||
data_dir = "` + dataDir + `"
|
||||
server = false
|
||||
bootstrap = false
|
||||
`
|
||||
a := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
|
||||
a.Start()
|
||||
defer a.Shutdown()
|
||||
defer os.RemoveAll(dataDir)
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
// Add a service to proxy (precondition for AddProxy)
|
||||
svc1 := &structs.NodeService{
|
||||
ID: "redis",
|
||||
Service: "redis",
|
||||
Tags: []string{"foo"},
|
||||
Port: 8000,
|
||||
}
|
||||
require.NoError(a.AddService(svc1, nil, true, ""))
|
||||
|
||||
// Add a proxy for it
|
||||
proxy := &structs.ConnectManagedProxy{
|
||||
TargetServiceID: svc1.ID,
|
||||
Command: []string{"/bin/sleep", "3600"},
|
||||
}
|
||||
proxyID := "redis-proxy"
|
||||
require.NoError(a.AddProxy(proxy, true, ""))
|
||||
|
||||
a.Shutdown()
|
||||
|
||||
// Try bringing the agent back up with the service already
|
||||
// existing in the config
|
||||
a2 := &TestAgent{Name: t.Name() + "-a2", HCL: cfg + `
|
||||
service = {
|
||||
id = "redis"
|
||||
name = "redis"
|
||||
tags = ["bar"]
|
||||
port = 9000
|
||||
connect {
|
||||
proxy {
|
||||
command = ["/bin/sleep", "3600"]
|
||||
}
|
||||
}
|
||||
}
|
||||
`, DataDir: dataDir}
|
||||
a2.Start()
|
||||
defer a2.Shutdown()
|
||||
|
||||
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash(proxyID))
|
||||
_, err := os.Stat(file)
|
||||
require.Error(err, "should have removed remote state")
|
||||
|
||||
result := a2.State.Proxy(proxyID)
|
||||
require.NotNil(result)
|
||||
require.Equal(proxy.Command, result.Proxy.Command)
|
||||
}
|
||||
|
||||
func TestAgent_PersistCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
|
||||
|
@ -1629,6 +1865,96 @@ func TestAgent_unloadServices(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAgent_loadProxies(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), `
|
||||
service = {
|
||||
id = "rabbitmq"
|
||||
name = "rabbitmq"
|
||||
port = 5672
|
||||
token = "abc123"
|
||||
connect {
|
||||
proxy {
|
||||
config {
|
||||
bind_port = 1234
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`)
|
||||
defer a.Shutdown()
|
||||
|
||||
services := a.State.Services()
|
||||
if _, ok := services["rabbitmq"]; !ok {
|
||||
t.Fatalf("missing service")
|
||||
}
|
||||
if token := a.State.ServiceToken("rabbitmq"); token != "abc123" {
|
||||
t.Fatalf("bad: %s", token)
|
||||
}
|
||||
if _, ok := services["rabbitmq-proxy"]; !ok {
|
||||
t.Fatalf("missing proxy service")
|
||||
}
|
||||
if token := a.State.ServiceToken("rabbitmq-proxy"); token != "abc123" {
|
||||
t.Fatalf("bad: %s", token)
|
||||
}
|
||||
proxies := a.State.Proxies()
|
||||
if _, ok := proxies["rabbitmq-proxy"]; !ok {
|
||||
t.Fatalf("missing proxy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_loadProxies_nilProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), `
|
||||
service = {
|
||||
id = "rabbitmq"
|
||||
name = "rabbitmq"
|
||||
port = 5672
|
||||
token = "abc123"
|
||||
connect {
|
||||
}
|
||||
}
|
||||
`)
|
||||
defer a.Shutdown()
|
||||
|
||||
services := a.State.Services()
|
||||
require.Contains(t, services, "rabbitmq")
|
||||
require.Equal(t, "abc123", a.State.ServiceToken("rabbitmq"))
|
||||
require.NotContains(t, services, "rabbitme-proxy")
|
||||
require.Empty(t, a.State.Proxies())
|
||||
}
|
||||
|
||||
func TestAgent_unloadProxies(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), `
|
||||
service = {
|
||||
id = "rabbitmq"
|
||||
name = "rabbitmq"
|
||||
port = 5672
|
||||
token = "abc123"
|
||||
connect {
|
||||
proxy {
|
||||
config {
|
||||
bind_port = 1234
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`)
|
||||
defer a.Shutdown()
|
||||
|
||||
// Sanity check it's there
|
||||
require.NotNil(t, a.State.Proxy("rabbitmq-proxy"))
|
||||
|
||||
// Unload all proxies
|
||||
if err := a.unloadProxies(); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if len(a.State.Proxies()) != 0 {
|
||||
t.Fatalf("should have unloaded proxies")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_Service_MaintenanceMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
|
@ -2179,6 +2505,18 @@ func TestAgent_reloadWatches(t *testing.T) {
|
|||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
|
||||
// Should fail to reload with connect watches
|
||||
newConf.Watches = []map[string]interface{}{
|
||||
{
|
||||
"type": "connect_roots",
|
||||
"key": "asdf",
|
||||
"args": []interface{}{"ls"},
|
||||
},
|
||||
}
|
||||
if err := a.reloadWatches(&newConf); err == nil || !strings.Contains(err.Error(), "not allowed in agent config") {
|
||||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
|
||||
// Should still succeed with only HTTPS addresses
|
||||
newConf.HTTPSAddrs = newConf.HTTPAddrs
|
||||
newConf.HTTPAddrs = make([]net.Addr, 0)
|
||||
|
@ -2226,3 +2564,217 @@ func TestAgent_reloadWatchesHTTPS(t *testing.T) {
|
|||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), `
|
||||
node_name = "node1"
|
||||
|
||||
connect {
|
||||
proxy_defaults {
|
||||
exec_mode = "script"
|
||||
daemon_command = ["foo", "bar"]
|
||||
script_command = ["bar", "foo"]
|
||||
}
|
||||
}
|
||||
|
||||
ports {
|
||||
proxy_min_port = 20000
|
||||
proxy_max_port = 20000
|
||||
}
|
||||
`)
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register a target service we can use
|
||||
reg := &structs.NodeService{
|
||||
Service: "web",
|
||||
Port: 8080,
|
||||
}
|
||||
require.NoError(t, a.AddService(reg, nil, false, ""))
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
proxy, wantProxy *structs.ConnectManagedProxy
|
||||
wantTCPCheck string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
desc: "basic proxy adding, unregistered service",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
TargetServiceID: "db", // non-existent service.
|
||||
},
|
||||
// Target service must be registered.
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
desc: "basic proxy adding, registered service",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
desc: "default global exec mode",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantProxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeScript,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
desc: "default daemon command",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantProxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"foo", "bar"},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
desc: "default script command",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeScript,
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantProxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeScript,
|
||||
Command: []string{"bar", "foo"},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
desc: "managed proxy with custom bind port",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"bind_address": "127.10.10.10",
|
||||
"bind_port": 1234,
|
||||
},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantTCPCheck: "127.10.10.10:1234",
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
{
|
||||
// This test is necessary since JSON and HCL both will parse
|
||||
// numbers as a float64.
|
||||
desc: "managed proxy with custom bind port (float64)",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"bind_address": "127.10.10.10",
|
||||
"bind_port": float64(1234),
|
||||
},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantTCPCheck: "127.10.10.10:1234",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
err := a.AddProxy(tt.proxy, false, "")
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
// Test the ID was created as we expect.
|
||||
got := a.State.Proxy("web-proxy")
|
||||
wantProxy := tt.wantProxy
|
||||
if wantProxy == nil {
|
||||
wantProxy = tt.proxy
|
||||
}
|
||||
wantProxy.ProxyService = got.Proxy.ProxyService
|
||||
require.Equal(wantProxy, got.Proxy)
|
||||
|
||||
// Ensure a TCP check was created for the service.
|
||||
gotCheck := a.State.Check("service:web-proxy")
|
||||
require.NotNil(gotCheck)
|
||||
require.Equal("Connect Proxy Listening", gotCheck.Name)
|
||||
|
||||
// Confusingly, a.State.Check("service:web-proxy") will return the state
|
||||
// but it's Definition field will be empty. This appears to be expected
|
||||
// when adding Checks as part of `AddService`. Notice how `AddService`
|
||||
// tests in this file don't assert on that state but instead look at the
|
||||
// agent's check state directly to ensure the right thing was registered.
|
||||
// We'll do the same for now.
|
||||
gotTCP, ok := a.checkTCPs["service:web-proxy"]
|
||||
require.True(ok)
|
||||
wantTCPCheck := tt.wantTCPCheck
|
||||
if wantTCPCheck == "" {
|
||||
wantTCPCheck = "127.0.0.1:20000"
|
||||
}
|
||||
require.Equal(wantTCPCheck, gotTCP.TCP)
|
||||
require.Equal(10*time.Second, gotTCP.Interval)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), `
|
||||
node_name = "node1"
|
||||
`)
|
||||
defer a.Shutdown()
|
||||
require := require.New(t)
|
||||
|
||||
// Register a target service we can use
|
||||
reg := &structs.NodeService{
|
||||
Service: "web",
|
||||
Port: 8080,
|
||||
}
|
||||
require.NoError(a.AddService(reg, nil, false, ""))
|
||||
|
||||
// Add a proxy for web
|
||||
pReg := &structs.ConnectManagedProxy{
|
||||
TargetServiceID: "web",
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"foo"},
|
||||
}
|
||||
require.NoError(a.AddProxy(pReg, false, ""))
|
||||
|
||||
// Test the ID was created as we expect.
|
||||
gotProxy := a.State.Proxy("web-proxy")
|
||||
require.NotNil(gotProxy)
|
||||
|
||||
err := a.RemoveProxy("web-proxy", false)
|
||||
require.NoError(err)
|
||||
|
||||
gotProxy = a.State.Proxy("web-proxy")
|
||||
require.Nil(gotProxy)
|
||||
require.Nil(a.State.Service("web-proxy"), "web-proxy service")
|
||||
|
||||
// Removing invalid proxy should be an error
|
||||
err = a.RemoveProxy("foobar", false)
|
||||
require.Error(err)
|
||||
}
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,240 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// Recommended name for registration.
|
||||
const ConnectCALeafName = "connect-ca-leaf"
|
||||
|
||||
// ConnectCALeaf supports fetching and generating Connect leaf
|
||||
// certificates.
|
||||
type ConnectCALeaf struct {
|
||||
caIndex uint64 // Current index for CA roots
|
||||
|
||||
issuedCertsLock sync.RWMutex
|
||||
issuedCerts map[string]*structs.IssuedCert
|
||||
|
||||
RPC RPC // RPC client for remote requests
|
||||
Cache *cache.Cache // Cache that has CA root certs via ConnectCARoot
|
||||
}
|
||||
|
||||
func (c *ConnectCALeaf) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
|
||||
var result cache.FetchResult
|
||||
|
||||
// Get the correct type
|
||||
reqReal, ok := req.(*ConnectCALeafRequest)
|
||||
if !ok {
|
||||
return result, fmt.Errorf(
|
||||
"Internal cache failure: request wrong type: %T", req)
|
||||
}
|
||||
|
||||
// This channel watches our overall timeout. The other goroutines
|
||||
// launched in this function should end all around the same time so
|
||||
// they clean themselves up.
|
||||
timeoutCh := time.After(opts.Timeout)
|
||||
|
||||
// Kick off the goroutine that waits for new CA roots. The channel buffer
|
||||
// is so that the goroutine doesn't block forever if we return for other
|
||||
// reasons.
|
||||
newRootCACh := make(chan error, 1)
|
||||
go c.waitNewRootCA(reqReal.Datacenter, newRootCACh, opts.Timeout)
|
||||
|
||||
// Get our prior cert (if we had one) and use that to determine our
|
||||
// expiration time. If no cert exists, we expire immediately since we
|
||||
// need to generate.
|
||||
c.issuedCertsLock.RLock()
|
||||
lastCert := c.issuedCerts[reqReal.Service]
|
||||
c.issuedCertsLock.RUnlock()
|
||||
|
||||
var leafExpiryCh <-chan time.Time
|
||||
if lastCert != nil {
|
||||
// Determine how long we wait until triggering. If we've already
|
||||
// expired, we trigger immediately.
|
||||
if expiryDur := lastCert.ValidBefore.Sub(time.Now()); expiryDur > 0 {
|
||||
leafExpiryCh = time.After(expiryDur - 1*time.Hour)
|
||||
// TODO(mitchellh): 1 hour buffer is hardcoded above
|
||||
}
|
||||
}
|
||||
|
||||
if leafExpiryCh == nil {
|
||||
// If the channel is still nil then it means we need to generate
|
||||
// a cert no matter what: we either don't have an existing one or
|
||||
// it is expired.
|
||||
leafExpiryCh = time.After(0)
|
||||
}
|
||||
|
||||
// Block on the events that wake us up.
|
||||
select {
|
||||
case <-timeoutCh:
|
||||
// On a timeout, we just return the empty result and no error.
|
||||
// It isn't an error to timeout, its just the limit of time the
|
||||
// caching system wants us to block for. By returning an empty result
|
||||
// the caching system will ignore.
|
||||
return result, nil
|
||||
|
||||
case err := <-newRootCACh:
|
||||
// A new root CA triggers us to refresh the leaf certificate.
|
||||
// If there was an error while getting the root CA then we return.
|
||||
// Otherwise, we leave the select statement and move to generation.
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
case <-leafExpiryCh:
|
||||
// The existing leaf certificate is expiring soon, so we generate a
|
||||
// new cert with a healthy overlapping validity period (determined
|
||||
// by the above channel).
|
||||
}
|
||||
|
||||
// Need to lookup RootCAs response to discover trust domain. First just lookup
|
||||
// with no blocking info - this should be a cache hit most of the time.
|
||||
rawRoots, _, err := c.Cache.Get(ConnectCARootName, &structs.DCSpecificRequest{
|
||||
Datacenter: reqReal.Datacenter,
|
||||
})
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
roots, ok := rawRoots.(*structs.IndexedCARoots)
|
||||
if !ok {
|
||||
return result, errors.New("invalid RootCA response type")
|
||||
}
|
||||
if roots.TrustDomain == "" {
|
||||
return result, errors.New("cluster has no CA bootstrapped")
|
||||
}
|
||||
|
||||
// Build the service ID
|
||||
serviceID := &connect.SpiffeIDService{
|
||||
Host: roots.TrustDomain,
|
||||
Datacenter: reqReal.Datacenter,
|
||||
Namespace: "default",
|
||||
Service: reqReal.Service,
|
||||
}
|
||||
|
||||
// Create a new private key
|
||||
pk, pkPEM, err := connect.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Create a CSR.
|
||||
csr, err := connect.CreateCSR(serviceID, pk)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Request signing
|
||||
var reply structs.IssuedCert
|
||||
args := structs.CASignRequest{
|
||||
WriteRequest: structs.WriteRequest{Token: reqReal.Token},
|
||||
Datacenter: reqReal.Datacenter,
|
||||
CSR: csr,
|
||||
}
|
||||
if err := c.RPC.RPC("ConnectCA.Sign", &args, &reply); err != nil {
|
||||
return result, err
|
||||
}
|
||||
reply.PrivateKeyPEM = pkPEM
|
||||
|
||||
// Lock the issued certs map so we can insert it. We only insert if
|
||||
// we didn't happen to get a newer one. This should never happen since
|
||||
// the Cache should ensure only one Fetch per service, but we sanity
|
||||
// check just in case.
|
||||
c.issuedCertsLock.Lock()
|
||||
defer c.issuedCertsLock.Unlock()
|
||||
lastCert = c.issuedCerts[reqReal.Service]
|
||||
if lastCert == nil || lastCert.ModifyIndex < reply.ModifyIndex {
|
||||
if c.issuedCerts == nil {
|
||||
c.issuedCerts = make(map[string]*structs.IssuedCert)
|
||||
}
|
||||
|
||||
c.issuedCerts[reqReal.Service] = &reply
|
||||
lastCert = &reply
|
||||
}
|
||||
|
||||
result.Value = lastCert
|
||||
result.Index = lastCert.ModifyIndex
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// waitNewRootCA blocks until a new root CA is available or the timeout is
|
||||
// reached (on timeout ErrTimeout is returned on the channel).
|
||||
func (c *ConnectCALeaf) waitNewRootCA(datacenter string, ch chan<- error,
|
||||
timeout time.Duration) {
|
||||
// We always want to block on at least an initial value. If this isn't
|
||||
minIndex := atomic.LoadUint64(&c.caIndex)
|
||||
if minIndex == 0 {
|
||||
minIndex = 1
|
||||
}
|
||||
|
||||
// Fetch some new roots. This will block until our MinQueryIndex is
|
||||
// matched or the timeout is reached.
|
||||
rawRoots, _, err := c.Cache.Get(ConnectCARootName, &structs.DCSpecificRequest{
|
||||
Datacenter: datacenter,
|
||||
QueryOptions: structs.QueryOptions{
|
||||
MinQueryIndex: minIndex,
|
||||
MaxQueryTime: timeout,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ch <- err
|
||||
return
|
||||
}
|
||||
|
||||
roots, ok := rawRoots.(*structs.IndexedCARoots)
|
||||
if !ok {
|
||||
// This should never happen but we don't want to even risk a panic
|
||||
ch <- fmt.Errorf(
|
||||
"internal error: CA root cache returned bad type: %T", rawRoots)
|
||||
return
|
||||
}
|
||||
|
||||
// We do a loop here because there can be multiple waitNewRootCA calls
|
||||
// happening simultaneously. Each Fetch kicks off one call. These are
|
||||
// multiplexed through Cache.Get which should ensure we only ever
|
||||
// actually make a single RPC call. However, there is a race to set
|
||||
// the caIndex field so do a basic CAS loop here.
|
||||
for {
|
||||
// We only set our index if its newer than what is previously set.
|
||||
old := atomic.LoadUint64(&c.caIndex)
|
||||
if old == roots.Index || old > roots.Index {
|
||||
break
|
||||
}
|
||||
|
||||
// Set the new index atomically. If the caIndex value changed
|
||||
// in the meantime, retry.
|
||||
if atomic.CompareAndSwapUint64(&c.caIndex, old, roots.Index) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger the channel since we updated.
|
||||
ch <- nil
|
||||
}
|
||||
|
||||
// ConnectCALeafRequest is the cache.Request implementation for the
|
||||
// ConnectCALeaf cache type. This is implemented here and not in structs
|
||||
// since this is only used for cache-related requests and not forwarded
|
||||
// directly to any Consul servers.
|
||||
type ConnectCALeafRequest struct {
|
||||
Token string
|
||||
Datacenter string
|
||||
Service string // Service name, not ID
|
||||
MinQueryIndex uint64
|
||||
}
|
||||
|
||||
func (r *ConnectCALeafRequest) CacheInfo() cache.RequestInfo {
|
||||
return cache.RequestInfo{
|
||||
Token: r.Token,
|
||||
Key: r.Service,
|
||||
Datacenter: r.Datacenter,
|
||||
MinIndex: r.MinQueryIndex,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,209 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test that after an initial signing, new CA roots (new ID) will
|
||||
// trigger a blocking query to execute.
|
||||
func TestConnectCALeaf_changingRoots(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
|
||||
typ, rootsCh := testCALeafType(t, rpc)
|
||||
defer close(rootsCh)
|
||||
rootsCh <- structs.IndexedCARoots{
|
||||
ActiveRootID: "1",
|
||||
TrustDomain: "fake-trust-domain.consul",
|
||||
QueryMeta: structs.QueryMeta{Index: 1},
|
||||
}
|
||||
|
||||
// Instrument ConnectCA.Sign to return signed cert
|
||||
var resp *structs.IssuedCert
|
||||
var idx uint64
|
||||
rpc.On("RPC", "ConnectCA.Sign", mock.Anything, mock.Anything).Return(nil).
|
||||
Run(func(args mock.Arguments) {
|
||||
reply := args.Get(2).(*structs.IssuedCert)
|
||||
reply.ValidBefore = time.Now().Add(12 * time.Hour)
|
||||
reply.CreateIndex = atomic.AddUint64(&idx, 1)
|
||||
reply.ModifyIndex = reply.CreateIndex
|
||||
resp = reply
|
||||
})
|
||||
|
||||
// We'll reuse the fetch options and request
|
||||
opts := cache.FetchOptions{MinIndex: 0, Timeout: 10 * time.Second}
|
||||
req := &ConnectCALeafRequest{Datacenter: "dc1", Service: "web"}
|
||||
|
||||
// First fetch should return immediately
|
||||
fetchCh := TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("shouldn't block waiting for fetch")
|
||||
case result := <-fetchCh:
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 1,
|
||||
}, result)
|
||||
}
|
||||
|
||||
// Second fetch should block with set index
|
||||
fetchCh = TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case result := <-fetchCh:
|
||||
t.Fatalf("should not return: %#v", result)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Let's send in new roots, which should trigger the sign req
|
||||
rootsCh <- structs.IndexedCARoots{
|
||||
ActiveRootID: "2",
|
||||
TrustDomain: "fake-trust-domain.consul",
|
||||
QueryMeta: structs.QueryMeta{Index: 2},
|
||||
}
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("shouldn't block waiting for fetch")
|
||||
case result := <-fetchCh:
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 2,
|
||||
}, result)
|
||||
}
|
||||
|
||||
// Third fetch should block
|
||||
fetchCh = TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case result := <-fetchCh:
|
||||
t.Fatalf("should not return: %#v", result)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that after an initial signing, an expiringLeaf will trigger a
|
||||
// blocking query to resign.
|
||||
func TestConnectCALeaf_expiringLeaf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
|
||||
typ, rootsCh := testCALeafType(t, rpc)
|
||||
defer close(rootsCh)
|
||||
rootsCh <- structs.IndexedCARoots{
|
||||
ActiveRootID: "1",
|
||||
TrustDomain: "fake-trust-domain.consul",
|
||||
QueryMeta: structs.QueryMeta{Index: 1},
|
||||
}
|
||||
|
||||
// Instrument ConnectCA.Sign to
|
||||
var resp *structs.IssuedCert
|
||||
var idx uint64
|
||||
rpc.On("RPC", "ConnectCA.Sign", mock.Anything, mock.Anything).Return(nil).
|
||||
Run(func(args mock.Arguments) {
|
||||
reply := args.Get(2).(*structs.IssuedCert)
|
||||
reply.CreateIndex = atomic.AddUint64(&idx, 1)
|
||||
reply.ModifyIndex = reply.CreateIndex
|
||||
|
||||
// This sets the validity to 0 on the first call, and
|
||||
// 12 hours+ on subsequent calls. This means that our first
|
||||
// cert expires immediately.
|
||||
reply.ValidBefore = time.Now().Add((12 * time.Hour) *
|
||||
time.Duration(reply.CreateIndex-1))
|
||||
|
||||
resp = reply
|
||||
})
|
||||
|
||||
// We'll reuse the fetch options and request
|
||||
opts := cache.FetchOptions{MinIndex: 0, Timeout: 10 * time.Second}
|
||||
req := &ConnectCALeafRequest{Datacenter: "dc1", Service: "web"}
|
||||
|
||||
// First fetch should return immediately
|
||||
fetchCh := TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("shouldn't block waiting for fetch")
|
||||
case result := <-fetchCh:
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 1,
|
||||
}, result)
|
||||
}
|
||||
|
||||
// Second fetch should return immediately despite there being
|
||||
// no updated CA roots, because we issued an expired cert.
|
||||
fetchCh = TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("shouldn't block waiting for fetch")
|
||||
case result := <-fetchCh:
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 2,
|
||||
}, result)
|
||||
}
|
||||
|
||||
// Third fetch should block since the cert is not expiring and
|
||||
// we also didn't update CA certs.
|
||||
fetchCh = TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case result := <-fetchCh:
|
||||
t.Fatalf("should not return: %#v", result)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
// testCALeafType returns a *ConnectCALeaf that is pre-configured to
|
||||
// use the given RPC implementation for "ConnectCA.Sign" operations.
|
||||
func testCALeafType(t *testing.T, rpc RPC) (*ConnectCALeaf, chan structs.IndexedCARoots) {
|
||||
// This creates an RPC implementation that will block until the
|
||||
// value is sent on the channel. This lets us control when the
|
||||
// next values show up.
|
||||
rootsCh := make(chan structs.IndexedCARoots, 10)
|
||||
rootsRPC := &testGatedRootsRPC{ValueCh: rootsCh}
|
||||
|
||||
// Create a cache
|
||||
c := cache.TestCache(t)
|
||||
c.RegisterType(ConnectCARootName, &ConnectCARoot{RPC: rootsRPC}, &cache.RegisterOptions{
|
||||
// Disable refresh so that the gated channel controls the
|
||||
// request directly. Otherwise, we get background refreshes and
|
||||
// it screws up the ordering of the channel reads of the
|
||||
// testGatedRootsRPC implementation.
|
||||
Refresh: false,
|
||||
})
|
||||
|
||||
// Create the leaf type
|
||||
return &ConnectCALeaf{RPC: rpc, Cache: c}, rootsCh
|
||||
}
|
||||
|
||||
// testGatedRootsRPC will send each subsequent value on the channel as the
|
||||
// RPC response, blocking if it is waiting for a value on the channel. This
|
||||
// can be used to control when background fetches are returned and what they
|
||||
// return.
|
||||
//
|
||||
// This should be used with Refresh = false for the registration options so
|
||||
// automatic refreshes don't mess up the channel read ordering.
|
||||
type testGatedRootsRPC struct {
|
||||
ValueCh chan structs.IndexedCARoots
|
||||
}
|
||||
|
||||
func (r *testGatedRootsRPC) RPC(method string, args interface{}, reply interface{}) error {
|
||||
if method != "ConnectCA.Roots" {
|
||||
return fmt.Errorf("invalid RPC method: %s", method)
|
||||
}
|
||||
|
||||
replyReal := reply.(*structs.IndexedCARoots)
|
||||
*replyReal = <-r.ValueCh
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// Recommended name for registration.
|
||||
const ConnectCARootName = "connect-ca-root"
|
||||
|
||||
// ConnectCARoot supports fetching the Connect CA roots. This is a
|
||||
// straightforward cache type since it only has to block on the given
|
||||
// index and return the data.
|
||||
type ConnectCARoot struct {
|
||||
RPC RPC
|
||||
}
|
||||
|
||||
func (c *ConnectCARoot) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
|
||||
var result cache.FetchResult
|
||||
|
||||
// The request should be a DCSpecificRequest.
|
||||
reqReal, ok := req.(*structs.DCSpecificRequest)
|
||||
if !ok {
|
||||
return result, fmt.Errorf(
|
||||
"Internal cache failure: request wrong type: %T", req)
|
||||
}
|
||||
|
||||
// Set the minimum query index to our current index so we block
|
||||
reqReal.QueryOptions.MinQueryIndex = opts.MinIndex
|
||||
reqReal.QueryOptions.MaxQueryTime = opts.Timeout
|
||||
|
||||
// Fetch
|
||||
var reply structs.IndexedCARoots
|
||||
if err := c.RPC.RPC("ConnectCA.Roots", reqReal, &reply); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
result.Value = &reply
|
||||
result.Index = reply.QueryMeta.Index
|
||||
return result, nil
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnectCARoot(t *testing.T) {
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
typ := &ConnectCARoot{RPC: rpc}
|
||||
|
||||
// Expect the proper RPC call. This also sets the expected value
|
||||
// since that is return-by-pointer in the arguments.
|
||||
var resp *structs.IndexedCARoots
|
||||
rpc.On("RPC", "ConnectCA.Roots", mock.Anything, mock.Anything).Return(nil).
|
||||
Run(func(args mock.Arguments) {
|
||||
req := args.Get(1).(*structs.DCSpecificRequest)
|
||||
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex)
|
||||
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime)
|
||||
|
||||
reply := args.Get(2).(*structs.IndexedCARoots)
|
||||
reply.QueryMeta.Index = 48
|
||||
resp = reply
|
||||
})
|
||||
|
||||
// Fetch
|
||||
result, err := typ.Fetch(cache.FetchOptions{
|
||||
MinIndex: 24,
|
||||
Timeout: 1 * time.Second,
|
||||
}, &structs.DCSpecificRequest{Datacenter: "dc1"})
|
||||
require.Nil(err)
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 48,
|
||||
}, result)
|
||||
}
|
||||
|
||||
func TestConnectCARoot_badReqType(t *testing.T) {
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
typ := &ConnectCARoot{RPC: rpc}
|
||||
|
||||
// Fetch
|
||||
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
|
||||
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
|
||||
require.NotNil(err)
|
||||
require.Contains(err.Error(), "wrong type")
|
||||
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// Recommended name for registration.
|
||||
const IntentionMatchName = "intention-match"
|
||||
|
||||
// IntentionMatch supports fetching the intentions via match queries.
|
||||
type IntentionMatch struct {
|
||||
RPC RPC
|
||||
}
|
||||
|
||||
func (c *IntentionMatch) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
|
||||
var result cache.FetchResult
|
||||
|
||||
// The request should be an IntentionQueryRequest.
|
||||
reqReal, ok := req.(*structs.IntentionQueryRequest)
|
||||
if !ok {
|
||||
return result, fmt.Errorf(
|
||||
"Internal cache failure: request wrong type: %T", req)
|
||||
}
|
||||
|
||||
// Set the minimum query index to our current index so we block
|
||||
reqReal.MinQueryIndex = opts.MinIndex
|
||||
reqReal.MaxQueryTime = opts.Timeout
|
||||
|
||||
// Fetch
|
||||
var reply structs.IndexedIntentionMatches
|
||||
if err := c.RPC.RPC("Intention.Match", reqReal, &reply); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
result.Value = &reply
|
||||
result.Index = reply.Index
|
||||
return result, nil
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIntentionMatch(t *testing.T) {
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
typ := &IntentionMatch{RPC: rpc}
|
||||
|
||||
// Expect the proper RPC call. This also sets the expected value
|
||||
// since that is return-by-pointer in the arguments.
|
||||
var resp *structs.IndexedIntentionMatches
|
||||
rpc.On("RPC", "Intention.Match", mock.Anything, mock.Anything).Return(nil).
|
||||
Run(func(args mock.Arguments) {
|
||||
req := args.Get(1).(*structs.IntentionQueryRequest)
|
||||
require.Equal(uint64(24), req.MinQueryIndex)
|
||||
require.Equal(1*time.Second, req.MaxQueryTime)
|
||||
|
||||
reply := args.Get(2).(*structs.IndexedIntentionMatches)
|
||||
reply.Index = 48
|
||||
resp = reply
|
||||
})
|
||||
|
||||
// Fetch
|
||||
result, err := typ.Fetch(cache.FetchOptions{
|
||||
MinIndex: 24,
|
||||
Timeout: 1 * time.Second,
|
||||
}, &structs.IntentionQueryRequest{Datacenter: "dc1"})
|
||||
require.NoError(err)
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 48,
|
||||
}, result)
|
||||
}
|
||||
|
||||
func TestIntentionMatch_badReqType(t *testing.T) {
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
typ := &IntentionMatch{RPC: rpc}
|
||||
|
||||
// Fetch
|
||||
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
|
||||
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
|
||||
require.Error(err)
|
||||
require.Contains(err.Error(), "wrong type")
|
||||
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Code generated by mockery v1.0.0
|
||||
package cachetype
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockRPC is an autogenerated mock type for the RPC type
|
||||
type MockRPC struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// RPC provides a mock function with given fields: method, args, reply
|
||||
func (_m *MockRPC) RPC(method string, args interface{}, reply interface{}) error {
|
||||
ret := _m.Called(method, args, reply)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, interface{}, interface{}) error); ok {
|
||||
r0 = rf(method, args, reply)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
package cachetype
|
||||
|
||||
//go:generate mockery -all -inpkg
|
||||
|
||||
// RPC is an interface that an RPC client must implement. This is a helper
|
||||
// interface that is implemented by the agent delegate so that Type
|
||||
// implementations can request RPC access.
|
||||
type RPC interface {
|
||||
RPC(method string, args interface{}, reply interface{}) error
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
)
|
||||
|
||||
// TestRPC returns a mock implementation of the RPC interface.
|
||||
func TestRPC(t testing.T) *MockRPC {
|
||||
// This function is relatively useless but this allows us to perhaps
|
||||
// perform some initialization later.
|
||||
return &MockRPC{}
|
||||
}
|
||||
|
||||
// TestFetchCh returns a channel that returns the result of the Fetch call.
|
||||
// This is useful for testing timing and concurrency with Fetch calls.
|
||||
// Errors will show up as an error type on the resulting channel so a
|
||||
// type switch should be used.
|
||||
func TestFetchCh(
|
||||
t testing.T,
|
||||
typ cache.Type,
|
||||
opts cache.FetchOptions,
|
||||
req cache.Request) <-chan interface{} {
|
||||
resultCh := make(chan interface{})
|
||||
go func() {
|
||||
result, err := typ.Fetch(opts, req)
|
||||
if err != nil {
|
||||
resultCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
resultCh <- result
|
||||
}()
|
||||
|
||||
return resultCh
|
||||
}
|
||||
|
||||
// TestFetchChResult tests that the result from TestFetchCh matches
|
||||
// within a reasonable period of time (it expects it to be "immediate" but
|
||||
// waits some milliseconds).
|
||||
func TestFetchChResult(t testing.T, ch <-chan interface{}, expected interface{}) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case result := <-ch:
|
||||
if err, ok := result.(error); ok {
|
||||
t.Fatalf("Result was error: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Fatalf("Result doesn't match!\n\n%#v\n\n%#v", result, expected)
|
||||
}
|
||||
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
|
@ -0,0 +1,536 @@
|
|||
// Package cache provides caching features for data from a Consul server.
|
||||
//
|
||||
// While this is similar in some ways to the "agent/ae" package, a key
|
||||
// difference is that with anti-entropy, the agent is the authoritative
|
||||
// source so it resolves differences the server may have. With caching (this
|
||||
// package), the server is the authoritative source and we do our best to
|
||||
// balance performance and correctness, depending on the type of data being
|
||||
// requested.
|
||||
//
|
||||
// The types of data that can be cached is configurable via the Type interface.
|
||||
// This allows specialized behavior for certain types of data. Each type of
|
||||
// Consul data (CA roots, leaf certs, intentions, KV, catalog, etc.) will
|
||||
// have to be manually implemented. This usually is not much work, see
|
||||
// the "agent/cache-types" package.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
)
|
||||
|
||||
//go:generate mockery -all -inpkg
|
||||
|
||||
// Constants related to refresh backoff. We probably don't ever need to
|
||||
// make these configurable knobs since they primarily exist to lower load.
|
||||
const (
|
||||
CacheRefreshBackoffMin = 3 // 3 attempts before backing off
|
||||
CacheRefreshMaxWait = 1 * time.Minute // maximum backoff wait time
|
||||
)
|
||||
|
||||
// Cache is a agent-local cache of Consul data. Create a Cache using the
|
||||
// New function. A zero-value Cache is not ready for usage and will result
|
||||
// in a panic.
|
||||
//
|
||||
// The types of data to be cached must be registered via RegisterType. Then,
|
||||
// calls to Get specify the type and a Request implementation. The
|
||||
// implementation of Request is usually done directly on the standard RPC
|
||||
// struct in agent/structs. This API makes cache usage a mostly drop-in
|
||||
// replacement for non-cached RPC calls.
|
||||
//
|
||||
// The cache is partitioned by ACL and datacenter. This allows the cache
|
||||
// to be safe for multi-DC queries and for queries where the data is modified
|
||||
// due to ACLs all without the cache having to have any clever logic, at
|
||||
// the slight expense of a less perfect cache.
|
||||
//
|
||||
// The Cache exposes various metrics via go-metrics. Please view the source
|
||||
// searching for "metrics." to see the various metrics exposed. These can be
|
||||
// used to explore the performance of the cache.
|
||||
type Cache struct {
|
||||
// types stores the list of data types that the cache knows how to service.
|
||||
// These can be dynamically registered with RegisterType.
|
||||
typesLock sync.RWMutex
|
||||
types map[string]typeEntry
|
||||
|
||||
// entries contains the actual cache data. Access to entries and
|
||||
// entriesExpiryHeap must be protected by entriesLock.
|
||||
//
|
||||
// entriesExpiryHeap is a heap of *cacheEntry values ordered by
|
||||
// expiry, with the soonest to expire being first in the list (index 0).
|
||||
//
|
||||
// NOTE(mitchellh): The entry map key is currently a string in the format
|
||||
// of "<DC>/<ACL token>/<Request key>" in order to properly partition
|
||||
// requests to different datacenters and ACL tokens. This format has some
|
||||
// big drawbacks: we can't evict by datacenter, ACL token, etc. For an
|
||||
// initial implementation this works and the tests are agnostic to the
|
||||
// internal storage format so changing this should be possible safely.
|
||||
entriesLock sync.RWMutex
|
||||
entries map[string]cacheEntry
|
||||
entriesExpiryHeap *expiryHeap
|
||||
}
|
||||
|
||||
// typeEntry is a single type that is registered with a Cache.
|
||||
type typeEntry struct {
|
||||
Type Type
|
||||
Opts *RegisterOptions
|
||||
}
|
||||
|
||||
// ResultMeta is returned from Get calls along with the value and can be used
|
||||
// to expose information about the cache status for debugging or testing.
|
||||
type ResultMeta struct {
|
||||
// Return whether or not the request was a cache hit
|
||||
Hit bool
|
||||
}
|
||||
|
||||
// Options are options for the Cache.
|
||||
type Options struct {
|
||||
// Nothing currently, reserved.
|
||||
}
|
||||
|
||||
// New creates a new cache with the given RPC client and reasonable defaults.
|
||||
// Further settings can be tweaked on the returned value.
|
||||
func New(*Options) *Cache {
|
||||
// Initialize the heap. The buffer of 1 is really important because
|
||||
// its possible for the expiry loop to trigger the heap to update
|
||||
// itself and it'd block forever otherwise.
|
||||
h := &expiryHeap{NotifyCh: make(chan struct{}, 1)}
|
||||
heap.Init(h)
|
||||
|
||||
c := &Cache{
|
||||
types: make(map[string]typeEntry),
|
||||
entries: make(map[string]cacheEntry),
|
||||
entriesExpiryHeap: h,
|
||||
}
|
||||
|
||||
// Start the expiry watcher
|
||||
go c.runExpiryLoop()
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// RegisterOptions are options that can be associated with a type being
|
||||
// registered for the cache. This changes the behavior of the cache for
|
||||
// this type.
|
||||
type RegisterOptions struct {
|
||||
// LastGetTTL is the time that the values returned by this type remain
|
||||
// in the cache after the last get operation. If a value isn't accessed
|
||||
// within this duration, the value is purged from the cache and
|
||||
// background refreshing will cease.
|
||||
LastGetTTL time.Duration
|
||||
|
||||
// Refresh configures whether the data is actively refreshed or if
|
||||
// the data is only refreshed on an explicit Get. The default (false)
|
||||
// is to only request data on explicit Get.
|
||||
Refresh bool
|
||||
|
||||
// RefreshTimer is the time between attempting to refresh data.
|
||||
// If this is zero, then data is refreshed immediately when a fetch
|
||||
// is returned.
|
||||
//
|
||||
// RefreshTimeout determines the maximum query time for a refresh
|
||||
// operation. This is specified as part of the query options and is
|
||||
// expected to be implemented by the Type itself.
|
||||
//
|
||||
// Using these values, various "refresh" mechanisms can be implemented:
|
||||
//
|
||||
// * With a high timer duration and a low timeout, a timer-based
|
||||
// refresh can be set that minimizes load on the Consul servers.
|
||||
//
|
||||
// * With a low timer and high timeout duration, a blocking-query-based
|
||||
// refresh can be set so that changes in server data are recognized
|
||||
// within the cache very quickly.
|
||||
//
|
||||
RefreshTimer time.Duration
|
||||
RefreshTimeout time.Duration
|
||||
}
|
||||
|
||||
// RegisterType registers a cacheable type.
|
||||
//
|
||||
// This makes the type available for Get but does not automatically perform
|
||||
// any prefetching. In order to populate the cache, Get must be called.
|
||||
func (c *Cache) RegisterType(n string, typ Type, opts *RegisterOptions) {
|
||||
if opts == nil {
|
||||
opts = &RegisterOptions{}
|
||||
}
|
||||
if opts.LastGetTTL == 0 {
|
||||
opts.LastGetTTL = 72 * time.Hour // reasonable default is days
|
||||
}
|
||||
|
||||
c.typesLock.Lock()
|
||||
defer c.typesLock.Unlock()
|
||||
c.types[n] = typeEntry{Type: typ, Opts: opts}
|
||||
}
|
||||
|
||||
// Get loads the data for the given type and request. If data satisfying the
|
||||
// minimum index is present in the cache, it is returned immediately. Otherwise,
|
||||
// this will block until the data is available or the request timeout is
|
||||
// reached.
|
||||
//
|
||||
// Multiple Get calls for the same Request (matching CacheKey value) will
|
||||
// block on a single network request.
|
||||
//
|
||||
// The timeout specified by the Request will be the timeout on the cache
|
||||
// Get, and does not correspond to the timeout of any background data
|
||||
// fetching. If the timeout is reached before data satisfying the minimum
|
||||
// index is retrieved, the last known value (maybe nil) is returned. No
|
||||
// error is returned on timeout. This matches the behavior of Consul blocking
|
||||
// queries.
|
||||
func (c *Cache) Get(t string, r Request) (interface{}, ResultMeta, error) {
|
||||
info := r.CacheInfo()
|
||||
if info.Key == "" {
|
||||
metrics.IncrCounter([]string{"consul", "cache", "bypass"}, 1)
|
||||
|
||||
// If no key is specified, then we do not cache this request.
|
||||
// Pass directly through to the backend.
|
||||
return c.fetchDirect(t, r)
|
||||
}
|
||||
|
||||
// Get the actual key for our entry
|
||||
key := c.entryKey(&info)
|
||||
|
||||
// First time through
|
||||
first := true
|
||||
|
||||
// timeoutCh for watching our timeout
|
||||
var timeoutCh <-chan time.Time
|
||||
|
||||
RETRY_GET:
|
||||
// Get the current value
|
||||
c.entriesLock.RLock()
|
||||
entry, ok := c.entries[key]
|
||||
c.entriesLock.RUnlock()
|
||||
|
||||
// If we have a current value and the index is greater than the
|
||||
// currently stored index then we return that right away. If the
|
||||
// index is zero and we have something in the cache we accept whatever
|
||||
// we have.
|
||||
if ok && entry.Valid {
|
||||
if info.MinIndex == 0 || info.MinIndex < entry.Index {
|
||||
meta := ResultMeta{}
|
||||
if first {
|
||||
metrics.IncrCounter([]string{"consul", "cache", t, "hit"}, 1)
|
||||
meta.Hit = true
|
||||
}
|
||||
|
||||
// Touch the expiration and fix the heap.
|
||||
c.entriesLock.Lock()
|
||||
entry.Expiry.Reset()
|
||||
c.entriesExpiryHeap.Fix(entry.Expiry)
|
||||
c.entriesLock.Unlock()
|
||||
|
||||
// We purposely do not return an error here since the cache
|
||||
// only works with fetching values that either have a value
|
||||
// or have an error, but not both. The Error may be non-nil
|
||||
// in the entry because of this to note future fetch errors.
|
||||
return entry.Value, meta, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If this isn't our first time through and our last value has an error,
|
||||
// then we return the error. This has the behavior that we don't sit in
|
||||
// a retry loop getting the same error for the entire duration of the
|
||||
// timeout. Instead, we make one effort to fetch a new value, and if
|
||||
// there was an error, we return.
|
||||
if !first && entry.Error != nil {
|
||||
return entry.Value, ResultMeta{}, entry.Error
|
||||
}
|
||||
|
||||
if first {
|
||||
// We increment two different counters for cache misses depending on
|
||||
// whether we're missing because we didn't have the data at all,
|
||||
// or if we're missing because we're blocking on a set index.
|
||||
if info.MinIndex == 0 {
|
||||
metrics.IncrCounter([]string{"consul", "cache", t, "miss_new"}, 1)
|
||||
} else {
|
||||
metrics.IncrCounter([]string{"consul", "cache", t, "miss_block"}, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// No longer our first time through
|
||||
first = false
|
||||
|
||||
// Set our timeout channel if we must
|
||||
if info.Timeout > 0 && timeoutCh == nil {
|
||||
timeoutCh = time.After(info.Timeout)
|
||||
}
|
||||
|
||||
// At this point, we know we either don't have a value at all or the
|
||||
// value we have is too old. We need to wait for new data.
|
||||
waiterCh, err := c.fetch(t, key, r, true, 0)
|
||||
if err != nil {
|
||||
return nil, ResultMeta{}, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-waiterCh:
|
||||
// Our fetch returned, retry the get from the cache
|
||||
goto RETRY_GET
|
||||
|
||||
case <-timeoutCh:
|
||||
// Timeout on the cache read, just return whatever we have.
|
||||
return entry.Value, ResultMeta{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// entryKey returns the key for the entry in the cache. See the note
|
||||
// about the entry key format in the structure docs for Cache.
|
||||
func (c *Cache) entryKey(r *RequestInfo) string {
|
||||
return fmt.Sprintf("%s/%s/%s", r.Datacenter, r.Token, r.Key)
|
||||
}
|
||||
|
||||
// fetch triggers a new background fetch for the given Request. If a
|
||||
// background fetch is already running for a matching Request, the waiter
|
||||
// channel for that request is returned. The effect of this is that there
|
||||
// is only ever one blocking query for any matching requests.
|
||||
//
|
||||
// If allowNew is true then the fetch should create the cache entry
|
||||
// if it doesn't exist. If this is false, then fetch will do nothing
|
||||
// if the entry doesn't exist. This latter case is to support refreshing.
|
||||
func (c *Cache) fetch(t, key string, r Request, allowNew bool, attempt uint) (<-chan struct{}, error) {
|
||||
// Get the type that we're fetching
|
||||
c.typesLock.RLock()
|
||||
tEntry, ok := c.types[t]
|
||||
c.typesLock.RUnlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown type in cache: %s", t)
|
||||
}
|
||||
|
||||
// We acquire a write lock because we may have to set Fetching to true.
|
||||
c.entriesLock.Lock()
|
||||
defer c.entriesLock.Unlock()
|
||||
entry, ok := c.entries[key]
|
||||
|
||||
// If we aren't allowing new values and we don't have an existing value,
|
||||
// return immediately. We return an immediately-closed channel so nothing
|
||||
// blocks.
|
||||
if !ok && !allowNew {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// If we already have an entry and it is actively fetching, then return
|
||||
// the currently active waiter.
|
||||
if ok && entry.Fetching {
|
||||
return entry.Waiter, nil
|
||||
}
|
||||
|
||||
// If we don't have an entry, then create it. The entry must be marked
|
||||
// as invalid so that it isn't returned as a valid value for a zero index.
|
||||
if !ok {
|
||||
entry = cacheEntry{Valid: false, Waiter: make(chan struct{})}
|
||||
}
|
||||
|
||||
// Set that we're fetching to true, which makes it so that future
|
||||
// identical calls to fetch will return the same waiter rather than
|
||||
// perform multiple fetches.
|
||||
entry.Fetching = true
|
||||
c.entries[key] = entry
|
||||
metrics.SetGauge([]string{"consul", "cache", "entries_count"}, float32(len(c.entries)))
|
||||
|
||||
// The actual Fetch must be performed in a goroutine.
|
||||
go func() {
|
||||
// Start building the new entry by blocking on the fetch.
|
||||
result, err := tEntry.Type.Fetch(FetchOptions{
|
||||
MinIndex: entry.Index,
|
||||
Timeout: tEntry.Opts.RefreshTimeout,
|
||||
}, r)
|
||||
|
||||
// Copy the existing entry to start.
|
||||
newEntry := entry
|
||||
newEntry.Fetching = false
|
||||
if result.Value != nil {
|
||||
// A new value was given, so we create a brand new entry.
|
||||
newEntry.Value = result.Value
|
||||
newEntry.Index = result.Index
|
||||
if newEntry.Index < 1 {
|
||||
// Less than one is invalid unless there was an error and in this case
|
||||
// there wasn't since a value was returned. If a badly behaved RPC
|
||||
// returns 0 when it has no data, we might get into a busy loop here. We
|
||||
// set this to minimum of 1 which is safe because no valid user data can
|
||||
// ever be written at raft index 1 due to the bootstrap process for
|
||||
// raft. This insure that any subsequent background refresh request will
|
||||
// always block, but allows the initial request to return immediately
|
||||
// even if there is no data.
|
||||
newEntry.Index = 1
|
||||
}
|
||||
|
||||
// This is a valid entry with a result
|
||||
newEntry.Valid = true
|
||||
}
|
||||
|
||||
// Error handling
|
||||
if err == nil {
|
||||
metrics.IncrCounter([]string{"consul", "cache", "fetch_success"}, 1)
|
||||
metrics.IncrCounter([]string{"consul", "cache", t, "fetch_success"}, 1)
|
||||
|
||||
if result.Index > 0 {
|
||||
// Reset the attempts counter so we don't have any backoff
|
||||
attempt = 0
|
||||
} else {
|
||||
// Result having a zero index is an implicit error case. There was no
|
||||
// actual error but it implies the RPC found in index (nothing written
|
||||
// yet for that type) but didn't take care to return safe "1" index. We
|
||||
// don't want to actually treat it like an error by setting
|
||||
// newEntry.Error to something non-nil, but we should guard against 100%
|
||||
// CPU burn hot loops caused by that case which will never block but
|
||||
// also won't backoff either. So we treat it as a failed attempt so that
|
||||
// at least the failure backoff will save our CPU while still
|
||||
// periodically refreshing so normal service can resume when the servers
|
||||
// actually have something to return from the RPC. If we get in this
|
||||
// state it can be considered a bug in the RPC implementation (to ever
|
||||
// return a zero index) however since it can happen this is a safety net
|
||||
// for the future.
|
||||
attempt++
|
||||
}
|
||||
} else {
|
||||
metrics.IncrCounter([]string{"consul", "cache", "fetch_error"}, 1)
|
||||
metrics.IncrCounter([]string{"consul", "cache", t, "fetch_error"}, 1)
|
||||
|
||||
// Increment attempt counter
|
||||
attempt++
|
||||
|
||||
// Always set the error. We don't override the value here because
|
||||
// if Valid is true, then we can reuse the Value in the case a
|
||||
// specific index isn't requested. However, for blocking queries,
|
||||
// we want Error to be set so that we can return early with the
|
||||
// error.
|
||||
newEntry.Error = err
|
||||
}
|
||||
|
||||
// Create a new waiter that will be used for the next fetch.
|
||||
newEntry.Waiter = make(chan struct{})
|
||||
|
||||
// Set our entry
|
||||
c.entriesLock.Lock()
|
||||
|
||||
// If this is a new entry (not in the heap yet), then setup the
|
||||
// initial expiry information and insert. If we're already in
|
||||
// the heap we do nothing since we're reusing the same entry.
|
||||
if newEntry.Expiry == nil || newEntry.Expiry.HeapIndex == -1 {
|
||||
newEntry.Expiry = &cacheEntryExpiry{
|
||||
Key: key,
|
||||
TTL: tEntry.Opts.LastGetTTL,
|
||||
}
|
||||
newEntry.Expiry.Reset()
|
||||
heap.Push(c.entriesExpiryHeap, newEntry.Expiry)
|
||||
}
|
||||
|
||||
c.entries[key] = newEntry
|
||||
c.entriesLock.Unlock()
|
||||
|
||||
// Trigger the old waiter
|
||||
close(entry.Waiter)
|
||||
|
||||
// If refresh is enabled, run the refresh in due time. The refresh
|
||||
// below might block, but saves us from spawning another goroutine.
|
||||
if tEntry.Opts.Refresh {
|
||||
c.refresh(tEntry.Opts, attempt, t, key, r)
|
||||
}
|
||||
}()
|
||||
|
||||
return entry.Waiter, nil
|
||||
}
|
||||
|
||||
// fetchDirect fetches the given request with no caching. Because this
|
||||
// bypasses the caching entirely, multiple matching requests will result
|
||||
// in multiple actual RPC calls (unlike fetch).
|
||||
func (c *Cache) fetchDirect(t string, r Request) (interface{}, ResultMeta, error) {
|
||||
// Get the type that we're fetching
|
||||
c.typesLock.RLock()
|
||||
tEntry, ok := c.types[t]
|
||||
c.typesLock.RUnlock()
|
||||
if !ok {
|
||||
return nil, ResultMeta{}, fmt.Errorf("unknown type in cache: %s", t)
|
||||
}
|
||||
|
||||
// Fetch it with the min index specified directly by the request.
|
||||
result, err := tEntry.Type.Fetch(FetchOptions{
|
||||
MinIndex: r.CacheInfo().MinIndex,
|
||||
}, r)
|
||||
if err != nil {
|
||||
return nil, ResultMeta{}, err
|
||||
}
|
||||
|
||||
// Return the result and ignore the rest
|
||||
return result.Value, ResultMeta{}, nil
|
||||
}
|
||||
|
||||
// refresh triggers a fetch for a specific Request according to the
|
||||
// registration options.
|
||||
func (c *Cache) refresh(opts *RegisterOptions, attempt uint, t string, key string, r Request) {
|
||||
// Sanity-check, we should not schedule anything that has refresh disabled
|
||||
if !opts.Refresh {
|
||||
return
|
||||
}
|
||||
|
||||
// If we're over the attempt minimum, start an exponential backoff.
|
||||
if attempt > CacheRefreshBackoffMin {
|
||||
waitTime := (1 << (attempt - CacheRefreshBackoffMin)) * time.Second
|
||||
if waitTime > CacheRefreshMaxWait {
|
||||
waitTime = CacheRefreshMaxWait
|
||||
}
|
||||
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
|
||||
// If we have a timer, wait for it
|
||||
if opts.RefreshTimer > 0 {
|
||||
time.Sleep(opts.RefreshTimer)
|
||||
}
|
||||
|
||||
// Trigger. The "allowNew" field is false because in the time we were
|
||||
// waiting to refresh we may have expired and got evicted. If that
|
||||
// happened, we don't want to create a new entry.
|
||||
c.fetch(t, key, r, false, attempt)
|
||||
}
|
||||
|
||||
// runExpiryLoop is a blocking function that watches the expiration
|
||||
// heap and invalidates entries that have expired.
|
||||
func (c *Cache) runExpiryLoop() {
|
||||
var expiryTimer *time.Timer
|
||||
for {
|
||||
// If we have a previous timer, stop it.
|
||||
if expiryTimer != nil {
|
||||
expiryTimer.Stop()
|
||||
}
|
||||
|
||||
// Get the entry expiring soonest
|
||||
var entry *cacheEntryExpiry
|
||||
var expiryCh <-chan time.Time
|
||||
c.entriesLock.RLock()
|
||||
if len(c.entriesExpiryHeap.Entries) > 0 {
|
||||
entry = c.entriesExpiryHeap.Entries[0]
|
||||
expiryTimer = time.NewTimer(entry.Expires.Sub(time.Now()))
|
||||
expiryCh = expiryTimer.C
|
||||
}
|
||||
c.entriesLock.RUnlock()
|
||||
|
||||
select {
|
||||
case <-c.entriesExpiryHeap.NotifyCh:
|
||||
// Entries changed, so the heap may have changed. Restart loop.
|
||||
|
||||
case <-expiryCh:
|
||||
c.entriesLock.Lock()
|
||||
|
||||
// Entry expired! Remove it.
|
||||
delete(c.entries, entry.Key)
|
||||
heap.Remove(c.entriesExpiryHeap, entry.HeapIndex)
|
||||
|
||||
// This is subtle but important: if we race and simultaneously
|
||||
// evict and fetch a new value, then we set this to -1 to
|
||||
// have it treated as a new value so that the TTL is extended.
|
||||
entry.HeapIndex = -1
|
||||
|
||||
// Set some metrics
|
||||
metrics.IncrCounter([]string{"consul", "cache", "evict_expired"}, 1)
|
||||
metrics.SetGauge([]string{"consul", "cache", "entries_count"}, float32(len(c.entries)))
|
||||
|
||||
c.entriesLock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,760 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test a basic Get with no indexes (and therefore no blocking queries).
|
||||
func TestCacheGet_noIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 42}, nil).Times(1)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Get, should not fetch since we already have a satisfying value
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.True(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test a basic Get with no index and a failed fetch.
|
||||
func TestCacheGet_initError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
fetcherr := fmt.Errorf("error")
|
||||
typ.Static(FetchResult{}, fetcherr).Times(2)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.Error(err)
|
||||
require.Nil(result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Get, should fetch again since our last fetch was an error
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.Error(err)
|
||||
require.Nil(result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test a Get with a request that returns a blank cache key. This should
|
||||
// force a backend request and skip the cache entirely.
|
||||
func TestCacheGet_blankCacheKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 42}, nil).Times(2)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: ""})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Get, should not fetch since we already have a satisfying value
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test that Get blocks on the initial value
|
||||
func TestCacheGet_blockingInitSameKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
triggerCh := make(chan time.Time)
|
||||
typ.Static(FetchResult{Value: 42}, nil).WaitUntil(triggerCh).Times(1)
|
||||
|
||||
// Perform multiple gets
|
||||
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
|
||||
// They should block
|
||||
select {
|
||||
case <-getCh1:
|
||||
t.Fatal("should block (ch1)")
|
||||
case <-getCh2:
|
||||
t.Fatal("should block (ch2)")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Trigger it
|
||||
close(triggerCh)
|
||||
|
||||
// Should return
|
||||
TestCacheGetChResult(t, getCh1, 42)
|
||||
TestCacheGetChResult(t, getCh2, 42)
|
||||
}
|
||||
|
||||
// Test that Get with different cache keys both block on initial value
|
||||
// but that the fetches were both properly called.
|
||||
func TestCacheGet_blockingInitDiffKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Keep track of the keys
|
||||
var keysLock sync.Mutex
|
||||
var keys []string
|
||||
|
||||
// Configure the type
|
||||
triggerCh := make(chan time.Time)
|
||||
typ.Static(FetchResult{Value: 42}, nil).
|
||||
WaitUntil(triggerCh).
|
||||
Times(2).
|
||||
Run(func(args mock.Arguments) {
|
||||
keysLock.Lock()
|
||||
defer keysLock.Unlock()
|
||||
keys = append(keys, args.Get(1).(Request).CacheInfo().Key)
|
||||
})
|
||||
|
||||
// Perform multiple gets
|
||||
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "goodbye"}))
|
||||
|
||||
// They should block
|
||||
select {
|
||||
case <-getCh1:
|
||||
t.Fatal("should block (ch1)")
|
||||
case <-getCh2:
|
||||
t.Fatal("should block (ch2)")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Trigger it
|
||||
close(triggerCh)
|
||||
|
||||
// Should return both!
|
||||
TestCacheGetChResult(t, getCh1, 42)
|
||||
TestCacheGetChResult(t, getCh2, 42)
|
||||
|
||||
// Verify proper keys
|
||||
sort.Strings(keys)
|
||||
require.Equal([]string{"goodbye", "hello"}, keys)
|
||||
}
|
||||
|
||||
// Test a get with an index set will wait until an index that is higher
|
||||
// is set in the cache.
|
||||
func TestCacheGet_blockingIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
triggerCh := make(chan time.Time)
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 42, Index: 6}, nil).WaitUntil(triggerCh)
|
||||
|
||||
// Fetch should block
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Key: "hello", MinIndex: 5}))
|
||||
|
||||
// Should block
|
||||
select {
|
||||
case <-resultCh:
|
||||
t.Fatal("should block")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Wait a bit
|
||||
close(triggerCh)
|
||||
|
||||
// Should return
|
||||
TestCacheGetChResult(t, resultCh, 42)
|
||||
}
|
||||
|
||||
// Test a get with an index set will timeout if the fetch doesn't return
|
||||
// anything.
|
||||
func TestCacheGet_blockingIndexTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
triggerCh := make(chan time.Time)
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 42, Index: 6}, nil).WaitUntil(triggerCh)
|
||||
|
||||
// Fetch should block
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Key: "hello", MinIndex: 5, Timeout: 200 * time.Millisecond}))
|
||||
|
||||
// Should block
|
||||
select {
|
||||
case <-resultCh:
|
||||
t.Fatal("should block")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Should return after more of the timeout
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
require.Equal(t, 12, result)
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
t.Fatal("should've returned")
|
||||
}
|
||||
}
|
||||
|
||||
// Test a get with an index set with requests returning an error
|
||||
// will return that error.
|
||||
func TestCacheGet_blockingIndexError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
var retries uint32
|
||||
fetchErr := fmt.Errorf("test fetch error")
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: nil, Index: 5}, fetchErr).Run(func(args mock.Arguments) {
|
||||
atomic.AddUint32(&retries, 1)
|
||||
})
|
||||
|
||||
// First good fetch to populate catch
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Fetch should not block and should return error
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Key: "hello", MinIndex: 7, Timeout: 1 * time.Minute}))
|
||||
TestCacheGetChResult(t, resultCh, nil)
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check the number
|
||||
actual := atomic.LoadUint32(&retries)
|
||||
require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
|
||||
}
|
||||
|
||||
// Test that if a Type returns an empty value on Fetch that the previous
|
||||
// value is preserved.
|
||||
func TestCacheGet_emptyFetchResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 42, Index: 1}, nil).Times(1)
|
||||
typ.Static(FetchResult{Value: nil}, nil)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Get, should not fetch since we already have a satisfying value
|
||||
req = TestRequest(t, RequestInfo{
|
||||
Key: "hello", MinIndex: 1, Timeout: 100 * time.Millisecond})
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test that a type registered with a periodic refresh will perform
|
||||
// that refresh after the timer is up.
|
||||
func TestCacheGet_periodicRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
Refresh: true,
|
||||
RefreshTimer: 100 * time.Millisecond,
|
||||
RefreshTimeout: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// This is a bit weird, but we do this to ensure that the final
|
||||
// call to the Fetch (if it happens, depends on timing) just blocks.
|
||||
triggerCh := make(chan time.Time)
|
||||
defer close(triggerCh)
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 12, Index: 5}, nil).WaitUntil(triggerCh)
|
||||
|
||||
// Fetch should block
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Fetch again almost immediately should return old result
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Wait for the timer
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 12)
|
||||
}
|
||||
|
||||
// Test that a type registered with a periodic refresh will perform
|
||||
// that refresh after the timer is up.
|
||||
func TestCacheGet_periodicRefreshMultiple(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
Refresh: true,
|
||||
RefreshTimer: 0 * time.Millisecond,
|
||||
RefreshTimeout: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// This is a bit weird, but we do this to ensure that the final
|
||||
// call to the Fetch (if it happens, depends on timing) just blocks.
|
||||
trigger := make([]chan time.Time, 3)
|
||||
for i := range trigger {
|
||||
trigger[i] = make(chan time.Time)
|
||||
}
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once().WaitUntil(trigger[0])
|
||||
typ.Static(FetchResult{Value: 24, Index: 6}, nil).Once().WaitUntil(trigger[1])
|
||||
typ.Static(FetchResult{Value: 42, Index: 7}, nil).WaitUntil(trigger[2])
|
||||
|
||||
// Fetch should block
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Fetch again almost immediately should return old result
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Trigger the next, sleep a bit, and verify we get the next result
|
||||
close(trigger[0])
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 12)
|
||||
|
||||
// Trigger the next, sleep a bit, and verify we get the next result
|
||||
close(trigger[1])
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 24)
|
||||
}
|
||||
|
||||
// Test that a refresh performs a backoff.
|
||||
func TestCacheGet_periodicRefreshErrorBackoff(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
Refresh: true,
|
||||
RefreshTimer: 0,
|
||||
RefreshTimeout: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Configure the type
|
||||
var retries uint32
|
||||
fetchErr := fmt.Errorf("test fetch error")
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: nil, Index: 5}, fetchErr).Run(func(args mock.Arguments) {
|
||||
atomic.AddUint32(&retries, 1)
|
||||
})
|
||||
|
||||
// Fetch
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Sleep a bit. The refresh will quietly fail in the background. What we
|
||||
// want to verify is that it doesn't retry too much. "Too much" is hard
|
||||
// to measure since its CPU dependent if this test is failing. But due
|
||||
// to the short sleep below, we can calculate about what we'd expect if
|
||||
// backoff IS working.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Fetch should work, we should get a 1 still. Errors are ignored.
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Check the number
|
||||
actual := atomic.LoadUint32(&retries)
|
||||
require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
|
||||
}
|
||||
|
||||
// Test that a badly behaved RPC that returns 0 index will perform a backoff.
|
||||
func TestCacheGet_periodicRefreshBadRPCZeroIndexErrorBackoff(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
Refresh: true,
|
||||
RefreshTimer: 0,
|
||||
RefreshTimeout: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Configure the type
|
||||
var retries uint32
|
||||
typ.Static(FetchResult{Value: 0, Index: 0}, nil).Run(func(args mock.Arguments) {
|
||||
atomic.AddUint32(&retries, 1)
|
||||
})
|
||||
|
||||
// Fetch
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 0)
|
||||
|
||||
// Sleep a bit. The refresh will quietly fail in the background. What we
|
||||
// want to verify is that it doesn't retry too much. "Too much" is hard
|
||||
// to measure since its CPU dependent if this test is failing. But due
|
||||
// to the short sleep below, we can calculate about what we'd expect if
|
||||
// backoff IS working.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Fetch should work, we should get a 0 still. Errors are ignored.
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 0)
|
||||
|
||||
// Check the number
|
||||
actual := atomic.LoadUint32(&retries)
|
||||
require.True(t, actual < 10, fmt.Sprintf("%d retries, should be < 10", actual))
|
||||
}
|
||||
|
||||
// Test that fetching with no index makes an initial request with no index, but
|
||||
// then ensures all background refreshes have > 0. This ensures we don't end up
|
||||
// with any index 0 loops from background refreshed while also returning
|
||||
// immediately on the initial request if there is no data written to that table
|
||||
// yet.
|
||||
func TestCacheGet_noIndexSetsOne(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
Refresh: true,
|
||||
RefreshTimer: 0,
|
||||
RefreshTimeout: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Simulate "well behaved" RPC with no data yet but returning 1
|
||||
{
|
||||
first := int32(1)
|
||||
|
||||
typ.Static(FetchResult{Value: 0, Index: 1}, nil).Run(func(args mock.Arguments) {
|
||||
opts := args.Get(0).(FetchOptions)
|
||||
isFirst := atomic.SwapInt32(&first, 0)
|
||||
if isFirst == 1 {
|
||||
assert.Equal(t, uint64(0), opts.MinIndex)
|
||||
} else {
|
||||
assert.True(t, opts.MinIndex > 0, "minIndex > 0")
|
||||
}
|
||||
})
|
||||
|
||||
// Fetch
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 0)
|
||||
|
||||
// Sleep a bit so background refresh happens
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Same for "badly behaved" RPC that returns 0 index and no data
|
||||
{
|
||||
first := int32(1)
|
||||
|
||||
typ.Static(FetchResult{Value: 0, Index: 0}, nil).Run(func(args mock.Arguments) {
|
||||
opts := args.Get(0).(FetchOptions)
|
||||
isFirst := atomic.SwapInt32(&first, 0)
|
||||
if isFirst == 1 {
|
||||
assert.Equal(t, uint64(0), opts.MinIndex)
|
||||
} else {
|
||||
assert.True(t, opts.MinIndex > 0, "minIndex > 0")
|
||||
}
|
||||
})
|
||||
|
||||
// Fetch
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 0)
|
||||
|
||||
// Sleep a bit so background refresh happens
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the backend fetch sets the proper timeout.
|
||||
func TestCacheGet_fetchTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
|
||||
// Register the type with a timeout
|
||||
timeout := 10 * time.Minute
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
RefreshTimeout: timeout,
|
||||
})
|
||||
|
||||
// Configure the type
|
||||
var actual time.Duration
|
||||
typ.Static(FetchResult{Value: 42}, nil).Times(1).Run(func(args mock.Arguments) {
|
||||
opts := args.Get(0).(FetchOptions)
|
||||
actual = opts.Timeout
|
||||
})
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Test the timeout
|
||||
require.Equal(timeout, actual)
|
||||
}
|
||||
|
||||
// Test that entries expire
|
||||
func TestCacheGet_expire(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
|
||||
// Register the type with a timeout
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
LastGetTTL: 400 * time.Millisecond,
|
||||
})
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 42}, nil).Times(2)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Get, should not fetch, verified via the mock assertions above
|
||||
req = TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.True(meta.Hit)
|
||||
|
||||
// Sleep for the expiry
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Get, should fetch
|
||||
req = TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test that entries reset their TTL on Get
|
||||
func TestCacheGet_expireResetGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
|
||||
// Register the type with a timeout
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
LastGetTTL: 150 * time.Millisecond,
|
||||
})
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 42}, nil).Times(2)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Fetch multiple times, where the total time is well beyond
|
||||
// the TTL. We should not trigger any fetches during this time.
|
||||
for i := 0; i < 5; i++ {
|
||||
// Sleep a bit
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Get, should not fetch
|
||||
req = TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.True(meta.Hit)
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Get, should fetch
|
||||
req = TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test that Get partitions the caches based on DC so two equivalent requests
|
||||
// to different datacenters are automatically cached even if their keys are
|
||||
// the same.
|
||||
func TestCacheGet_partitionDC(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", &testPartitionType{}, nil)
|
||||
|
||||
// Perform multiple gets
|
||||
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Datacenter: "dc1", Key: "hello"}))
|
||||
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Datacenter: "dc9", Key: "hello"}))
|
||||
|
||||
// Should return both!
|
||||
TestCacheGetChResult(t, getCh1, "dc1")
|
||||
TestCacheGetChResult(t, getCh2, "dc9")
|
||||
}
|
||||
|
||||
// Test that Get partitions the caches based on token so two equivalent requests
|
||||
// with different ACL tokens do not return the same result.
|
||||
func TestCacheGet_partitionToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", &testPartitionType{}, nil)
|
||||
|
||||
// Perform multiple gets
|
||||
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Token: "", Key: "hello"}))
|
||||
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Token: "foo", Key: "hello"}))
|
||||
|
||||
// Should return both!
|
||||
TestCacheGetChResult(t, getCh1, "")
|
||||
TestCacheGetChResult(t, getCh2, "foo")
|
||||
}
|
||||
|
||||
// testPartitionType implements Type for testing that simply returns a value
|
||||
// comprised of the request DC and ACL token, used for testing cache
|
||||
// partitioning.
|
||||
type testPartitionType struct{}
|
||||
|
||||
func (t *testPartitionType) Fetch(opts FetchOptions, r Request) (FetchResult, error) {
|
||||
info := r.CacheInfo()
|
||||
return FetchResult{
|
||||
Value: fmt.Sprintf("%s%s", info.Datacenter, info.Token),
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,143 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"time"
|
||||
)
|
||||
|
||||
// cacheEntry stores a single cache entry.
|
||||
//
|
||||
// Note that this isn't a very optimized structure currently. There are
|
||||
// a lot of improvements that can be made here in the long term.
|
||||
type cacheEntry struct {
|
||||
// Fields pertaining to the actual value
|
||||
Value interface{}
|
||||
Error error
|
||||
Index uint64
|
||||
|
||||
// Metadata that is used for internal accounting
|
||||
Valid bool // True if the Value is set
|
||||
Fetching bool // True if a fetch is already active
|
||||
Waiter chan struct{} // Closed when this entry is invalidated
|
||||
|
||||
// Expiry contains information about the expiration of this
|
||||
// entry. This is a pointer as its shared as a value in the
|
||||
// expiryHeap as well.
|
||||
Expiry *cacheEntryExpiry
|
||||
}
|
||||
|
||||
// cacheEntryExpiry contains the expiration information for a cache
|
||||
// entry. Any modifications to this struct should be done only while
|
||||
// the Cache entriesLock is held.
|
||||
type cacheEntryExpiry struct {
|
||||
Key string // Key in the cache map
|
||||
Expires time.Time // Time when entry expires (monotonic clock)
|
||||
TTL time.Duration // TTL for this entry to extend when resetting
|
||||
HeapIndex int // Index in the heap
|
||||
}
|
||||
|
||||
// Reset resets the expiration to be the ttl duration from now.
|
||||
func (e *cacheEntryExpiry) Reset() {
|
||||
e.Expires = time.Now().Add(e.TTL)
|
||||
}
|
||||
|
||||
// expiryHeap is a heap implementation that stores information about
|
||||
// when entires expire. Implements container/heap.Interface.
|
||||
//
|
||||
// All operations on the heap and read/write of the heap contents require
|
||||
// the proper entriesLock to be held on Cache.
|
||||
type expiryHeap struct {
|
||||
Entries []*cacheEntryExpiry
|
||||
|
||||
// NotifyCh is sent a value whenever the 0 index value of the heap
|
||||
// changes. This can be used to detect when the earliest value
|
||||
// changes.
|
||||
//
|
||||
// There is a single edge case where the heap will not automatically
|
||||
// send a notification: if heap.Fix is called manually and the index
|
||||
// changed is 0 and the change doesn't result in any moves (stays at index
|
||||
// 0), then we won't detect the change. To work around this, please
|
||||
// always call the expiryHeap.Fix method instead.
|
||||
NotifyCh chan struct{}
|
||||
}
|
||||
|
||||
// Identical to heap.Fix for this heap instance but will properly handle
|
||||
// the edge case where idx == 0 and no heap modification is necessary,
|
||||
// and still notify the NotifyCh.
|
||||
//
|
||||
// This is important for cache expiry since the expiry time may have been
|
||||
// extended and if we don't send a message to the NotifyCh then we'll never
|
||||
// reset the timer and the entry will be evicted early.
|
||||
func (h *expiryHeap) Fix(entry *cacheEntryExpiry) {
|
||||
idx := entry.HeapIndex
|
||||
heap.Fix(h, idx)
|
||||
|
||||
// This is the edge case we handle: if the prev (idx) and current (HeapIndex)
|
||||
// is zero, it means the head-of-line didn't change while the value
|
||||
// changed. Notify to reset our expiry worker.
|
||||
if idx == 0 && entry.HeapIndex == 0 {
|
||||
h.notify()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *expiryHeap) Len() int { return len(h.Entries) }
|
||||
|
||||
func (h *expiryHeap) Swap(i, j int) {
|
||||
h.Entries[i], h.Entries[j] = h.Entries[j], h.Entries[i]
|
||||
h.Entries[i].HeapIndex = i
|
||||
h.Entries[j].HeapIndex = j
|
||||
|
||||
// If we're moving the 0 index, update the channel since we need
|
||||
// to re-update the timer we're waiting on for the soonest expiring
|
||||
// value.
|
||||
if i == 0 || j == 0 {
|
||||
h.notify()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *expiryHeap) Less(i, j int) bool {
|
||||
// The usage of Before here is important (despite being obvious):
|
||||
// this function uses the monotonic time that should be available
|
||||
// on the time.Time value so the heap is immune to wall clock changes.
|
||||
return h.Entries[i].Expires.Before(h.Entries[j].Expires)
|
||||
}
|
||||
|
||||
// heap.Interface, this isn't expected to be called directly.
|
||||
func (h *expiryHeap) Push(x interface{}) {
|
||||
entry := x.(*cacheEntryExpiry)
|
||||
|
||||
// Set initial heap index, if we're going to the end then Swap
|
||||
// won't be called so we need to initialize
|
||||
entry.HeapIndex = len(h.Entries)
|
||||
|
||||
// For the first entry, we need to trigger a channel send because
|
||||
// Swap won't be called; nothing to swap! We can call it right away
|
||||
// because all heap operations are within a lock.
|
||||
if len(h.Entries) == 0 {
|
||||
h.notify()
|
||||
}
|
||||
|
||||
h.Entries = append(h.Entries, entry)
|
||||
}
|
||||
|
||||
// heap.Interface, this isn't expected to be called directly.
|
||||
func (h *expiryHeap) Pop() interface{} {
|
||||
old := h.Entries
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
h.Entries = old[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
func (h *expiryHeap) notify() {
|
||||
select {
|
||||
case h.NotifyCh <- struct{}{}:
|
||||
// Good
|
||||
|
||||
default:
|
||||
// If the send would've blocked, we just ignore it. The reason this
|
||||
// is safe is because NotifyCh should always be a buffered channel.
|
||||
// If this blocks, it means that there is a pending message anyways
|
||||
// so the receiver will restart regardless.
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExpiryHeap_impl(t *testing.T) {
|
||||
var _ heap.Interface = new(expiryHeap)
|
||||
}
|
||||
|
||||
func TestExpiryHeap(t *testing.T) {
|
||||
require := require.New(t)
|
||||
now := time.Now()
|
||||
ch := make(chan struct{}, 10) // buffered to prevent blocking in tests
|
||||
h := &expiryHeap{NotifyCh: ch}
|
||||
|
||||
// Init, shouldn't trigger anything
|
||||
heap.Init(h)
|
||||
testNoMessage(t, ch)
|
||||
|
||||
// Push an initial value, expect one message
|
||||
entry := &cacheEntryExpiry{Key: "foo", HeapIndex: -1, Expires: now.Add(100)}
|
||||
heap.Push(h, entry)
|
||||
require.Equal(0, entry.HeapIndex)
|
||||
testMessage(t, ch)
|
||||
testNoMessage(t, ch) // exactly one asserted above
|
||||
|
||||
// Push another that goes earlier than entry
|
||||
entry2 := &cacheEntryExpiry{Key: "bar", HeapIndex: -1, Expires: now.Add(50)}
|
||||
heap.Push(h, entry2)
|
||||
require.Equal(0, entry2.HeapIndex)
|
||||
require.Equal(1, entry.HeapIndex)
|
||||
testMessage(t, ch)
|
||||
testNoMessage(t, ch) // exactly one asserted above
|
||||
|
||||
// Push another that goes at the end
|
||||
entry3 := &cacheEntryExpiry{Key: "bar", HeapIndex: -1, Expires: now.Add(1000)}
|
||||
heap.Push(h, entry3)
|
||||
require.Equal(2, entry3.HeapIndex)
|
||||
testNoMessage(t, ch) // no notify cause index 0 stayed the same
|
||||
|
||||
// Remove the first entry (not Pop, since we don't use Pop, but that works too)
|
||||
remove := h.Entries[0]
|
||||
heap.Remove(h, remove.HeapIndex)
|
||||
require.Equal(0, entry.HeapIndex)
|
||||
require.Equal(1, entry3.HeapIndex)
|
||||
testMessage(t, ch)
|
||||
testMessage(t, ch) // we have two because two swaps happen
|
||||
testNoMessage(t, ch)
|
||||
|
||||
// Let's change entry 3 to be early, and fix it
|
||||
entry3.Expires = now.Add(10)
|
||||
h.Fix(entry3)
|
||||
require.Equal(1, entry.HeapIndex)
|
||||
require.Equal(0, entry3.HeapIndex)
|
||||
testMessage(t, ch)
|
||||
testNoMessage(t, ch)
|
||||
|
||||
// Let's change entry 3 again, this is an edge case where if the 0th
|
||||
// element changed, we didn't trigger the channel. Our Fix func should.
|
||||
entry.Expires = now.Add(20)
|
||||
h.Fix(entry3)
|
||||
require.Equal(1, entry.HeapIndex) // no move
|
||||
require.Equal(0, entry3.HeapIndex)
|
||||
testMessage(t, ch)
|
||||
testNoMessage(t, ch) // one message
|
||||
}
|
||||
|
||||
func testNoMessage(t *testing.T, ch <-chan struct{}) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatal("should not have a message")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func testMessage(t *testing.T, ch <-chan struct{}) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
t.Fatal("should have a message")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Code generated by mockery v1.0.0
|
||||
package cache
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockRequest is an autogenerated mock type for the Request type
|
||||
type MockRequest struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// CacheInfo provides a mock function with given fields:
|
||||
func (_m *MockRequest) CacheInfo() RequestInfo {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 RequestInfo
|
||||
if rf, ok := ret.Get(0).(func() RequestInfo); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(RequestInfo)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
// Code generated by mockery v1.0.0
|
||||
package cache
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockType is an autogenerated mock type for the Type type
|
||||
type MockType struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Fetch provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockType) Fetch(_a0 FetchOptions, _a1 Request) (FetchResult, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 FetchResult
|
||||
if rf, ok := ret.Get(0).(func(FetchOptions, Request) FetchResult); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
r0 = ret.Get(0).(FetchResult)
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(FetchOptions, Request) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Request is a cacheable request.
|
||||
//
|
||||
// This interface is typically implemented by request structures in
|
||||
// the agent/structs package.
|
||||
type Request interface {
|
||||
// CacheInfo returns information used for caching this request.
|
||||
CacheInfo() RequestInfo
|
||||
}
|
||||
|
||||
// RequestInfo represents cache information for a request. The caching
|
||||
// framework uses this to control the behavior of caching and to determine
|
||||
// cacheability.
|
||||
type RequestInfo struct {
|
||||
// Key is a unique cache key for this request. This key should
|
||||
// be globally unique to identify this request, since any conflicting
|
||||
// cache keys could result in invalid data being returned from the cache.
|
||||
// The Key does not need to include ACL or DC information, since the
|
||||
// cache already partitions by these values prior to using this key.
|
||||
Key string
|
||||
|
||||
// Token is the ACL token associated with this request.
|
||||
//
|
||||
// Datacenter is the datacenter that the request is targeting.
|
||||
//
|
||||
// Both of these values are used to partition the cache. The cache framework
|
||||
// today partitions data on these values to simplify behavior: by
|
||||
// partitioning ACL tokens, the cache doesn't need to be smart about
|
||||
// filtering results. By filtering datacenter results, the cache can
|
||||
// service the multi-DC nature of Consul. This comes at the expense of
|
||||
// working set size, but in general the effect is minimal.
|
||||
Token string
|
||||
Datacenter string
|
||||
|
||||
// MinIndex is the minimum index being queried. This is used to
|
||||
// determine if we already have data satisfying the query or if we need
|
||||
// to block until new data is available. If no index is available, the
|
||||
// default value (zero) is acceptable.
|
||||
MinIndex uint64
|
||||
|
||||
// Timeout is the timeout for waiting on a blocking query. When the
|
||||
// timeout is reached, the last known value is returned (or maybe nil
|
||||
// if there was no prior value). This "last known value" behavior matches
|
||||
// normal Consul blocking queries.
|
||||
Timeout time.Duration
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// TestCache returns a Cache instance configuring for testing.
|
||||
func TestCache(t testing.T) *Cache {
|
||||
// Simple but lets us do some fine-tuning later if we want to.
|
||||
return New(nil)
|
||||
}
|
||||
|
||||
// TestCacheGetCh returns a channel that returns the result of the Get call.
|
||||
// This is useful for testing timing and concurrency with Get calls. Any
|
||||
// error will be logged, so the result value should always be asserted.
|
||||
func TestCacheGetCh(t testing.T, c *Cache, typ string, r Request) <-chan interface{} {
|
||||
resultCh := make(chan interface{})
|
||||
go func() {
|
||||
result, _, err := c.Get(typ, r)
|
||||
if err != nil {
|
||||
t.Logf("Error: %s", err)
|
||||
close(resultCh)
|
||||
return
|
||||
}
|
||||
|
||||
resultCh <- result
|
||||
}()
|
||||
|
||||
return resultCh
|
||||
}
|
||||
|
||||
// TestCacheGetChResult tests that the result from TestCacheGetCh matches
|
||||
// within a reasonable period of time (it expects it to be "immediate" but
|
||||
// waits some milliseconds).
|
||||
func TestCacheGetChResult(t testing.T, ch <-chan interface{}, expected interface{}) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case result := <-ch:
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Fatalf("Result doesn't match!\n\n%#v\n\n%#v", result, expected)
|
||||
}
|
||||
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
t.Fatalf("Result not sent on channel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequest returns a Request that returns the given cache key and index.
|
||||
// The Reset method can be called to reset it for custom usage.
|
||||
func TestRequest(t testing.T, info RequestInfo) *MockRequest {
|
||||
req := &MockRequest{}
|
||||
req.On("CacheInfo").Return(info)
|
||||
return req
|
||||
}
|
||||
|
||||
// TestType returns a MockType that can be used to setup expectations
|
||||
// on data fetching.
|
||||
func TestType(t testing.T) *MockType {
|
||||
typ := &MockType{}
|
||||
return typ
|
||||
}
|
||||
|
||||
// A bit weird, but we add methods to the auto-generated structs here so that
|
||||
// they don't get clobbered. The helper methods are conveniences.
|
||||
|
||||
// Static sets a static value to return for a call to Fetch.
|
||||
func (m *MockType) Static(r FetchResult, err error) *mock.Call {
|
||||
return m.Mock.On("Fetch", mock.Anything, mock.Anything).Return(r, err)
|
||||
}
|
||||
|
||||
func (m *MockRequest) Reset() {
|
||||
m.Mock = mock.Mock{}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Type implements the logic to fetch certain types of data.
|
||||
type Type interface {
|
||||
// Fetch fetches a single unique item.
|
||||
//
|
||||
// The FetchOptions contain the index and timeouts for blocking queries.
|
||||
// The MinIndex value on the Request itself should NOT be used
|
||||
// as the blocking index since a request may be reused multiple times
|
||||
// as part of Refresh behavior.
|
||||
//
|
||||
// The return value is a FetchResult which contains information about
|
||||
// the fetch. If an error is given, the FetchResult is ignored. The
|
||||
// cache does not support backends that return partial values.
|
||||
//
|
||||
// On timeout, FetchResult can behave one of two ways. First, it can
|
||||
// return the last known value. This is the default behavior of blocking
|
||||
// RPC calls in Consul so this allows cache types to be implemented with
|
||||
// no extra logic. Second, FetchResult can return an unset value and index.
|
||||
// In this case, the cache will reuse the last value automatically.
|
||||
Fetch(FetchOptions, Request) (FetchResult, error)
|
||||
}
|
||||
|
||||
// FetchOptions are various settable options when a Fetch is called.
|
||||
type FetchOptions struct {
|
||||
// MinIndex is the minimum index to be used for blocking queries.
|
||||
// If blocking queries aren't supported for data being returned,
|
||||
// this value can be ignored.
|
||||
MinIndex uint64
|
||||
|
||||
// Timeout is the maximum time for the query. This must be implemented
|
||||
// in the Fetch itself.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// FetchResult is the result of a Type Fetch operation and contains the
|
||||
// data along with metadata gathered from that operation.
|
||||
type FetchResult struct {
|
||||
// Value is the result of the fetch.
|
||||
Value interface{}
|
||||
|
||||
// Index is the corresponding index value for this data.
|
||||
Index uint64
|
||||
}
|
|
@ -157,12 +157,27 @@ RETRY_ONCE:
|
|||
return out.Services, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) CatalogConnectServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
return s.catalogServiceNodes(resp, req, true)
|
||||
}
|
||||
|
||||
func (s *HTTPServer) CatalogServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_service_nodes"}, 1,
|
||||
return s.catalogServiceNodes(resp, req, false)
|
||||
}
|
||||
|
||||
func (s *HTTPServer) catalogServiceNodes(resp http.ResponseWriter, req *http.Request, connect bool) (interface{}, error) {
|
||||
metricsKey := "catalog_service_nodes"
|
||||
pathPrefix := "/v1/catalog/service/"
|
||||
if connect {
|
||||
metricsKey = "catalog_connect_service_nodes"
|
||||
pathPrefix = "/v1/catalog/connect/"
|
||||
}
|
||||
|
||||
metrics.IncrCounterWithLabels([]string{"client", "api", metricsKey}, 1,
|
||||
[]metrics.Label{{Name: "node", Value: s.nodeName()}})
|
||||
|
||||
// Set default DC
|
||||
args := structs.ServiceSpecificRequest{}
|
||||
args := structs.ServiceSpecificRequest{Connect: connect}
|
||||
s.parseSource(req, &args.Source)
|
||||
args.NodeMetaFilters = s.parseMetaFilter(req)
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
|
@ -177,7 +192,7 @@ func (s *HTTPServer) CatalogServiceNodes(resp http.ResponseWriter, req *http.Req
|
|||
}
|
||||
|
||||
// Pull out the service name
|
||||
args.ServiceName = strings.TrimPrefix(req.URL.Path, "/v1/catalog/service/")
|
||||
args.ServiceName = strings.TrimPrefix(req.URL.Path, pathPrefix)
|
||||
if args.ServiceName == "" {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, "Missing service name")
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/testutil/retry"
|
||||
"github.com/hashicorp/serf/coordinate"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCatalogRegister_Service_InvalidAddress(t *testing.T) {
|
||||
|
@ -750,6 +751,60 @@ func TestCatalogServiceNodes_DistanceSort(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test that connect proxies can be queried via /v1/catalog/service/:service
|
||||
// directly and that their results contain the proxy fields.
|
||||
func TestCatalogServiceNodes_ConnectProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/catalog/service/%s", args.Service.Service), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.CatalogServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
nodes := obj.(structs.ServiceNodes)
|
||||
assert.Len(nodes, 1)
|
||||
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
|
||||
}
|
||||
|
||||
// Test that the Connect-compatible endpoints can be queried for a
|
||||
// service via /v1/catalog/connect/:service.
|
||||
func TestCatalogConnectServiceNodes_good(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Address = "127.0.0.55"
|
||||
var out struct{}
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/catalog/connect/%s", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.CatalogConnectServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
nodes := obj.(structs.ServiceNodes)
|
||||
assert.Len(nodes, 1)
|
||||
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
|
||||
assert.Equal(args.Service.Address, nodes[0].ServiceAddress)
|
||||
}
|
||||
|
||||
func TestCatalogNodeServices(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
|
@ -785,6 +840,33 @@ func TestCatalogNodeServices(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test that the services on a node contain all the Connect proxies on
|
||||
// the node as well with their fields properly populated.
|
||||
func TestCatalogNodeServices_ConnectProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/catalog/node/%s", args.Node), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.CatalogNodeServices(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
ns := obj.(*structs.NodeServices)
|
||||
assert.Len(ns.Services, 1)
|
||||
v := ns.Services[args.Service.Service]
|
||||
assert.Equal(structs.ServiceKindConnectProxy, v.Kind)
|
||||
}
|
||||
|
||||
func TestCatalogNodeServices_WanTranslation(t *testing.T) {
|
||||
t.Parallel()
|
||||
a1 := NewTestAgent(t.Name(), `
|
||||
|
|
|
@ -14,9 +14,11 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/consul"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/ipaddr"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"github.com/hashicorp/consul/types"
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
|
@ -340,6 +342,12 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
|
|||
serverPort := b.portVal("ports.server", c.Ports.Server)
|
||||
serfPortLAN := b.portVal("ports.serf_lan", c.Ports.SerfLAN)
|
||||
serfPortWAN := b.portVal("ports.serf_wan", c.Ports.SerfWAN)
|
||||
proxyMinPort := b.portVal("ports.proxy_min_port", c.Ports.ProxyMinPort)
|
||||
proxyMaxPort := b.portVal("ports.proxy_max_port", c.Ports.ProxyMaxPort)
|
||||
if proxyMaxPort < proxyMinPort {
|
||||
return RuntimeConfig{}, fmt.Errorf(
|
||||
"proxy_min_port must be less than proxy_max_port. To disable, set both to zero.")
|
||||
}
|
||||
|
||||
// determine the default bind and advertise address
|
||||
//
|
||||
|
@ -520,6 +528,30 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
|
|||
consulRaftHeartbeatTimeout := b.durationVal("consul.raft.heartbeat_timeout", c.Consul.Raft.HeartbeatTimeout) * time.Duration(performanceRaftMultiplier)
|
||||
consulRaftLeaderLeaseTimeout := b.durationVal("consul.raft.leader_lease_timeout", c.Consul.Raft.LeaderLeaseTimeout) * time.Duration(performanceRaftMultiplier)
|
||||
|
||||
// Connect proxy defaults.
|
||||
connectEnabled := b.boolVal(c.Connect.Enabled)
|
||||
connectCAProvider := b.stringVal(c.Connect.CAProvider)
|
||||
connectCAConfig := c.Connect.CAConfig
|
||||
if connectCAConfig != nil {
|
||||
TranslateKeys(connectCAConfig, map[string]string{
|
||||
// Consul CA config
|
||||
"private_key": "PrivateKey",
|
||||
"root_cert": "RootCert",
|
||||
"rotation_period": "RotationPeriod",
|
||||
|
||||
// Vault CA config
|
||||
"address": "Address",
|
||||
"token": "Token",
|
||||
"root_pki_path": "RootPKIPath",
|
||||
"intermediate_pki_path": "IntermediatePKIPath",
|
||||
})
|
||||
}
|
||||
|
||||
proxyDefaultExecMode := b.stringVal(c.Connect.ProxyDefaults.ExecMode)
|
||||
proxyDefaultDaemonCommand := c.Connect.ProxyDefaults.DaemonCommand
|
||||
proxyDefaultScriptCommand := c.Connect.ProxyDefaults.ScriptCommand
|
||||
proxyDefaultConfig := c.Connect.ProxyDefaults.Config
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
// build runtime config
|
||||
//
|
||||
|
@ -592,6 +624,7 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
|
|||
DNSRecursors: dnsRecursors,
|
||||
DNSServiceTTL: dnsServiceTTL,
|
||||
DNSUDPAnswerLimit: b.intVal(c.DNS.UDPAnswerLimit),
|
||||
DNSNodeMetaTXT: b.boolValWithDefault(c.DNS.NodeMetaTXT, true),
|
||||
|
||||
// HTTP
|
||||
HTTPPort: httpPort,
|
||||
|
@ -602,120 +635,133 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
|
|||
HTTPResponseHeaders: c.HTTPConfig.ResponseHeaders,
|
||||
|
||||
// Telemetry
|
||||
TelemetryCirconusAPIApp: b.stringVal(c.Telemetry.CirconusAPIApp),
|
||||
TelemetryCirconusAPIToken: b.stringVal(c.Telemetry.CirconusAPIToken),
|
||||
TelemetryCirconusAPIURL: b.stringVal(c.Telemetry.CirconusAPIURL),
|
||||
TelemetryCirconusBrokerID: b.stringVal(c.Telemetry.CirconusBrokerID),
|
||||
TelemetryCirconusBrokerSelectTag: b.stringVal(c.Telemetry.CirconusBrokerSelectTag),
|
||||
TelemetryCirconusCheckDisplayName: b.stringVal(c.Telemetry.CirconusCheckDisplayName),
|
||||
TelemetryCirconusCheckForceMetricActivation: b.stringVal(c.Telemetry.CirconusCheckForceMetricActivation),
|
||||
TelemetryCirconusCheckID: b.stringVal(c.Telemetry.CirconusCheckID),
|
||||
TelemetryCirconusCheckInstanceID: b.stringVal(c.Telemetry.CirconusCheckInstanceID),
|
||||
TelemetryCirconusCheckSearchTag: b.stringVal(c.Telemetry.CirconusCheckSearchTag),
|
||||
TelemetryCirconusCheckTags: b.stringVal(c.Telemetry.CirconusCheckTags),
|
||||
TelemetryCirconusSubmissionInterval: b.stringVal(c.Telemetry.CirconusSubmissionInterval),
|
||||
TelemetryCirconusSubmissionURL: b.stringVal(c.Telemetry.CirconusSubmissionURL),
|
||||
TelemetryDisableHostname: b.boolVal(c.Telemetry.DisableHostname),
|
||||
TelemetryDogstatsdAddr: b.stringVal(c.Telemetry.DogstatsdAddr),
|
||||
TelemetryDogstatsdTags: c.Telemetry.DogstatsdTags,
|
||||
TelemetryPrometheusRetentionTime: b.durationVal("prometheus_retention_time", c.Telemetry.PrometheusRetentionTime),
|
||||
TelemetryFilterDefault: b.boolVal(c.Telemetry.FilterDefault),
|
||||
TelemetryAllowedPrefixes: telemetryAllowedPrefixes,
|
||||
TelemetryBlockedPrefixes: telemetryBlockedPrefixes,
|
||||
TelemetryMetricsPrefix: b.stringVal(c.Telemetry.MetricsPrefix),
|
||||
TelemetryStatsdAddr: b.stringVal(c.Telemetry.StatsdAddr),
|
||||
TelemetryStatsiteAddr: b.stringVal(c.Telemetry.StatsiteAddr),
|
||||
Telemetry: lib.TelemetryConfig{
|
||||
CirconusAPIApp: b.stringVal(c.Telemetry.CirconusAPIApp),
|
||||
CirconusAPIToken: b.stringVal(c.Telemetry.CirconusAPIToken),
|
||||
CirconusAPIURL: b.stringVal(c.Telemetry.CirconusAPIURL),
|
||||
CirconusBrokerID: b.stringVal(c.Telemetry.CirconusBrokerID),
|
||||
CirconusBrokerSelectTag: b.stringVal(c.Telemetry.CirconusBrokerSelectTag),
|
||||
CirconusCheckDisplayName: b.stringVal(c.Telemetry.CirconusCheckDisplayName),
|
||||
CirconusCheckForceMetricActivation: b.stringVal(c.Telemetry.CirconusCheckForceMetricActivation),
|
||||
CirconusCheckID: b.stringVal(c.Telemetry.CirconusCheckID),
|
||||
CirconusCheckInstanceID: b.stringVal(c.Telemetry.CirconusCheckInstanceID),
|
||||
CirconusCheckSearchTag: b.stringVal(c.Telemetry.CirconusCheckSearchTag),
|
||||
CirconusCheckTags: b.stringVal(c.Telemetry.CirconusCheckTags),
|
||||
CirconusSubmissionInterval: b.stringVal(c.Telemetry.CirconusSubmissionInterval),
|
||||
CirconusSubmissionURL: b.stringVal(c.Telemetry.CirconusSubmissionURL),
|
||||
DisableHostname: b.boolVal(c.Telemetry.DisableHostname),
|
||||
DogstatsdAddr: b.stringVal(c.Telemetry.DogstatsdAddr),
|
||||
DogstatsdTags: c.Telemetry.DogstatsdTags,
|
||||
PrometheusRetentionTime: b.durationVal("prometheus_retention_time", c.Telemetry.PrometheusRetentionTime),
|
||||
FilterDefault: b.boolVal(c.Telemetry.FilterDefault),
|
||||
AllowedPrefixes: telemetryAllowedPrefixes,
|
||||
BlockedPrefixes: telemetryBlockedPrefixes,
|
||||
MetricsPrefix: b.stringVal(c.Telemetry.MetricsPrefix),
|
||||
StatsdAddr: b.stringVal(c.Telemetry.StatsdAddr),
|
||||
StatsiteAddr: b.stringVal(c.Telemetry.StatsiteAddr),
|
||||
},
|
||||
|
||||
// Agent
|
||||
AdvertiseAddrLAN: advertiseAddrLAN,
|
||||
AdvertiseAddrWAN: advertiseAddrWAN,
|
||||
BindAddr: bindAddr,
|
||||
Bootstrap: b.boolVal(c.Bootstrap),
|
||||
BootstrapExpect: b.intVal(c.BootstrapExpect),
|
||||
CAFile: b.stringVal(c.CAFile),
|
||||
CAPath: b.stringVal(c.CAPath),
|
||||
CertFile: b.stringVal(c.CertFile),
|
||||
CheckUpdateInterval: b.durationVal("check_update_interval", c.CheckUpdateInterval),
|
||||
Checks: checks,
|
||||
ClientAddrs: clientAddrs,
|
||||
DataDir: b.stringVal(c.DataDir),
|
||||
Datacenter: strings.ToLower(b.stringVal(c.Datacenter)),
|
||||
DevMode: b.boolVal(b.Flags.DevMode),
|
||||
DisableAnonymousSignature: b.boolVal(c.DisableAnonymousSignature),
|
||||
DisableCoordinates: b.boolVal(c.DisableCoordinates),
|
||||
DisableHostNodeID: b.boolVal(c.DisableHostNodeID),
|
||||
DisableKeyringFile: b.boolVal(c.DisableKeyringFile),
|
||||
DisableRemoteExec: b.boolVal(c.DisableRemoteExec),
|
||||
DisableUpdateCheck: b.boolVal(c.DisableUpdateCheck),
|
||||
DiscardCheckOutput: b.boolVal(c.DiscardCheckOutput),
|
||||
DiscoveryMaxStale: b.durationVal("discovery_max_stale", c.DiscoveryMaxStale),
|
||||
EnableAgentTLSForChecks: b.boolVal(c.EnableAgentTLSForChecks),
|
||||
EnableDebug: b.boolVal(c.EnableDebug),
|
||||
EnableScriptChecks: b.boolVal(c.EnableScriptChecks),
|
||||
EnableSyslog: b.boolVal(c.EnableSyslog),
|
||||
EnableUI: b.boolVal(c.UI),
|
||||
EncryptKey: b.stringVal(c.EncryptKey),
|
||||
EncryptVerifyIncoming: b.boolVal(c.EncryptVerifyIncoming),
|
||||
EncryptVerifyOutgoing: b.boolVal(c.EncryptVerifyOutgoing),
|
||||
KeyFile: b.stringVal(c.KeyFile),
|
||||
LeaveDrainTime: b.durationVal("performance.leave_drain_time", c.Performance.LeaveDrainTime),
|
||||
LeaveOnTerm: leaveOnTerm,
|
||||
LogLevel: b.stringVal(c.LogLevel),
|
||||
NodeID: types.NodeID(b.stringVal(c.NodeID)),
|
||||
NodeMeta: c.NodeMeta,
|
||||
NodeName: b.nodeName(c.NodeName),
|
||||
NonVotingServer: b.boolVal(c.NonVotingServer),
|
||||
PidFile: b.stringVal(c.PidFile),
|
||||
RPCAdvertiseAddr: rpcAdvertiseAddr,
|
||||
RPCBindAddr: rpcBindAddr,
|
||||
RPCHoldTimeout: b.durationVal("performance.rpc_hold_timeout", c.Performance.RPCHoldTimeout),
|
||||
RPCMaxBurst: b.intVal(c.Limits.RPCMaxBurst),
|
||||
RPCProtocol: b.intVal(c.RPCProtocol),
|
||||
RPCRateLimit: rate.Limit(b.float64Val(c.Limits.RPCRate)),
|
||||
RaftProtocol: b.intVal(c.RaftProtocol),
|
||||
RaftSnapshotThreshold: b.intVal(c.RaftSnapshotThreshold),
|
||||
RaftSnapshotInterval: b.durationVal("raft_snapshot_interval", c.RaftSnapshotInterval),
|
||||
ReconnectTimeoutLAN: b.durationVal("reconnect_timeout", c.ReconnectTimeoutLAN),
|
||||
ReconnectTimeoutWAN: b.durationVal("reconnect_timeout_wan", c.ReconnectTimeoutWAN),
|
||||
RejoinAfterLeave: b.boolVal(c.RejoinAfterLeave),
|
||||
RetryJoinIntervalLAN: b.durationVal("retry_interval", c.RetryJoinIntervalLAN),
|
||||
RetryJoinIntervalWAN: b.durationVal("retry_interval_wan", c.RetryJoinIntervalWAN),
|
||||
RetryJoinLAN: b.expandAllOptionalAddrs("retry_join", c.RetryJoinLAN),
|
||||
RetryJoinMaxAttemptsLAN: b.intVal(c.RetryJoinMaxAttemptsLAN),
|
||||
RetryJoinMaxAttemptsWAN: b.intVal(c.RetryJoinMaxAttemptsWAN),
|
||||
RetryJoinWAN: b.expandAllOptionalAddrs("retry_join_wan", c.RetryJoinWAN),
|
||||
SegmentName: b.stringVal(c.SegmentName),
|
||||
Segments: segments,
|
||||
SerfAdvertiseAddrLAN: serfAdvertiseAddrLAN,
|
||||
SerfAdvertiseAddrWAN: serfAdvertiseAddrWAN,
|
||||
SerfBindAddrLAN: serfBindAddrLAN,
|
||||
SerfBindAddrWAN: serfBindAddrWAN,
|
||||
SerfPortLAN: serfPortLAN,
|
||||
SerfPortWAN: serfPortWAN,
|
||||
ServerMode: b.boolVal(c.ServerMode),
|
||||
ServerName: b.stringVal(c.ServerName),
|
||||
ServerPort: serverPort,
|
||||
Services: services,
|
||||
SessionTTLMin: b.durationVal("session_ttl_min", c.SessionTTLMin),
|
||||
SkipLeaveOnInt: skipLeaveOnInt,
|
||||
StartJoinAddrsLAN: b.expandAllOptionalAddrs("start_join", c.StartJoinAddrsLAN),
|
||||
StartJoinAddrsWAN: b.expandAllOptionalAddrs("start_join_wan", c.StartJoinAddrsWAN),
|
||||
SyslogFacility: b.stringVal(c.SyslogFacility),
|
||||
TLSCipherSuites: b.tlsCipherSuites("tls_cipher_suites", c.TLSCipherSuites),
|
||||
TLSMinVersion: b.stringVal(c.TLSMinVersion),
|
||||
TLSPreferServerCipherSuites: b.boolVal(c.TLSPreferServerCipherSuites),
|
||||
TaggedAddresses: c.TaggedAddresses,
|
||||
TranslateWANAddrs: b.boolVal(c.TranslateWANAddrs),
|
||||
UIDir: b.stringVal(c.UIDir),
|
||||
UnixSocketGroup: b.stringVal(c.UnixSocket.Group),
|
||||
UnixSocketMode: b.stringVal(c.UnixSocket.Mode),
|
||||
UnixSocketUser: b.stringVal(c.UnixSocket.User),
|
||||
VerifyIncoming: b.boolVal(c.VerifyIncoming),
|
||||
VerifyIncomingHTTPS: b.boolVal(c.VerifyIncomingHTTPS),
|
||||
VerifyIncomingRPC: b.boolVal(c.VerifyIncomingRPC),
|
||||
VerifyOutgoing: b.boolVal(c.VerifyOutgoing),
|
||||
VerifyServerHostname: b.boolVal(c.VerifyServerHostname),
|
||||
Watches: c.Watches,
|
||||
AdvertiseAddrLAN: advertiseAddrLAN,
|
||||
AdvertiseAddrWAN: advertiseAddrWAN,
|
||||
BindAddr: bindAddr,
|
||||
Bootstrap: b.boolVal(c.Bootstrap),
|
||||
BootstrapExpect: b.intVal(c.BootstrapExpect),
|
||||
CAFile: b.stringVal(c.CAFile),
|
||||
CAPath: b.stringVal(c.CAPath),
|
||||
CertFile: b.stringVal(c.CertFile),
|
||||
CheckUpdateInterval: b.durationVal("check_update_interval", c.CheckUpdateInterval),
|
||||
Checks: checks,
|
||||
ClientAddrs: clientAddrs,
|
||||
ConnectEnabled: connectEnabled,
|
||||
ConnectCAProvider: connectCAProvider,
|
||||
ConnectCAConfig: connectCAConfig,
|
||||
ConnectProxyAllowManagedRoot: b.boolVal(c.Connect.Proxy.AllowManagedRoot),
|
||||
ConnectProxyAllowManagedAPIRegistration: b.boolVal(c.Connect.Proxy.AllowManagedAPIRegistration),
|
||||
ConnectProxyBindMinPort: proxyMinPort,
|
||||
ConnectProxyBindMaxPort: proxyMaxPort,
|
||||
ConnectProxyDefaultExecMode: proxyDefaultExecMode,
|
||||
ConnectProxyDefaultDaemonCommand: proxyDefaultDaemonCommand,
|
||||
ConnectProxyDefaultScriptCommand: proxyDefaultScriptCommand,
|
||||
ConnectProxyDefaultConfig: proxyDefaultConfig,
|
||||
DataDir: b.stringVal(c.DataDir),
|
||||
Datacenter: strings.ToLower(b.stringVal(c.Datacenter)),
|
||||
DevMode: b.boolVal(b.Flags.DevMode),
|
||||
DisableAnonymousSignature: b.boolVal(c.DisableAnonymousSignature),
|
||||
DisableCoordinates: b.boolVal(c.DisableCoordinates),
|
||||
DisableHostNodeID: b.boolVal(c.DisableHostNodeID),
|
||||
DisableKeyringFile: b.boolVal(c.DisableKeyringFile),
|
||||
DisableRemoteExec: b.boolVal(c.DisableRemoteExec),
|
||||
DisableUpdateCheck: b.boolVal(c.DisableUpdateCheck),
|
||||
DiscardCheckOutput: b.boolVal(c.DiscardCheckOutput),
|
||||
DiscoveryMaxStale: b.durationVal("discovery_max_stale", c.DiscoveryMaxStale),
|
||||
EnableAgentTLSForChecks: b.boolVal(c.EnableAgentTLSForChecks),
|
||||
EnableDebug: b.boolVal(c.EnableDebug),
|
||||
EnableScriptChecks: b.boolVal(c.EnableScriptChecks),
|
||||
EnableSyslog: b.boolVal(c.EnableSyslog),
|
||||
EnableUI: b.boolVal(c.UI),
|
||||
EncryptKey: b.stringVal(c.EncryptKey),
|
||||
EncryptVerifyIncoming: b.boolVal(c.EncryptVerifyIncoming),
|
||||
EncryptVerifyOutgoing: b.boolVal(c.EncryptVerifyOutgoing),
|
||||
KeyFile: b.stringVal(c.KeyFile),
|
||||
LeaveDrainTime: b.durationVal("performance.leave_drain_time", c.Performance.LeaveDrainTime),
|
||||
LeaveOnTerm: leaveOnTerm,
|
||||
LogLevel: b.stringVal(c.LogLevel),
|
||||
NodeID: types.NodeID(b.stringVal(c.NodeID)),
|
||||
NodeMeta: c.NodeMeta,
|
||||
NodeName: b.nodeName(c.NodeName),
|
||||
NonVotingServer: b.boolVal(c.NonVotingServer),
|
||||
PidFile: b.stringVal(c.PidFile),
|
||||
RPCAdvertiseAddr: rpcAdvertiseAddr,
|
||||
RPCBindAddr: rpcBindAddr,
|
||||
RPCHoldTimeout: b.durationVal("performance.rpc_hold_timeout", c.Performance.RPCHoldTimeout),
|
||||
RPCMaxBurst: b.intVal(c.Limits.RPCMaxBurst),
|
||||
RPCProtocol: b.intVal(c.RPCProtocol),
|
||||
RPCRateLimit: rate.Limit(b.float64Val(c.Limits.RPCRate)),
|
||||
RaftProtocol: b.intVal(c.RaftProtocol),
|
||||
RaftSnapshotThreshold: b.intVal(c.RaftSnapshotThreshold),
|
||||
RaftSnapshotInterval: b.durationVal("raft_snapshot_interval", c.RaftSnapshotInterval),
|
||||
ReconnectTimeoutLAN: b.durationVal("reconnect_timeout", c.ReconnectTimeoutLAN),
|
||||
ReconnectTimeoutWAN: b.durationVal("reconnect_timeout_wan", c.ReconnectTimeoutWAN),
|
||||
RejoinAfterLeave: b.boolVal(c.RejoinAfterLeave),
|
||||
RetryJoinIntervalLAN: b.durationVal("retry_interval", c.RetryJoinIntervalLAN),
|
||||
RetryJoinIntervalWAN: b.durationVal("retry_interval_wan", c.RetryJoinIntervalWAN),
|
||||
RetryJoinLAN: b.expandAllOptionalAddrs("retry_join", c.RetryJoinLAN),
|
||||
RetryJoinMaxAttemptsLAN: b.intVal(c.RetryJoinMaxAttemptsLAN),
|
||||
RetryJoinMaxAttemptsWAN: b.intVal(c.RetryJoinMaxAttemptsWAN),
|
||||
RetryJoinWAN: b.expandAllOptionalAddrs("retry_join_wan", c.RetryJoinWAN),
|
||||
SegmentName: b.stringVal(c.SegmentName),
|
||||
Segments: segments,
|
||||
SerfAdvertiseAddrLAN: serfAdvertiseAddrLAN,
|
||||
SerfAdvertiseAddrWAN: serfAdvertiseAddrWAN,
|
||||
SerfBindAddrLAN: serfBindAddrLAN,
|
||||
SerfBindAddrWAN: serfBindAddrWAN,
|
||||
SerfPortLAN: serfPortLAN,
|
||||
SerfPortWAN: serfPortWAN,
|
||||
ServerMode: b.boolVal(c.ServerMode),
|
||||
ServerName: b.stringVal(c.ServerName),
|
||||
ServerPort: serverPort,
|
||||
Services: services,
|
||||
SessionTTLMin: b.durationVal("session_ttl_min", c.SessionTTLMin),
|
||||
SkipLeaveOnInt: skipLeaveOnInt,
|
||||
StartJoinAddrsLAN: b.expandAllOptionalAddrs("start_join", c.StartJoinAddrsLAN),
|
||||
StartJoinAddrsWAN: b.expandAllOptionalAddrs("start_join_wan", c.StartJoinAddrsWAN),
|
||||
SyslogFacility: b.stringVal(c.SyslogFacility),
|
||||
TLSCipherSuites: b.tlsCipherSuites("tls_cipher_suites", c.TLSCipherSuites),
|
||||
TLSMinVersion: b.stringVal(c.TLSMinVersion),
|
||||
TLSPreferServerCipherSuites: b.boolVal(c.TLSPreferServerCipherSuites),
|
||||
TaggedAddresses: c.TaggedAddresses,
|
||||
TranslateWANAddrs: b.boolVal(c.TranslateWANAddrs),
|
||||
UIDir: b.stringVal(c.UIDir),
|
||||
UnixSocketGroup: b.stringVal(c.UnixSocket.Group),
|
||||
UnixSocketMode: b.stringVal(c.UnixSocket.Mode),
|
||||
UnixSocketUser: b.stringVal(c.UnixSocket.User),
|
||||
VerifyIncoming: b.boolVal(c.VerifyIncoming),
|
||||
VerifyIncomingHTTPS: b.boolVal(c.VerifyIncomingHTTPS),
|
||||
VerifyIncomingRPC: b.boolVal(c.VerifyIncomingRPC),
|
||||
VerifyOutgoing: b.boolVal(c.VerifyOutgoing),
|
||||
VerifyServerHostname: b.boolVal(c.VerifyServerHostname),
|
||||
Watches: c.Watches,
|
||||
}
|
||||
|
||||
if rt.BootstrapExpect == 1 {
|
||||
|
@ -881,6 +927,34 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
|
|||
return b.err
|
||||
}
|
||||
|
||||
// Check for errors in the service definitions
|
||||
for _, s := range rt.Services {
|
||||
if err := s.Validate(); err != nil {
|
||||
return fmt.Errorf("service %q: %s", s.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the given Connect CA provider config
|
||||
validCAProviders := map[string]bool{
|
||||
"": true,
|
||||
structs.ConsulCAProvider: true,
|
||||
structs.VaultCAProvider: true,
|
||||
}
|
||||
if _, ok := validCAProviders[rt.ConnectCAProvider]; !ok {
|
||||
return fmt.Errorf("%s is not a valid CA provider", rt.ConnectCAProvider)
|
||||
} else {
|
||||
switch rt.ConnectCAProvider {
|
||||
case structs.ConsulCAProvider:
|
||||
if _, err := ca.ParseConsulCAConfig(rt.ConnectCAConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
case structs.VaultCAProvider:
|
||||
if _, err := ca.ParseVaultCAConfig(rt.ConnectCAConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
// warnings
|
||||
//
|
||||
|
@ -1007,16 +1081,41 @@ func (b *Builder) serviceVal(v *ServiceDefinition) *structs.ServiceDefinition {
|
|||
Token: b.stringVal(v.Token),
|
||||
EnableTagOverride: b.boolVal(v.EnableTagOverride),
|
||||
Checks: checks,
|
||||
Connect: b.serviceConnectVal(v.Connect),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Builder) boolVal(v *bool) bool {
|
||||
func (b *Builder) serviceConnectVal(v *ServiceConnect) *structs.ServiceConnect {
|
||||
if v == nil {
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
var proxy *structs.ServiceDefinitionConnectProxy
|
||||
if v.Proxy != nil {
|
||||
proxy = &structs.ServiceDefinitionConnectProxy{
|
||||
ExecMode: b.stringVal(v.Proxy.ExecMode),
|
||||
Command: v.Proxy.Command,
|
||||
Config: v.Proxy.Config,
|
||||
}
|
||||
}
|
||||
|
||||
return &structs.ServiceConnect{
|
||||
Proxy: proxy,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Builder) boolValWithDefault(v *bool, default_val bool) bool {
|
||||
if v == nil {
|
||||
return default_val
|
||||
}
|
||||
|
||||
return *v
|
||||
}
|
||||
|
||||
func (b *Builder) boolVal(v *bool) bool {
|
||||
return b.boolValWithDefault(v, false)
|
||||
}
|
||||
|
||||
func (b *Builder) durationVal(name string, v *string) (d time.Duration) {
|
||||
if v == nil {
|
||||
return 0
|
||||
|
|
|
@ -84,6 +84,7 @@ func Parse(data string, format string) (c Config, err error) {
|
|||
"services",
|
||||
"services.checks",
|
||||
"watches",
|
||||
"service.connect.proxy.config.upstreams",
|
||||
})
|
||||
|
||||
// There is a difference of representation of some fields depending on
|
||||
|
@ -159,6 +160,7 @@ type Config struct {
|
|||
CheckUpdateInterval *string `json:"check_update_interval,omitempty" hcl:"check_update_interval" mapstructure:"check_update_interval"`
|
||||
Checks []CheckDefinition `json:"checks,omitempty" hcl:"checks" mapstructure:"checks"`
|
||||
ClientAddr *string `json:"client_addr,omitempty" hcl:"client_addr" mapstructure:"client_addr"`
|
||||
Connect Connect `json:"connect,omitempty" hcl:"connect" mapstructure:"connect"`
|
||||
DNS DNS `json:"dns_config,omitempty" hcl:"dns_config" mapstructure:"dns_config"`
|
||||
DNSDomain *string `json:"domain,omitempty" hcl:"domain" mapstructure:"domain"`
|
||||
DNSRecursors []string `json:"recursors,omitempty" hcl:"recursors" mapstructure:"recursors"`
|
||||
|
@ -324,6 +326,7 @@ type ServiceDefinition struct {
|
|||
Checks []CheckDefinition `json:"checks,omitempty" hcl:"checks" mapstructure:"checks"`
|
||||
Token *string `json:"token,omitempty" hcl:"token" mapstructure:"token"`
|
||||
EnableTagOverride *bool `json:"enable_tag_override,omitempty" hcl:"enable_tag_override" mapstructure:"enable_tag_override"`
|
||||
Connect *ServiceConnect `json:"connect,omitempty" hcl:"connect" mapstructure:"connect"`
|
||||
}
|
||||
|
||||
type CheckDefinition struct {
|
||||
|
@ -349,6 +352,58 @@ type CheckDefinition struct {
|
|||
DeregisterCriticalServiceAfter *string `json:"deregister_critical_service_after,omitempty" hcl:"deregister_critical_service_after" mapstructure:"deregister_critical_service_after"`
|
||||
}
|
||||
|
||||
// ServiceConnect is the connect block within a service registration
|
||||
type ServiceConnect struct {
|
||||
// TODO(banks) add way to specify that the app is connect-native
|
||||
// Proxy configures a connect proxy instance for the service
|
||||
Proxy *ServiceConnectProxy `json:"proxy,omitempty" hcl:"proxy" mapstructure:"proxy"`
|
||||
}
|
||||
|
||||
type ServiceConnectProxy struct {
|
||||
Command []string `json:"command,omitempty" hcl:"command" mapstructure:"command"`
|
||||
ExecMode *string `json:"exec_mode,omitempty" hcl:"exec_mode" mapstructure:"exec_mode"`
|
||||
Config map[string]interface{} `json:"config,omitempty" hcl:"config" mapstructure:"config"`
|
||||
}
|
||||
|
||||
// Connect is the agent-global connect configuration.
|
||||
type Connect struct {
|
||||
// Enabled opts the agent into connect. It should be set on all clients and
|
||||
// servers in a cluster for correct connect operation.
|
||||
Enabled *bool `json:"enabled,omitempty" hcl:"enabled" mapstructure:"enabled"`
|
||||
Proxy ConnectProxy `json:"proxy,omitempty" hcl:"proxy" mapstructure:"proxy"`
|
||||
ProxyDefaults ConnectProxyDefaults `json:"proxy_defaults,omitempty" hcl:"proxy_defaults" mapstructure:"proxy_defaults"`
|
||||
CAProvider *string `json:"ca_provider,omitempty" hcl:"ca_provider" mapstructure:"ca_provider"`
|
||||
CAConfig map[string]interface{} `json:"ca_config,omitempty" hcl:"ca_config" mapstructure:"ca_config"`
|
||||
}
|
||||
|
||||
// ConnectProxy is the agent-global connect proxy configuration.
|
||||
type ConnectProxy struct {
|
||||
// Consul will not execute managed proxies if its EUID is 0 (root).
|
||||
// If this is true, then Consul will execute proxies if Consul is
|
||||
// running as root. This is not recommended.
|
||||
AllowManagedRoot *bool `json:"allow_managed_root" hcl:"allow_managed_root" mapstructure:"allow_managed_root"`
|
||||
|
||||
// AllowManagedAPIRegistration enables managed proxy registration
|
||||
// via the agent HTTP API. If this is false, only file configurations
|
||||
// can be used.
|
||||
AllowManagedAPIRegistration *bool `json:"allow_managed_api_registration" hcl:"allow_managed_api_registration" mapstructure:"allow_managed_api_registration"`
|
||||
}
|
||||
|
||||
// ConnectProxyDefaults is the agent-global defaults for managed Connect proxies.
|
||||
type ConnectProxyDefaults struct {
|
||||
// ExecMode is used where a registration doesn't include an exec_mode.
|
||||
// Defaults to daemon.
|
||||
ExecMode *string `json:"exec_mode,omitempty" hcl:"exec_mode" mapstructure:"exec_mode"`
|
||||
// DaemonCommand is used to start proxy in exec_mode = daemon if not specified
|
||||
// at registration time.
|
||||
DaemonCommand []string `json:"daemon_command,omitempty" hcl:"daemon_command" mapstructure:"daemon_command"`
|
||||
// ScriptCommand is used to start proxy in exec_mode = script if not specified
|
||||
// at registration time.
|
||||
ScriptCommand []string `json:"script_command,omitempty" hcl:"script_command" mapstructure:"script_command"`
|
||||
// Config is merged into an Config specified at registration time.
|
||||
Config map[string]interface{} `json:"config,omitempty" hcl:"config" mapstructure:"config"`
|
||||
}
|
||||
|
||||
type DNS struct {
|
||||
AllowStale *bool `json:"allow_stale,omitempty" hcl:"allow_stale" mapstructure:"allow_stale"`
|
||||
ARecordLimit *int `json:"a_record_limit,omitempty" hcl:"a_record_limit" mapstructure:"a_record_limit"`
|
||||
|
@ -360,6 +415,7 @@ type DNS struct {
|
|||
RecursorTimeout *string `json:"recursor_timeout,omitempty" hcl:"recursor_timeout" mapstructure:"recursor_timeout"`
|
||||
ServiceTTL map[string]string `json:"service_ttl,omitempty" hcl:"service_ttl" mapstructure:"service_ttl"`
|
||||
UDPAnswerLimit *int `json:"udp_answer_limit,omitempty" hcl:"udp_answer_limit" mapstructure:"udp_answer_limit"`
|
||||
NodeMetaTXT *bool `json:"enable_additional_node_meta_txt,omitempty" hcl:"enable_additional_node_meta_txt" mapstructure:"enable_additional_node_meta_txt"`
|
||||
}
|
||||
|
||||
type HTTPConfig struct {
|
||||
|
@ -399,12 +455,14 @@ type Telemetry struct {
|
|||
}
|
||||
|
||||
type Ports struct {
|
||||
DNS *int `json:"dns,omitempty" hcl:"dns" mapstructure:"dns"`
|
||||
HTTP *int `json:"http,omitempty" hcl:"http" mapstructure:"http"`
|
||||
HTTPS *int `json:"https,omitempty" hcl:"https" mapstructure:"https"`
|
||||
SerfLAN *int `json:"serf_lan,omitempty" hcl:"serf_lan" mapstructure:"serf_lan"`
|
||||
SerfWAN *int `json:"serf_wan,omitempty" hcl:"serf_wan" mapstructure:"serf_wan"`
|
||||
Server *int `json:"server,omitempty" hcl:"server" mapstructure:"server"`
|
||||
DNS *int `json:"dns,omitempty" hcl:"dns" mapstructure:"dns"`
|
||||
HTTP *int `json:"http,omitempty" hcl:"http" mapstructure:"http"`
|
||||
HTTPS *int `json:"https,omitempty" hcl:"https" mapstructure:"https"`
|
||||
SerfLAN *int `json:"serf_lan,omitempty" hcl:"serf_lan" mapstructure:"serf_lan"`
|
||||
SerfWAN *int `json:"serf_wan,omitempty" hcl:"serf_wan" mapstructure:"serf_wan"`
|
||||
Server *int `json:"server,omitempty" hcl:"server" mapstructure:"server"`
|
||||
ProxyMinPort *int `json:"proxy_min_port,omitempty" hcl:"proxy_min_port" mapstructure:"proxy_min_port"`
|
||||
ProxyMaxPort *int `json:"proxy_max_port,omitempty" hcl:"proxy_max_port" mapstructure:"proxy_max_port"`
|
||||
}
|
||||
|
||||
type UnixSocket struct {
|
||||
|
|
|
@ -85,6 +85,8 @@ func DefaultSource() Source {
|
|||
serf_lan = ` + strconv.Itoa(consul.DefaultLANSerfPort) + `
|
||||
serf_wan = ` + strconv.Itoa(consul.DefaultWANSerfPort) + `
|
||||
server = ` + strconv.Itoa(consul.DefaultRPCPort) + `
|
||||
proxy_min_port = 20000
|
||||
proxy_max_port = 20255
|
||||
}
|
||||
telemetry = {
|
||||
metrics_prefix = "consul"
|
||||
|
@ -108,6 +110,10 @@ func DevSource() Source {
|
|||
ui = true
|
||||
log_level = "DEBUG"
|
||||
server = true
|
||||
|
||||
connect = {
|
||||
enabled = true
|
||||
}
|
||||
performance = {
|
||||
raft_multiplier = 1
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"golang.org/x/time/rate"
|
||||
|
@ -281,6 +282,11 @@ type RuntimeConfig struct {
|
|||
// hcl: dns_config { udp_answer_limit = int }
|
||||
DNSUDPAnswerLimit int
|
||||
|
||||
// DNSNodeMetaTXT controls whether DNS queries will synthesize
|
||||
// TXT records for the node metadata and add them when not specifically
|
||||
// request (query type = TXT). If unset this will default to true
|
||||
DNSNodeMetaTXT bool
|
||||
|
||||
// DNSRecursors can be set to allow the DNS servers to recursively
|
||||
// resolve non-consul domains.
|
||||
//
|
||||
|
@ -299,177 +305,8 @@ type RuntimeConfig struct {
|
|||
// hcl: http_config { response_headers = map[string]string }
|
||||
HTTPResponseHeaders map[string]string
|
||||
|
||||
// TelemetryCirconus*: see https://github.com/circonus-labs/circonus-gometrics
|
||||
// for more details on the various configuration options.
|
||||
// Valid configuration combinations:
|
||||
// - CirconusAPIToken
|
||||
// metric management enabled (search for existing check or create a new one)
|
||||
// - CirconusSubmissionUrl
|
||||
// metric management disabled (use check with specified submission_url,
|
||||
// broker must be using a public SSL certificate)
|
||||
// - CirconusAPIToken + CirconusCheckSubmissionURL
|
||||
// metric management enabled (use check with specified submission_url)
|
||||
// - CirconusAPIToken + CirconusCheckID
|
||||
// metric management enabled (use check with specified id)
|
||||
|
||||
// TelemetryCirconusAPIApp is an app name associated with API token.
|
||||
// Default: "consul"
|
||||
//
|
||||
// hcl: telemetry { circonus_api_app = string }
|
||||
TelemetryCirconusAPIApp string
|
||||
|
||||
// TelemetryCirconusAPIToken is a valid API Token used to create/manage check. If provided,
|
||||
// metric management is enabled.
|
||||
// Default: none
|
||||
//
|
||||
// hcl: telemetry { circonus_api_token = string }
|
||||
TelemetryCirconusAPIToken string
|
||||
|
||||
// TelemetryCirconusAPIURL is the base URL to use for contacting the Circonus API.
|
||||
// Default: "https://api.circonus.com/v2"
|
||||
//
|
||||
// hcl: telemetry { circonus_api_url = string }
|
||||
TelemetryCirconusAPIURL string
|
||||
|
||||
// TelemetryCirconusBrokerID is an explicit broker to use when creating a new check. The numeric portion
|
||||
// of broker._cid. If metric management is enabled and neither a Submission URL nor Check ID
|
||||
// is provided, an attempt will be made to search for an existing check using Instance ID and
|
||||
// Search Tag. If one is not found, a new HTTPTRAP check will be created.
|
||||
// Default: use Select Tag if provided, otherwise, a random Enterprise Broker associated
|
||||
// with the specified API token or the default Circonus Broker.
|
||||
// Default: none
|
||||
//
|
||||
// hcl: telemetry { circonus_broker_id = string }
|
||||
TelemetryCirconusBrokerID string
|
||||
|
||||
// TelemetryCirconusBrokerSelectTag is a special tag which will be used to select a broker when
|
||||
// a Broker ID is not provided. The best use of this is to as a hint for which broker
|
||||
// should be used based on *where* this particular instance is running.
|
||||
// (e.g. a specific geo location or datacenter, dc:sfo)
|
||||
// Default: none
|
||||
//
|
||||
// hcl: telemetry { circonus_broker_select_tag = string }
|
||||
TelemetryCirconusBrokerSelectTag string
|
||||
|
||||
// TelemetryCirconusCheckDisplayName is the name for the check which will be displayed in the Circonus UI.
|
||||
// Default: value of CirconusCheckInstanceID
|
||||
//
|
||||
// hcl: telemetry { circonus_check_display_name = string }
|
||||
TelemetryCirconusCheckDisplayName string
|
||||
|
||||
// TelemetryCirconusCheckForceMetricActivation will force enabling metrics, as they are encountered,
|
||||
// if the metric already exists and is NOT active. If check management is enabled, the default
|
||||
// behavior is to add new metrics as they are encountered. If the metric already exists in the
|
||||
// check, it will *NOT* be activated. This setting overrides that behavior.
|
||||
// Default: "false"
|
||||
//
|
||||
// hcl: telemetry { circonus_check_metrics_activation = (true|false)
|
||||
TelemetryCirconusCheckForceMetricActivation string
|
||||
|
||||
// TelemetryCirconusCheckID is the check id (not check bundle id) from a previously created
|
||||
// HTTPTRAP check. The numeric portion of the check._cid field.
|
||||
// Default: none
|
||||
//
|
||||
// hcl: telemetry { circonus_check_id = string }
|
||||
TelemetryCirconusCheckID string
|
||||
|
||||
// TelemetryCirconusCheckInstanceID serves to uniquely identify the metrics coming from this "instance".
|
||||
// It can be used to maintain metric continuity with transient or ephemeral instances as
|
||||
// they move around within an infrastructure.
|
||||
// Default: hostname:app
|
||||
//
|
||||
// hcl: telemetry { circonus_check_instance_id = string }
|
||||
TelemetryCirconusCheckInstanceID string
|
||||
|
||||
// TelemetryCirconusCheckSearchTag is a special tag which, when coupled with the instance id, helps to
|
||||
// narrow down the search results when neither a Submission URL or Check ID is provided.
|
||||
// Default: service:app (e.g. service:consul)
|
||||
//
|
||||
// hcl: telemetry { circonus_check_search_tag = string }
|
||||
TelemetryCirconusCheckSearchTag string
|
||||
|
||||
// TelemetryCirconusCheckSearchTag is a special tag which, when coupled with the instance id, helps to
|
||||
// narrow down the search results when neither a Submission URL or Check ID is provided.
|
||||
// Default: service:app (e.g. service:consul)
|
||||
//
|
||||
// hcl: telemetry { circonus_check_tags = string }
|
||||
TelemetryCirconusCheckTags string
|
||||
|
||||
// TelemetryCirconusSubmissionInterval is the interval at which metrics are submitted to Circonus.
|
||||
// Default: 10s
|
||||
//
|
||||
// hcl: telemetry { circonus_submission_interval = "duration" }
|
||||
TelemetryCirconusSubmissionInterval string
|
||||
|
||||
// TelemetryCirconusCheckSubmissionURL is the check.config.submission_url field from a
|
||||
// previously created HTTPTRAP check.
|
||||
// Default: none
|
||||
//
|
||||
// hcl: telemetry { circonus_submission_url = string }
|
||||
TelemetryCirconusSubmissionURL string
|
||||
|
||||
// DisableHostname will disable hostname prefixing for all metrics.
|
||||
//
|
||||
// hcl: telemetry { disable_hostname = (true|false)
|
||||
TelemetryDisableHostname bool
|
||||
|
||||
// TelemetryDogStatsdAddr is the address of a dogstatsd instance. If provided,
|
||||
// metrics will be sent to that instance
|
||||
//
|
||||
// hcl: telemetry { dogstatsd_addr = string }
|
||||
TelemetryDogstatsdAddr string
|
||||
|
||||
// TelemetryDogStatsdTags are the global tags that should be sent with each packet to dogstatsd
|
||||
// It is a list of strings, where each string looks like "my_tag_name:my_tag_value"
|
||||
//
|
||||
// hcl: telemetry { dogstatsd_tags = []string }
|
||||
TelemetryDogstatsdTags []string
|
||||
|
||||
// PrometheusRetentionTime is the retention time for prometheus metrics if greater than 0.
|
||||
// A value of 0 disable Prometheus support. Regarding Prometheus, it is considered a good
|
||||
// practice to put large values here (such as a few days), and at least the interval between
|
||||
// prometheus requests.
|
||||
//
|
||||
// hcl: telemetry { prometheus_retention_time = "duration" }
|
||||
TelemetryPrometheusRetentionTime time.Duration
|
||||
|
||||
// TelemetryFilterDefault is the default for whether to allow a metric that's not
|
||||
// covered by the filter.
|
||||
//
|
||||
// hcl: telemetry { filter_default = (true|false) }
|
||||
TelemetryFilterDefault bool
|
||||
|
||||
// TelemetryAllowedPrefixes is a list of filter rules to apply for allowing metrics
|
||||
// by prefix. Use the 'prefix_filter' option and prefix rules with '+' to be
|
||||
// included.
|
||||
//
|
||||
// hcl: telemetry { prefix_filter = []string{"+<expr>", "+<expr>", ...} }
|
||||
TelemetryAllowedPrefixes []string
|
||||
|
||||
// TelemetryBlockedPrefixes is a list of filter rules to apply for blocking metrics
|
||||
// by prefix. Use the 'prefix_filter' option and prefix rules with '-' to be
|
||||
// excluded.
|
||||
//
|
||||
// hcl: telemetry { prefix_filter = []string{"-<expr>", "-<expr>", ...} }
|
||||
TelemetryBlockedPrefixes []string
|
||||
|
||||
// TelemetryMetricsPrefix is the prefix used to write stats values to.
|
||||
// Default: "consul."
|
||||
//
|
||||
// hcl: telemetry { metrics_prefix = string }
|
||||
TelemetryMetricsPrefix string
|
||||
|
||||
// TelemetryStatsdAddr is the address of a statsd instance. If provided,
|
||||
// metrics will be sent to that instance.
|
||||
//
|
||||
// hcl: telemetry { statsd_addr = string }
|
||||
TelemetryStatsdAddr string
|
||||
|
||||
// TelemetryStatsiteAddr is the address of a statsite instance. If provided,
|
||||
// metrics will be streamed to that instance.
|
||||
//
|
||||
// hcl: telemetry { statsite_addr = string }
|
||||
TelemetryStatsiteAddr string
|
||||
// Embed Telemetry Config
|
||||
Telemetry lib.TelemetryConfig
|
||||
|
||||
// Datacenter is the datacenter this node is in. Defaults to "dc1".
|
||||
//
|
||||
|
@ -616,6 +453,61 @@ type RuntimeConfig struct {
|
|||
// flag: -client string
|
||||
ClientAddrs []*net.IPAddr
|
||||
|
||||
// ConnectEnabled opts the agent into connect. It should be set on all clients
|
||||
// and servers in a cluster for correct connect operation.
|
||||
ConnectEnabled bool
|
||||
|
||||
// ConnectProxyBindMinPort is the inclusive start of the range of ports
|
||||
// allocated to the agent for starting proxy listeners on where no explicit
|
||||
// port is specified.
|
||||
ConnectProxyBindMinPort int
|
||||
|
||||
// ConnectProxyBindMaxPort is the inclusive end of the range of ports
|
||||
// allocated to the agent for starting proxy listeners on where no explicit
|
||||
// port is specified.
|
||||
ConnectProxyBindMaxPort int
|
||||
|
||||
// ConnectProxyAllowManagedRoot is true if Consul can execute managed
|
||||
// proxies when running as root (EUID == 0).
|
||||
ConnectProxyAllowManagedRoot bool
|
||||
|
||||
// ConnectProxyAllowManagedAPIRegistration enables managed proxy registration
|
||||
// via the agent HTTP API. If this is false, only file configurations
|
||||
// can be used.
|
||||
ConnectProxyAllowManagedAPIRegistration bool
|
||||
|
||||
// ConnectProxyDefaultExecMode is used where a registration doesn't include an
|
||||
// exec_mode. Defaults to daemon.
|
||||
ConnectProxyDefaultExecMode string
|
||||
|
||||
// ConnectProxyDefaultDaemonCommand is used to start proxy in exec_mode =
|
||||
// daemon if not specified at registration time.
|
||||
ConnectProxyDefaultDaemonCommand []string
|
||||
|
||||
// ConnectProxyDefaultScriptCommand is used to start proxy in exec_mode =
|
||||
// script if not specified at registration time.
|
||||
ConnectProxyDefaultScriptCommand []string
|
||||
|
||||
// ConnectProxyDefaultConfig is merged with any config specified at
|
||||
// registration time to allow global control of defaults.
|
||||
ConnectProxyDefaultConfig map[string]interface{}
|
||||
|
||||
// ConnectCAProvider is the type of CA provider to use with Connect.
|
||||
ConnectCAProvider string
|
||||
|
||||
// ConnectCAConfig is the config to use for the CA provider.
|
||||
ConnectCAConfig map[string]interface{}
|
||||
|
||||
// ConnectTestDisableManagedProxies is not exposed to public config but us
|
||||
// used by TestAgent to prevent self-executing the test binary in the
|
||||
// background if a managed proxy is created for a test. The only place we
|
||||
// actually want to test processes really being spun up and managed is in
|
||||
// `agent/proxy` and it does it at a lower level. Note that this still allows
|
||||
// registering managed proxies via API and other methods, and still creates
|
||||
// all the agent state for them, just doesn't actually start external
|
||||
// processes up.
|
||||
ConnectTestDisableManagedProxies bool
|
||||
|
||||
// DNSAddrs contains the list of TCP and UDP addresses the DNS server will
|
||||
// bind to. If the DNS endpoint is disabled (ports.dns <= 0) the list is
|
||||
// empty.
|
||||
|
|
|
@ -19,10 +19,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/sergi/go-diff/diffmatchpatch"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type configTest struct {
|
||||
|
@ -31,6 +32,7 @@ type configTest struct {
|
|||
pre, post func()
|
||||
json, jsontail []string
|
||||
hcl, hcltail []string
|
||||
skipformat bool
|
||||
privatev4 func() ([]*net.IPAddr, error)
|
||||
publicv6 func() ([]*net.IPAddr, error)
|
||||
patch func(rt *RuntimeConfig)
|
||||
|
@ -263,6 +265,7 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
|
|||
rt.AdvertiseAddrLAN = ipAddr("127.0.0.1")
|
||||
rt.AdvertiseAddrWAN = ipAddr("127.0.0.1")
|
||||
rt.BindAddr = ipAddr("127.0.0.1")
|
||||
rt.ConnectEnabled = true
|
||||
rt.DevMode = true
|
||||
rt.DisableAnonymousSignature = true
|
||||
rt.DisableKeyringFile = true
|
||||
|
@ -1850,8 +1853,8 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
|
|||
`},
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.DataDir = dataDir
|
||||
rt.TelemetryAllowedPrefixes = []string{"foo"}
|
||||
rt.TelemetryBlockedPrefixes = []string{"bar"}
|
||||
rt.Telemetry.AllowedPrefixes = []string{"foo"}
|
||||
rt.Telemetry.BlockedPrefixes = []string{"bar"}
|
||||
},
|
||||
warns: []string{`Filter rule must begin with either '+' or '-': "nix"`},
|
||||
},
|
||||
|
@ -2068,6 +2071,127 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
|
|||
rt.DataDir = dataDir
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
desc: "HCL service managed proxy 'upstreams'",
|
||||
args: []string{
|
||||
`-data-dir=` + dataDir,
|
||||
},
|
||||
hcl: []string{
|
||||
`service {
|
||||
name = "web"
|
||||
port = 8080
|
||||
connect {
|
||||
proxy {
|
||||
config {
|
||||
upstreams {
|
||||
local_bind_port = 1234
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
skipformat: true, // skipping JSON cause we get slightly diff types (okay)
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.DataDir = dataDir
|
||||
rt.Services = []*structs.ServiceDefinition{
|
||||
&structs.ServiceDefinition{
|
||||
Name: "web",
|
||||
Port: 8080,
|
||||
Connect: &structs.ServiceConnect{
|
||||
Proxy: &structs.ServiceDefinitionConnectProxy{
|
||||
Config: map[string]interface{}{
|
||||
"upstreams": []map[string]interface{}{
|
||||
map[string]interface{}{
|
||||
"local_bind_port": 1234,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "JSON service managed proxy 'upstreams'",
|
||||
args: []string{
|
||||
`-data-dir=` + dataDir,
|
||||
},
|
||||
json: []string{
|
||||
`{
|
||||
"service": {
|
||||
"name": "web",
|
||||
"port": 8080,
|
||||
"connect": {
|
||||
"proxy": {
|
||||
"config": {
|
||||
"upstreams": [{
|
||||
"local_bind_port": 1234
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
skipformat: true, // skipping HCL cause we get slightly diff types (okay)
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.DataDir = dataDir
|
||||
rt.Services = []*structs.ServiceDefinition{
|
||||
&structs.ServiceDefinition{
|
||||
Name: "web",
|
||||
Port: 8080,
|
||||
Connect: &structs.ServiceConnect{
|
||||
Proxy: &structs.ServiceDefinitionConnectProxy{
|
||||
Config: map[string]interface{}{
|
||||
"upstreams": []interface{}{
|
||||
map[string]interface{}{
|
||||
"local_bind_port": float64(1234),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
desc: "enabling Connect allow_managed_root",
|
||||
args: []string{
|
||||
`-data-dir=` + dataDir,
|
||||
},
|
||||
json: []string{
|
||||
`{ "connect": { "proxy": { "allow_managed_root": true } } }`,
|
||||
},
|
||||
hcl: []string{
|
||||
`connect { proxy { allow_managed_root = true } }`,
|
||||
},
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.DataDir = dataDir
|
||||
rt.ConnectProxyAllowManagedRoot = true
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
desc: "enabling Connect allow_managed_api_registration",
|
||||
args: []string{
|
||||
`-data-dir=` + dataDir,
|
||||
},
|
||||
json: []string{
|
||||
`{ "connect": { "proxy": { "allow_managed_api_registration": true } } }`,
|
||||
},
|
||||
hcl: []string{
|
||||
`connect { proxy { allow_managed_api_registration = true } }`,
|
||||
},
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.DataDir = dataDir
|
||||
rt.ConnectProxyAllowManagedAPIRegistration = true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testConfig(t, tests, dataDir)
|
||||
|
@ -2089,7 +2213,7 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
|
|||
|
||||
// json and hcl sources need to be in sync
|
||||
// to make sure we're generating the same config
|
||||
if len(tt.json) != len(tt.hcl) {
|
||||
if len(tt.json) != len(tt.hcl) && !tt.skipformat {
|
||||
t.Fatal(tt.desc, ": JSON and HCL test case out of sync")
|
||||
}
|
||||
|
||||
|
@ -2099,6 +2223,12 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
|
|||
srcs, tails = tt.hcl, tt.hcltail
|
||||
}
|
||||
|
||||
// If we're skipping a format and the current format is empty,
|
||||
// then skip it!
|
||||
if tt.skipformat && len(srcs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// build the description
|
||||
var desc []string
|
||||
if !flagsOnly {
|
||||
|
@ -2353,6 +2483,23 @@ func TestFullConfig(t *testing.T) {
|
|||
],
|
||||
"check_update_interval": "16507s",
|
||||
"client_addr": "93.83.18.19",
|
||||
"connect": {
|
||||
"ca_provider": "consul",
|
||||
"ca_config": {
|
||||
"RotationPeriod": "90h"
|
||||
},
|
||||
"enabled": true,
|
||||
"proxy_defaults": {
|
||||
"exec_mode": "script",
|
||||
"daemon_command": ["consul", "connect", "proxy"],
|
||||
"script_command": ["proxyctl.sh"],
|
||||
"config": {
|
||||
"foo": "bar",
|
||||
"connect_timeout_ms": 1000,
|
||||
"pedantic_mode": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"data_dir": "` + dataDir + `",
|
||||
"datacenter": "rzo029wg",
|
||||
"disable_anonymous_signature": true,
|
||||
|
@ -2417,7 +2564,9 @@ func TestFullConfig(t *testing.T) {
|
|||
"dns": 7001,
|
||||
"http": 7999,
|
||||
"https": 15127,
|
||||
"server": 3757
|
||||
"server": 3757,
|
||||
"proxy_min_port": 2000,
|
||||
"proxy_max_port": 3000
|
||||
},
|
||||
"protocol": 30793,
|
||||
"raft_protocol": 19016,
|
||||
|
@ -2613,7 +2762,16 @@ func TestFullConfig(t *testing.T) {
|
|||
"ttl": "11222s",
|
||||
"deregister_critical_service_after": "68482s"
|
||||
}
|
||||
]
|
||||
],
|
||||
"connect": {
|
||||
"proxy": {
|
||||
"exec_mode": "daemon",
|
||||
"command": ["awesome-proxy"],
|
||||
"config": {
|
||||
"foo": "qux"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"session_ttl_min": "26627s",
|
||||
|
@ -2786,6 +2944,25 @@ func TestFullConfig(t *testing.T) {
|
|||
]
|
||||
check_update_interval = "16507s"
|
||||
client_addr = "93.83.18.19"
|
||||
connect {
|
||||
ca_provider = "consul"
|
||||
ca_config {
|
||||
"RotationPeriod" = "90h"
|
||||
}
|
||||
enabled = true
|
||||
proxy_defaults {
|
||||
exec_mode = "script"
|
||||
daemon_command = ["consul", "connect", "proxy"]
|
||||
script_command = ["proxyctl.sh"]
|
||||
config = {
|
||||
foo = "bar"
|
||||
# hack float since json parses numbers as float and we have to
|
||||
# assert against the same thing
|
||||
connect_timeout_ms = 1000.0
|
||||
pedantic_mode = true
|
||||
}
|
||||
}
|
||||
}
|
||||
data_dir = "` + dataDir + `"
|
||||
datacenter = "rzo029wg"
|
||||
disable_anonymous_signature = true
|
||||
|
@ -2851,6 +3028,8 @@ func TestFullConfig(t *testing.T) {
|
|||
http = 7999,
|
||||
https = 15127
|
||||
server = 3757
|
||||
proxy_min_port = 2000
|
||||
proxy_max_port = 3000
|
||||
}
|
||||
protocol = 30793
|
||||
raft_protocol = 19016
|
||||
|
@ -3047,6 +3226,15 @@ func TestFullConfig(t *testing.T) {
|
|||
deregister_critical_service_after = "68482s"
|
||||
}
|
||||
]
|
||||
connect {
|
||||
proxy {
|
||||
exec_mode = "daemon"
|
||||
command = ["awesome-proxy"]
|
||||
config = {
|
||||
foo = "qux"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
session_ttl_min = "26627s"
|
||||
|
@ -3355,8 +3543,25 @@ func TestFullConfig(t *testing.T) {
|
|||
DeregisterCriticalServiceAfter: 13209 * time.Second,
|
||||
},
|
||||
},
|
||||
CheckUpdateInterval: 16507 * time.Second,
|
||||
ClientAddrs: []*net.IPAddr{ipAddr("93.83.18.19")},
|
||||
CheckUpdateInterval: 16507 * time.Second,
|
||||
ClientAddrs: []*net.IPAddr{ipAddr("93.83.18.19")},
|
||||
ConnectEnabled: true,
|
||||
ConnectProxyBindMinPort: 2000,
|
||||
ConnectProxyBindMaxPort: 3000,
|
||||
ConnectCAProvider: "consul",
|
||||
ConnectCAConfig: map[string]interface{}{
|
||||
"RotationPeriod": "90h",
|
||||
},
|
||||
ConnectProxyAllowManagedRoot: false,
|
||||
ConnectProxyAllowManagedAPIRegistration: false,
|
||||
ConnectProxyDefaultExecMode: "script",
|
||||
ConnectProxyDefaultDaemonCommand: []string{"consul", "connect", "proxy"},
|
||||
ConnectProxyDefaultScriptCommand: []string{"proxyctl.sh"},
|
||||
ConnectProxyDefaultConfig: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"connect_timeout_ms": float64(1000),
|
||||
"pedantic_mode": true,
|
||||
},
|
||||
DNSAddrs: []net.Addr{tcpAddr("93.95.95.81:7001"), udpAddr("93.95.95.81:7001")},
|
||||
DNSARecordLimit: 29907,
|
||||
DNSAllowStale: true,
|
||||
|
@ -3371,6 +3576,7 @@ func TestFullConfig(t *testing.T) {
|
|||
DNSRecursors: []string{"63.38.39.58", "92.49.18.18"},
|
||||
DNSServiceTTL: map[string]time.Duration{"*": 32030 * time.Second},
|
||||
DNSUDPAnswerLimit: 29909,
|
||||
DNSNodeMetaTXT: true,
|
||||
DataDir: dataDir,
|
||||
Datacenter: "rzo029wg",
|
||||
DevMode: true,
|
||||
|
@ -3529,6 +3735,15 @@ func TestFullConfig(t *testing.T) {
|
|||
DeregisterCriticalServiceAfter: 68482 * time.Second,
|
||||
},
|
||||
},
|
||||
Connect: &structs.ServiceConnect{
|
||||
Proxy: &structs.ServiceDefinitionConnectProxy{
|
||||
ExecMode: "daemon",
|
||||
Command: []string{"awesome-proxy"},
|
||||
Config: map[string]interface{}{
|
||||
"foo": "qux",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "dLOXpSCI",
|
||||
|
@ -3606,41 +3821,43 @@ func TestFullConfig(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
SerfAdvertiseAddrLAN: tcpAddr("17.99.29.16:8301"),
|
||||
SerfAdvertiseAddrWAN: tcpAddr("78.63.37.19:8302"),
|
||||
SerfBindAddrLAN: tcpAddr("99.43.63.15:8301"),
|
||||
SerfBindAddrWAN: tcpAddr("67.88.33.19:8302"),
|
||||
SessionTTLMin: 26627 * time.Second,
|
||||
SkipLeaveOnInt: true,
|
||||
StartJoinAddrsLAN: []string{"LR3hGDoG", "MwVpZ4Up"},
|
||||
StartJoinAddrsWAN: []string{"EbFSc3nA", "kwXTh623"},
|
||||
SyslogFacility: "hHv79Uia",
|
||||
TelemetryCirconusAPIApp: "p4QOTe9j",
|
||||
TelemetryCirconusAPIToken: "E3j35V23",
|
||||
TelemetryCirconusAPIURL: "mEMjHpGg",
|
||||
TelemetryCirconusBrokerID: "BHlxUhed",
|
||||
TelemetryCirconusBrokerSelectTag: "13xy1gHm",
|
||||
TelemetryCirconusCheckDisplayName: "DRSlQR6n",
|
||||
TelemetryCirconusCheckForceMetricActivation: "Ua5FGVYf",
|
||||
TelemetryCirconusCheckID: "kGorutad",
|
||||
TelemetryCirconusCheckInstanceID: "rwoOL6R4",
|
||||
TelemetryCirconusCheckSearchTag: "ovT4hT4f",
|
||||
TelemetryCirconusCheckTags: "prvO4uBl",
|
||||
TelemetryCirconusSubmissionInterval: "DolzaflP",
|
||||
TelemetryCirconusSubmissionURL: "gTcbS93G",
|
||||
TelemetryDisableHostname: true,
|
||||
TelemetryDogstatsdAddr: "0wSndumK",
|
||||
TelemetryDogstatsdTags: []string{"3N81zSUB", "Xtj8AnXZ"},
|
||||
TelemetryFilterDefault: true,
|
||||
TelemetryAllowedPrefixes: []string{"oJotS8XJ"},
|
||||
TelemetryBlockedPrefixes: []string{"cazlEhGn"},
|
||||
TelemetryMetricsPrefix: "ftO6DySn",
|
||||
TelemetryPrometheusRetentionTime: 15 * time.Second,
|
||||
TelemetryStatsdAddr: "drce87cy",
|
||||
TelemetryStatsiteAddr: "HpFwKB8R",
|
||||
TLSCipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384},
|
||||
TLSMinVersion: "pAOWafkR",
|
||||
TLSPreferServerCipherSuites: true,
|
||||
SerfAdvertiseAddrLAN: tcpAddr("17.99.29.16:8301"),
|
||||
SerfAdvertiseAddrWAN: tcpAddr("78.63.37.19:8302"),
|
||||
SerfBindAddrLAN: tcpAddr("99.43.63.15:8301"),
|
||||
SerfBindAddrWAN: tcpAddr("67.88.33.19:8302"),
|
||||
SessionTTLMin: 26627 * time.Second,
|
||||
SkipLeaveOnInt: true,
|
||||
StartJoinAddrsLAN: []string{"LR3hGDoG", "MwVpZ4Up"},
|
||||
StartJoinAddrsWAN: []string{"EbFSc3nA", "kwXTh623"},
|
||||
SyslogFacility: "hHv79Uia",
|
||||
Telemetry: lib.TelemetryConfig{
|
||||
CirconusAPIApp: "p4QOTe9j",
|
||||
CirconusAPIToken: "E3j35V23",
|
||||
CirconusAPIURL: "mEMjHpGg",
|
||||
CirconusBrokerID: "BHlxUhed",
|
||||
CirconusBrokerSelectTag: "13xy1gHm",
|
||||
CirconusCheckDisplayName: "DRSlQR6n",
|
||||
CirconusCheckForceMetricActivation: "Ua5FGVYf",
|
||||
CirconusCheckID: "kGorutad",
|
||||
CirconusCheckInstanceID: "rwoOL6R4",
|
||||
CirconusCheckSearchTag: "ovT4hT4f",
|
||||
CirconusCheckTags: "prvO4uBl",
|
||||
CirconusSubmissionInterval: "DolzaflP",
|
||||
CirconusSubmissionURL: "gTcbS93G",
|
||||
DisableHostname: true,
|
||||
DogstatsdAddr: "0wSndumK",
|
||||
DogstatsdTags: []string{"3N81zSUB", "Xtj8AnXZ"},
|
||||
FilterDefault: true,
|
||||
AllowedPrefixes: []string{"oJotS8XJ"},
|
||||
BlockedPrefixes: []string{"cazlEhGn"},
|
||||
MetricsPrefix: "ftO6DySn",
|
||||
PrometheusRetentionTime: 15 * time.Second,
|
||||
StatsdAddr: "drce87cy",
|
||||
StatsiteAddr: "HpFwKB8R",
|
||||
},
|
||||
TLSCipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384},
|
||||
TLSMinVersion: "pAOWafkR",
|
||||
TLSPreferServerCipherSuites: true,
|
||||
TaggedAddresses: map[string]string{
|
||||
"7MYgHrYH": "dALJAhLD",
|
||||
"h6DdBy6K": "ebrr9zZ8",
|
||||
|
@ -4018,6 +4235,18 @@ func TestSanitize(t *testing.T) {
|
|||
}
|
||||
],
|
||||
"ClientAddrs": [],
|
||||
"ConnectCAConfig": {},
|
||||
"ConnectCAProvider": "",
|
||||
"ConnectEnabled": false,
|
||||
"ConnectProxyAllowManagedAPIRegistration": false,
|
||||
"ConnectProxyAllowManagedRoot": false,
|
||||
"ConnectProxyBindMaxPort": 0,
|
||||
"ConnectProxyBindMinPort": 0,
|
||||
"ConnectProxyDefaultConfig": {},
|
||||
"ConnectProxyDefaultDaemonCommand": [],
|
||||
"ConnectProxyDefaultExecMode": "",
|
||||
"ConnectProxyDefaultScriptCommand": [],
|
||||
"ConnectTestDisableManagedProxies": false,
|
||||
"ConsulCoordinateUpdateBatchSize": 0,
|
||||
"ConsulCoordinateUpdateMaxBatches": 0,
|
||||
"ConsulCoordinateUpdatePeriod": "15s",
|
||||
|
@ -4043,6 +4272,7 @@ func TestSanitize(t *testing.T) {
|
|||
"DNSDomain": "",
|
||||
"DNSEnableTruncate": false,
|
||||
"DNSMaxStale": "0s",
|
||||
"DNSNodeMetaTXT": false,
|
||||
"DNSNodeTTL": "0s",
|
||||
"DNSOnlyPassing": false,
|
||||
"DNSPort": 0,
|
||||
|
@ -4148,11 +4378,14 @@ func TestSanitize(t *testing.T) {
|
|||
"Timeout": "0s"
|
||||
},
|
||||
"Checks": [],
|
||||
"Connect": null,
|
||||
"EnableTagOverride": false,
|
||||
"ID": "",
|
||||
"Kind": "",
|
||||
"Meta": {},
|
||||
"Name": "foo",
|
||||
"Port": 0,
|
||||
"ProxyDestination": "",
|
||||
"Tags": [],
|
||||
"Token": "hidden"
|
||||
}
|
||||
|
@ -4168,29 +4401,31 @@ func TestSanitize(t *testing.T) {
|
|||
"TLSMinVersion": "",
|
||||
"TLSPreferServerCipherSuites": false,
|
||||
"TaggedAddresses": {},
|
||||
"TelemetryAllowedPrefixes": [],
|
||||
"TelemetryBlockedPrefixes": [],
|
||||
"TelemetryCirconusAPIApp": "",
|
||||
"TelemetryCirconusAPIToken": "hidden",
|
||||
"TelemetryCirconusAPIURL": "",
|
||||
"TelemetryCirconusBrokerID": "",
|
||||
"TelemetryCirconusBrokerSelectTag": "",
|
||||
"TelemetryCirconusCheckDisplayName": "",
|
||||
"TelemetryCirconusCheckForceMetricActivation": "",
|
||||
"TelemetryCirconusCheckID": "",
|
||||
"TelemetryCirconusCheckInstanceID": "",
|
||||
"TelemetryCirconusCheckSearchTag": "",
|
||||
"TelemetryCirconusCheckTags": "",
|
||||
"TelemetryCirconusSubmissionInterval": "",
|
||||
"TelemetryCirconusSubmissionURL": "",
|
||||
"TelemetryDisableHostname": false,
|
||||
"TelemetryDogstatsdAddr": "",
|
||||
"TelemetryDogstatsdTags": [],
|
||||
"TelemetryFilterDefault": false,
|
||||
"TelemetryMetricsPrefix": "",
|
||||
"TelemetryPrometheusRetentionTime": "0s",
|
||||
"TelemetryStatsdAddr": "",
|
||||
"TelemetryStatsiteAddr": "",
|
||||
"Telemetry":{
|
||||
"AllowedPrefixes": [],
|
||||
"BlockedPrefixes": [],
|
||||
"CirconusAPIApp": "",
|
||||
"CirconusAPIToken": "hidden",
|
||||
"CirconusAPIURL": "",
|
||||
"CirconusBrokerID": "",
|
||||
"CirconusBrokerSelectTag": "",
|
||||
"CirconusCheckDisplayName": "",
|
||||
"CirconusCheckForceMetricActivation": "",
|
||||
"CirconusCheckID": "",
|
||||
"CirconusCheckInstanceID": "",
|
||||
"CirconusCheckSearchTag": "",
|
||||
"CirconusCheckTags": "",
|
||||
"CirconusSubmissionInterval": "",
|
||||
"CirconusSubmissionURL": "",
|
||||
"DisableHostname": false,
|
||||
"DogstatsdAddr": "",
|
||||
"DogstatsdTags": [],
|
||||
"FilterDefault": false,
|
||||
"MetricsPrefix": "",
|
||||
"PrometheusRetentionTime": "0s",
|
||||
"StatsdAddr": "",
|
||||
"StatsiteAddr": ""
|
||||
},
|
||||
"TranslateWANAddrs": false,
|
||||
"UIDir": "",
|
||||
"UnixSocketGroup": "",
|
||||
|
@ -4210,11 +4445,7 @@ func TestSanitize(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := string(b), rtJSON; got != want {
|
||||
dmp := diffmatchpatch.New()
|
||||
diffs := dmp.DiffMain(want, got, false)
|
||||
t.Fatal(dmp.DiffPrettyText(diffs))
|
||||
}
|
||||
require.JSONEq(t, rtJSON, string(b))
|
||||
}
|
||||
|
||||
func splitIPPort(hostport string) (net.IP, int) {
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
// Provider is the interface for Consul to interact with
|
||||
// an external CA that provides leaf certificate signing for
|
||||
// given SpiffeIDServices.
|
||||
type Provider interface {
|
||||
// Active root returns the currently active root CA for this
|
||||
// provider. This should be a parent of the certificate returned by
|
||||
// ActiveIntermediate()
|
||||
ActiveRoot() (string, error)
|
||||
|
||||
// ActiveIntermediate returns the current signing cert used by this provider
|
||||
// for generating SPIFFE leaf certs. Note that this must not change except
|
||||
// when Consul requests the change via GenerateIntermediate. Changing the
|
||||
// signing cert will break Consul's assumptions about which validation paths
|
||||
// are active.
|
||||
ActiveIntermediate() (string, error)
|
||||
|
||||
// GenerateIntermediate returns a new intermediate signing cert and sets it to
|
||||
// the active intermediate. If multiple intermediates are needed to complete
|
||||
// the chain from the signing certificate back to the active root, they should
|
||||
// all by bundled here.
|
||||
GenerateIntermediate() (string, error)
|
||||
|
||||
// Sign signs a leaf certificate used by Connect proxies from a CSR. The PEM
|
||||
// returned should include only the leaf certificate as all Intermediates
|
||||
// needed to validate it will be added by Consul based on the active
|
||||
// intemediate and any cross-signed intermediates managed by Consul.
|
||||
Sign(*x509.CertificateRequest) (string, error)
|
||||
|
||||
// CrossSignCA must accept a CA certificate from another CA provider
|
||||
// and cross sign it exactly as it is such that it forms a chain back the the
|
||||
// CAProvider's current root. Specifically, the Distinguished Name, Subject
|
||||
// Alternative Name, SubjectKeyID and other relevant extensions must be kept.
|
||||
// The resulting certificate must have a distinct Serial Number and the
|
||||
// AuthorityKeyID set to the CAProvider's current signing key as well as the
|
||||
// Issuer related fields changed as necessary. The resulting certificate is
|
||||
// returned as a PEM formatted string.
|
||||
CrossSignCA(*x509.Certificate) (string, error)
|
||||
|
||||
// Cleanup performs any necessary cleanup that should happen when the provider
|
||||
// is shut down permanently, such as removing a temporary PKI backend in Vault
|
||||
// created for an intermediate CA.
|
||||
Cleanup() error
|
||||
}
|
|
@ -0,0 +1,379 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
type ConsulProvider struct {
|
||||
config *structs.ConsulCAProviderConfig
|
||||
id string
|
||||
delegate ConsulProviderStateDelegate
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
type ConsulProviderStateDelegate interface {
|
||||
State() *state.Store
|
||||
ApplyCARequest(*structs.CARequest) error
|
||||
}
|
||||
|
||||
// NewConsulProvider returns a new instance of the Consul CA provider,
|
||||
// bootstrapping its state in the state store necessary
|
||||
func NewConsulProvider(rawConfig map[string]interface{}, delegate ConsulProviderStateDelegate) (*ConsulProvider, error) {
|
||||
conf, err := ParseConsulCAConfig(rawConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider := &ConsulProvider{
|
||||
config: conf,
|
||||
delegate: delegate,
|
||||
id: fmt.Sprintf("%s,%s", conf.PrivateKey, conf.RootCert),
|
||||
}
|
||||
|
||||
// Check if this configuration of the provider has already been
|
||||
// initialized in the state store.
|
||||
state := delegate.State()
|
||||
_, providerState, err := state.CAProviderState(provider.id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Exit early if the state store has already been populated for this config.
|
||||
if providerState != nil {
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
newState := structs.CAConsulProviderState{
|
||||
ID: provider.id,
|
||||
}
|
||||
|
||||
// Write the initial provider state to get the index to use for the
|
||||
// CA serial number.
|
||||
{
|
||||
args := &structs.CARequest{
|
||||
Op: structs.CAOpSetProviderState,
|
||||
ProviderState: &newState,
|
||||
}
|
||||
if err := delegate.ApplyCARequest(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
idx, _, err := state.CAProviderState(provider.id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate a private key if needed
|
||||
if conf.PrivateKey == "" {
|
||||
_, pk, err := connect.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newState.PrivateKey = pk
|
||||
} else {
|
||||
newState.PrivateKey = conf.PrivateKey
|
||||
}
|
||||
|
||||
// Generate the root CA if necessary
|
||||
if conf.RootCert == "" {
|
||||
ca, err := provider.generateCA(newState.PrivateKey, idx+1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating CA: %v", err)
|
||||
}
|
||||
newState.RootCert = ca
|
||||
} else {
|
||||
newState.RootCert = conf.RootCert
|
||||
}
|
||||
|
||||
// Write the provider state
|
||||
args := &structs.CARequest{
|
||||
Op: structs.CAOpSetProviderState,
|
||||
ProviderState: &newState,
|
||||
}
|
||||
if err := delegate.ApplyCARequest(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// Return the active root CA and generate a new one if needed
|
||||
func (c *ConsulProvider) ActiveRoot() (string, error) {
|
||||
state := c.delegate.State()
|
||||
_, providerState, err := state.CAProviderState(c.id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return providerState.RootCert, nil
|
||||
}
|
||||
|
||||
// We aren't maintaining separate root/intermediate CAs for the builtin
|
||||
// provider, so just return the root.
|
||||
func (c *ConsulProvider) ActiveIntermediate() (string, error) {
|
||||
return c.ActiveRoot()
|
||||
}
|
||||
|
||||
// We aren't maintaining separate root/intermediate CAs for the builtin
|
||||
// provider, so just return the root.
|
||||
func (c *ConsulProvider) GenerateIntermediate() (string, error) {
|
||||
return c.ActiveIntermediate()
|
||||
}
|
||||
|
||||
// Remove the state store entry for this provider instance.
|
||||
func (c *ConsulProvider) Cleanup() error {
|
||||
args := &structs.CARequest{
|
||||
Op: structs.CAOpDeleteProviderState,
|
||||
ProviderState: &structs.CAConsulProviderState{ID: c.id},
|
||||
}
|
||||
if err := c.delegate.ApplyCARequest(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign returns a new certificate valid for the given SpiffeIDService
|
||||
// using the current CA.
|
||||
func (c *ConsulProvider) Sign(csr *x509.CertificateRequest) (string, error) {
|
||||
// Lock during the signing so we don't use the same index twice
|
||||
// for different cert serial numbers.
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// Get the provider state
|
||||
state := c.delegate.State()
|
||||
idx, providerState, err := state.CAProviderState(c.id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create the keyId for the cert from the signing private key.
|
||||
signer, err := connect.ParseSigner(providerState.PrivateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if signer == nil {
|
||||
return "", fmt.Errorf("error signing cert: Consul CA not initialized yet")
|
||||
}
|
||||
keyId, err := connect.KeyId(signer.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse the SPIFFE ID
|
||||
spiffeId, err := connect.ParseCertURI(csr.URIs[0])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
serviceId, ok := spiffeId.(*connect.SpiffeIDService)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("SPIFFE ID in CSR must be a service ID")
|
||||
}
|
||||
|
||||
// Parse the CA cert
|
||||
caCert, err := connect.ParseCert(providerState.RootCert)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing CA cert: %s", err)
|
||||
}
|
||||
|
||||
// Cert template for generation
|
||||
sn := &big.Int{}
|
||||
sn.SetUint64(idx + 1)
|
||||
// Sign the certificate valid from 1 minute in the past, this helps it be
|
||||
// accepted right away even when nodes are not in close time sync accross the
|
||||
// cluster. A minute is more than enough for typical DC clock drift.
|
||||
effectiveNow := time.Now().Add(-1 * time.Minute)
|
||||
template := x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
Subject: pkix.Name{CommonName: serviceId.Service},
|
||||
URIs: csr.URIs,
|
||||
Signature: csr.Signature,
|
||||
SignatureAlgorithm: csr.SignatureAlgorithm,
|
||||
PublicKeyAlgorithm: csr.PublicKeyAlgorithm,
|
||||
PublicKey: csr.PublicKey,
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageDataEncipherment |
|
||||
x509.KeyUsageKeyAgreement |
|
||||
x509.KeyUsageDigitalSignature |
|
||||
x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
// todo(kyhavlov): add a way to set the cert lifetime here from the CA config
|
||||
NotAfter: effectiveNow.Add(3 * 24 * time.Hour),
|
||||
NotBefore: effectiveNow,
|
||||
AuthorityKeyId: keyId,
|
||||
SubjectKeyId: keyId,
|
||||
}
|
||||
|
||||
// Create the certificate, PEM encode it and return that value.
|
||||
var buf bytes.Buffer
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, caCert, csr.PublicKey, signer)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error generating certificate: %s", err)
|
||||
}
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encoding certificate: %s", err)
|
||||
}
|
||||
|
||||
err = c.incrementProviderIndex(providerState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Set the response
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// CrossSignCA returns the given CA cert signed by the current active root.
|
||||
func (c *ConsulProvider) CrossSignCA(cert *x509.Certificate) (string, error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// Get the provider state
|
||||
state := c.delegate.State()
|
||||
idx, providerState, err := state.CAProviderState(c.id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privKey, err := connect.ParseSigner(providerState.PrivateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing private key %q: %s", providerState.PrivateKey, err)
|
||||
}
|
||||
|
||||
rootCA, err := connect.ParseCert(providerState.RootCert)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyId, err := connect.KeyId(privKey.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create the cross-signing template from the existing root CA
|
||||
serialNum := &big.Int{}
|
||||
serialNum.SetUint64(idx + 1)
|
||||
template := *cert
|
||||
template.SerialNumber = serialNum
|
||||
template.SignatureAlgorithm = rootCA.SignatureAlgorithm
|
||||
template.AuthorityKeyId = keyId
|
||||
|
||||
// Sign the certificate valid from 1 minute in the past, this helps it be
|
||||
// accepted right away even when nodes are not in close time sync accross the
|
||||
// cluster. A minute is more than enough for typical DC clock drift.
|
||||
effectiveNow := time.Now().Add(-1 * time.Minute)
|
||||
template.NotBefore = effectiveNow
|
||||
// This cross-signed cert is only needed during rotation, and only while old
|
||||
// leaf certs are still in use. They expire within 3 days currently so 7 is
|
||||
// safe. TODO(banks): make this be based on leaf expiry time when that is
|
||||
// configurable.
|
||||
template.NotAfter = effectiveNow.Add(7 * 24 * time.Hour)
|
||||
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, rootCA, cert.PublicKey, privKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error generating CA certificate: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
err = c.incrementProviderIndex(providerState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// incrementProviderIndex does a write to increment the provider state store table index
|
||||
// used for serial numbers when generating certificates.
|
||||
func (c *ConsulProvider) incrementProviderIndex(providerState *structs.CAConsulProviderState) error {
|
||||
newState := *providerState
|
||||
args := &structs.CARequest{
|
||||
Op: structs.CAOpSetProviderState,
|
||||
ProviderState: &newState,
|
||||
}
|
||||
if err := c.delegate.ApplyCARequest(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateCA makes a new root CA using the current private key
|
||||
func (c *ConsulProvider) generateCA(privateKey string, sn uint64) (string, error) {
|
||||
state := c.delegate.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privKey, err := connect.ParseSigner(privateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing private key %q: %s", privateKey, err)
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("Consul CA %d", sn)
|
||||
|
||||
// The URI (SPIFFE compatible) for the cert
|
||||
id := connect.SpiffeIDSigningForCluster(config)
|
||||
keyId, err := connect.KeyId(privKey.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create the CA cert
|
||||
serialNum := &big.Int{}
|
||||
serialNum.SetUint64(sn)
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNum,
|
||||
Subject: pkix.Name{CommonName: name},
|
||||
URIs: []*url.URL{id.URI()},
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageCertSign |
|
||||
x509.KeyUsageCRLSign |
|
||||
x509.KeyUsageDigitalSignature,
|
||||
IsCA: true,
|
||||
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
|
||||
NotBefore: time.Now(),
|
||||
AuthorityKeyId: keyId,
|
||||
SubjectKeyId: keyId,
|
||||
}
|
||||
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, &template, privKey.Public(), privKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error generating CA certificate: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
func ParseConsulCAConfig(raw map[string]interface{}) (*structs.ConsulCAProviderConfig, error) {
|
||||
var config structs.ConsulCAProviderConfig
|
||||
decodeConf := &mapstructure.DecoderConfig{
|
||||
DecodeHook: ParseDurationFunc(),
|
||||
ErrorUnused: true,
|
||||
Result: &config,
|
||||
WeaklyTypedInput: true,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(decodeConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := decoder.Decode(raw); err != nil {
|
||||
return nil, fmt.Errorf("error decoding config: %s", err)
|
||||
}
|
||||
|
||||
if config.PrivateKey == "" && config.RootCert != "" {
|
||||
return nil, fmt.Errorf("must provide a private key when providing a root cert")
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// ParseDurationFunc is a mapstructure hook for decoding a string or
|
||||
// []uint8 into a time.Duration value.
|
||||
func ParseDurationFunc() mapstructure.DecodeHookFunc {
|
||||
return func(
|
||||
f reflect.Type,
|
||||
t reflect.Type,
|
||||
data interface{}) (interface{}, error) {
|
||||
var v time.Duration
|
||||
if t != reflect.TypeOf(v) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case f.Kind() == reflect.String:
|
||||
if dur, err := time.ParseDuration(data.(string)); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
v = dur
|
||||
}
|
||||
return v, nil
|
||||
case f == reflect.SliceOf(reflect.TypeOf(uint8(0))):
|
||||
s := Uint8ToString(data.([]uint8))
|
||||
if dur, err := time.ParseDuration(s); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
v = dur
|
||||
}
|
||||
return v, nil
|
||||
default:
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Uint8ToString(bs []uint8) string {
|
||||
b := make([]byte, len(bs))
|
||||
for i, v := range bs {
|
||||
b[i] = byte(v)
|
||||
}
|
||||
return string(b)
|
||||
}
|
|
@ -0,0 +1,266 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type consulCAMockDelegate struct {
|
||||
state *state.Store
|
||||
}
|
||||
|
||||
func (c *consulCAMockDelegate) State() *state.Store {
|
||||
return c.state
|
||||
}
|
||||
|
||||
func (c *consulCAMockDelegate) ApplyCARequest(req *structs.CARequest) error {
|
||||
idx, _, err := c.state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch req.Op {
|
||||
case structs.CAOpSetProviderState:
|
||||
_, err := c.state.CASetProviderState(idx+1, req.ProviderState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
case structs.CAOpDeleteProviderState:
|
||||
if err := c.state.CADeleteProviderState(req.ProviderState.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("Invalid CA operation '%s'", req.Op)
|
||||
}
|
||||
}
|
||||
|
||||
func newMockDelegate(t *testing.T, conf *structs.CAConfiguration) *consulCAMockDelegate {
|
||||
s, err := state.NewStateStore(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if s == nil {
|
||||
t.Fatalf("missing state store")
|
||||
}
|
||||
if err := s.CASetConfig(conf.RaftIndex.CreateIndex, conf); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
return &consulCAMockDelegate{s}
|
||||
}
|
||||
|
||||
func testConsulCAConfig() *structs.CAConfiguration {
|
||||
return &structs.CAConfiguration{
|
||||
ClusterID: "asdf",
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsulCAProvider_Bootstrap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
conf := testConsulCAConfig()
|
||||
delegate := newMockDelegate(t, conf)
|
||||
|
||||
provider, err := NewConsulProvider(conf.Config, delegate)
|
||||
assert.NoError(err)
|
||||
|
||||
root, err := provider.ActiveRoot()
|
||||
assert.NoError(err)
|
||||
|
||||
// Intermediate should be the same cert.
|
||||
inter, err := provider.ActiveIntermediate()
|
||||
assert.NoError(err)
|
||||
assert.Equal(root, inter)
|
||||
|
||||
// Should be a valid cert
|
||||
parsed, err := connect.ParseCert(root)
|
||||
assert.NoError(err)
|
||||
assert.Equal(parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", conf.ClusterID))
|
||||
}
|
||||
|
||||
func TestConsulCAProvider_Bootstrap_WithCert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Make sure setting a custom private key/root cert works.
|
||||
assert := assert.New(t)
|
||||
rootCA := connect.TestCA(t, nil)
|
||||
conf := testConsulCAConfig()
|
||||
conf.Config = map[string]interface{}{
|
||||
"PrivateKey": rootCA.SigningKey,
|
||||
"RootCert": rootCA.RootCert,
|
||||
}
|
||||
delegate := newMockDelegate(t, conf)
|
||||
|
||||
provider, err := NewConsulProvider(conf.Config, delegate)
|
||||
assert.NoError(err)
|
||||
|
||||
root, err := provider.ActiveRoot()
|
||||
assert.NoError(err)
|
||||
assert.Equal(root, rootCA.RootCert)
|
||||
}
|
||||
|
||||
func TestConsulCAProvider_SignLeaf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
conf := testConsulCAConfig()
|
||||
delegate := newMockDelegate(t, conf)
|
||||
|
||||
provider, err := NewConsulProvider(conf.Config, delegate)
|
||||
assert.NoError(err)
|
||||
|
||||
spiffeService := &connect.SpiffeIDService{
|
||||
Host: "node1",
|
||||
Namespace: "default",
|
||||
Datacenter: "dc1",
|
||||
Service: "foo",
|
||||
}
|
||||
|
||||
// Generate a leaf cert for the service.
|
||||
{
|
||||
raw, _ := connect.TestCSR(t, spiffeService)
|
||||
|
||||
csr, err := connect.ParseCSR(raw)
|
||||
assert.NoError(err)
|
||||
|
||||
cert, err := provider.Sign(csr)
|
||||
assert.NoError(err)
|
||||
|
||||
parsed, err := connect.ParseCert(cert)
|
||||
assert.NoError(err)
|
||||
assert.Equal(parsed.URIs[0], spiffeService.URI())
|
||||
assert.Equal(parsed.Subject.CommonName, "foo")
|
||||
assert.Equal(uint64(2), parsed.SerialNumber.Uint64())
|
||||
|
||||
// Ensure the cert is valid now and expires within the correct limit.
|
||||
assert.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
|
||||
assert.True(parsed.NotBefore.Before(time.Now()))
|
||||
}
|
||||
|
||||
// Generate a new cert for another service and make sure
|
||||
// the serial number is incremented.
|
||||
spiffeService.Service = "bar"
|
||||
{
|
||||
raw, _ := connect.TestCSR(t, spiffeService)
|
||||
|
||||
csr, err := connect.ParseCSR(raw)
|
||||
assert.NoError(err)
|
||||
|
||||
cert, err := provider.Sign(csr)
|
||||
assert.NoError(err)
|
||||
|
||||
parsed, err := connect.ParseCert(cert)
|
||||
assert.NoError(err)
|
||||
assert.Equal(parsed.URIs[0], spiffeService.URI())
|
||||
assert.Equal(parsed.Subject.CommonName, "bar")
|
||||
assert.Equal(parsed.SerialNumber.Uint64(), uint64(2))
|
||||
|
||||
// Ensure the cert is valid now and expires within the correct limit.
|
||||
assert.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
|
||||
assert.True(parsed.NotBefore.Before(time.Now()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsulCAProvider_CrossSignCA(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conf1 := testConsulCAConfig()
|
||||
delegate1 := newMockDelegate(t, conf1)
|
||||
provider1, err := NewConsulProvider(conf1.Config, delegate1)
|
||||
require.NoError(t, err)
|
||||
|
||||
conf2 := testConsulCAConfig()
|
||||
conf2.CreateIndex = 10
|
||||
delegate2 := newMockDelegate(t, conf2)
|
||||
provider2, err := NewConsulProvider(conf2.Config, delegate2)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCrossSignProviders(t, provider1, provider2)
|
||||
}
|
||||
|
||||
func testCrossSignProviders(t *testing.T, provider1, provider2 Provider) {
|
||||
require := require.New(t)
|
||||
|
||||
// Get the root from the new provider to be cross-signed.
|
||||
newRootPEM, err := provider2.ActiveRoot()
|
||||
require.NoError(err)
|
||||
newRoot, err := connect.ParseCert(newRootPEM)
|
||||
require.NoError(err)
|
||||
oldSubject := newRoot.Subject.CommonName
|
||||
|
||||
newInterPEM, err := provider2.ActiveIntermediate()
|
||||
require.NoError(err)
|
||||
newIntermediate, err := connect.ParseCert(newInterPEM)
|
||||
require.NoError(err)
|
||||
|
||||
// Have provider1 cross sign our new root cert.
|
||||
xcPEM, err := provider1.CrossSignCA(newRoot)
|
||||
require.NoError(err)
|
||||
xc, err := connect.ParseCert(xcPEM)
|
||||
require.NoError(err)
|
||||
|
||||
oldRootPEM, err := provider1.ActiveRoot()
|
||||
require.NoError(err)
|
||||
oldRoot, err := connect.ParseCert(oldRootPEM)
|
||||
require.NoError(err)
|
||||
|
||||
// AuthorityKeyID should now be the signing root's, SubjectKeyId should be kept.
|
||||
require.Equal(oldRoot.AuthorityKeyId, xc.AuthorityKeyId)
|
||||
require.Equal(newRoot.SubjectKeyId, xc.SubjectKeyId)
|
||||
|
||||
// Subject name should not have changed.
|
||||
require.Equal(oldSubject, xc.Subject.CommonName)
|
||||
|
||||
// Issuer should be the signing root.
|
||||
require.Equal(oldRoot.Issuer.CommonName, xc.Issuer.CommonName)
|
||||
|
||||
// Get a leaf cert so we can verify against the cross-signed cert.
|
||||
spiffeService := &connect.SpiffeIDService{
|
||||
Host: "node1",
|
||||
Namespace: "default",
|
||||
Datacenter: "dc1",
|
||||
Service: "foo",
|
||||
}
|
||||
raw, _ := connect.TestCSR(t, spiffeService)
|
||||
|
||||
leafCsr, err := connect.ParseCSR(raw)
|
||||
require.NoError(err)
|
||||
|
||||
leafPEM, err := provider2.Sign(leafCsr)
|
||||
require.NoError(err)
|
||||
|
||||
cert, err := connect.ParseCert(leafPEM)
|
||||
require.NoError(err)
|
||||
|
||||
// Check that the leaf signed by the new cert can be verified by either root
|
||||
// certificate by using the new intermediate + cross-signed cert.
|
||||
intermediatePool := x509.NewCertPool()
|
||||
intermediatePool.AddCert(newIntermediate)
|
||||
intermediatePool.AddCert(xc)
|
||||
|
||||
for _, root := range []*x509.Certificate{oldRoot, newRoot} {
|
||||
rootPool := x509.NewCertPool()
|
||||
rootPool.AddCert(root)
|
||||
|
||||
_, err = cert.Verify(x509.VerifyOptions{
|
||||
Intermediates: intermediatePool,
|
||||
Roots: rootPool,
|
||||
})
|
||||
require.NoError(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,322 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
vaultapi "github.com/hashicorp/vault/api"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const VaultCALeafCertRole = "leaf-cert"
|
||||
|
||||
var ErrBackendNotMounted = fmt.Errorf("backend not mounted")
|
||||
var ErrBackendNotInitialized = fmt.Errorf("backend not initialized")
|
||||
|
||||
type VaultProvider struct {
|
||||
config *structs.VaultCAProviderConfig
|
||||
client *vaultapi.Client
|
||||
clusterId string
|
||||
}
|
||||
|
||||
// NewVaultProvider returns a vault provider with its root and intermediate PKI
|
||||
// backends mounted and initialized. If the root backend is not set up already,
|
||||
// it will be mounted/generated as needed, but any existing state will not be
|
||||
// overwritten.
|
||||
func NewVaultProvider(rawConfig map[string]interface{}, clusterId string) (*VaultProvider, error) {
|
||||
conf, err := ParseVaultCAConfig(rawConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// todo(kyhavlov): figure out the right way to pass the TLS config
|
||||
clientConf := &vaultapi.Config{
|
||||
Address: conf.Address,
|
||||
}
|
||||
client, err := vaultapi.NewClient(clientConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client.SetToken(conf.Token)
|
||||
|
||||
provider := &VaultProvider{
|
||||
config: conf,
|
||||
client: client,
|
||||
clusterId: clusterId,
|
||||
}
|
||||
|
||||
// Set up the root PKI backend if necessary.
|
||||
_, err = provider.ActiveRoot()
|
||||
switch err {
|
||||
case ErrBackendNotMounted:
|
||||
err := client.Sys().Mount(conf.RootPKIPath, &vaultapi.MountInput{
|
||||
Type: "pki",
|
||||
Description: "root CA backend for Consul Connect",
|
||||
Config: vaultapi.MountConfigInput{
|
||||
MaxLeaseTTL: "8760h",
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fallthrough
|
||||
case ErrBackendNotInitialized:
|
||||
spiffeID := connect.SpiffeIDSigning{ClusterID: clusterId, Domain: "consul"}
|
||||
uuid, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = client.Logical().Write(conf.RootPKIPath+"root/generate/internal", map[string]interface{}{
|
||||
"common_name": fmt.Sprintf("Vault CA Root Authority %s", uuid),
|
||||
"uri_sans": spiffeID.URI().String(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Set up the intermediate backend.
|
||||
if _, err := provider.GenerateIntermediate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (v *VaultProvider) ActiveRoot() (string, error) {
|
||||
return v.getCA(v.config.RootPKIPath)
|
||||
}
|
||||
|
||||
func (v *VaultProvider) ActiveIntermediate() (string, error) {
|
||||
return v.getCA(v.config.IntermediatePKIPath)
|
||||
}
|
||||
|
||||
// getCA returns the raw CA cert for the given endpoint if there is one.
|
||||
// We have to use the raw NewRequest call here instead of Logical().Read
|
||||
// because the endpoint only returns the raw PEM contents of the CA cert
|
||||
// and not the typical format of the secrets endpoints.
|
||||
func (v *VaultProvider) getCA(path string) (string, error) {
|
||||
req := v.client.NewRequest("GET", "/v1/"+path+"/ca/pem")
|
||||
resp, err := v.client.RawRequest(req)
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if resp != nil && resp.StatusCode == http.StatusNotFound {
|
||||
return "", ErrBackendNotMounted
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
bytes, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
root := string(bytes)
|
||||
if root == "" {
|
||||
return "", ErrBackendNotInitialized
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// GenerateIntermediate mounts the configured intermediate PKI backend if
|
||||
// necessary, then generates and signs a new CA CSR using the root PKI backend
|
||||
// and updates the intermediate backend to use that new certificate.
|
||||
func (v *VaultProvider) GenerateIntermediate() (string, error) {
|
||||
mounts, err := v.client.Sys().ListMounts()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Mount the backend if it isn't mounted already.
|
||||
if _, ok := mounts[v.config.IntermediatePKIPath]; !ok {
|
||||
err := v.client.Sys().Mount(v.config.IntermediatePKIPath, &vaultapi.MountInput{
|
||||
Type: "pki",
|
||||
Description: "intermediate CA backend for Consul Connect",
|
||||
Config: vaultapi.MountConfigInput{
|
||||
MaxLeaseTTL: "2160h",
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Create the role for issuing leaf certs if it doesn't exist yet
|
||||
rolePath := v.config.IntermediatePKIPath + "roles/" + VaultCALeafCertRole
|
||||
role, err := v.client.Logical().Read(rolePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
spiffeID := connect.SpiffeIDSigning{ClusterID: v.clusterId, Domain: "consul"}
|
||||
if role == nil {
|
||||
_, err := v.client.Logical().Write(rolePath, map[string]interface{}{
|
||||
"allow_any_name": true,
|
||||
"allowed_uri_sans": "spiffe://*",
|
||||
"key_type": "any",
|
||||
"max_ttl": "72h",
|
||||
"require_cn": false,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a new intermediate CSR for the root to sign.
|
||||
csr, err := v.client.Logical().Write(v.config.IntermediatePKIPath+"intermediate/generate/internal", map[string]interface{}{
|
||||
"common_name": "Vault CA Intermediate Authority",
|
||||
"uri_sans": spiffeID.URI().String(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if csr == nil || csr.Data["csr"] == "" {
|
||||
return "", fmt.Errorf("got empty value when generating intermediate CSR")
|
||||
}
|
||||
|
||||
// Sign the CSR with the root backend.
|
||||
intermediate, err := v.client.Logical().Write(v.config.RootPKIPath+"root/sign-intermediate", map[string]interface{}{
|
||||
"csr": csr.Data["csr"],
|
||||
"format": "pem_bundle",
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if intermediate == nil || intermediate.Data["certificate"] == "" {
|
||||
return "", fmt.Errorf("got empty value when generating intermediate certificate")
|
||||
}
|
||||
|
||||
// Set the intermediate backend to use the new certificate.
|
||||
_, err = v.client.Logical().Write(v.config.IntermediatePKIPath+"intermediate/set-signed", map[string]interface{}{
|
||||
"certificate": intermediate.Data["certificate"],
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return v.ActiveIntermediate()
|
||||
}
|
||||
|
||||
// Sign calls the configured role in the intermediate PKI backend to issue
|
||||
// a new leaf certificate based on the provided CSR, with the issuing
|
||||
// intermediate CA cert attached.
|
||||
func (v *VaultProvider) Sign(csr *x509.CertificateRequest) (string, error) {
|
||||
var pemBuf bytes.Buffer
|
||||
if err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Use the leaf cert role to sign a new cert for this CSR.
|
||||
response, err := v.client.Logical().Write(v.config.IntermediatePKIPath+"sign/"+VaultCALeafCertRole, map[string]interface{}{
|
||||
"csr": pemBuf.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error issuing cert: %v", err)
|
||||
}
|
||||
if response == nil || response.Data["certificate"] == "" || response.Data["issuing_ca"] == "" {
|
||||
return "", fmt.Errorf("certificate info returned from Vault was blank")
|
||||
}
|
||||
|
||||
cert, ok := response.Data["certificate"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("certificate was not a string")
|
||||
}
|
||||
ca, ok := response.Data["issuing_ca"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("issuing_ca was not a string")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s\n%s", cert, ca), nil
|
||||
}
|
||||
|
||||
// CrossSignCA takes a CA certificate and cross-signs it to form a trust chain
|
||||
// back to our active root.
|
||||
func (v *VaultProvider) CrossSignCA(cert *x509.Certificate) (string, error) {
|
||||
var pemBuf bytes.Buffer
|
||||
err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Have the root PKI backend sign this cert.
|
||||
response, err := v.client.Logical().Write(v.config.RootPKIPath+"root/sign-self-issued", map[string]interface{}{
|
||||
"certificate": pemBuf.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error having Vault cross-sign cert: %v", err)
|
||||
}
|
||||
if response == nil || response.Data["certificate"] == "" {
|
||||
return "", fmt.Errorf("certificate info returned from Vault was blank")
|
||||
}
|
||||
|
||||
xcCert, ok := response.Data["certificate"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("certificate was not a string")
|
||||
}
|
||||
|
||||
return xcCert, nil
|
||||
}
|
||||
|
||||
// Cleanup unmounts the configured intermediate PKI backend. It's fine to tear
|
||||
// this down and recreate it on small config changes because the intermediate
|
||||
// certs get bundled with the leaf certs, so there's no cost to the CA changing.
|
||||
func (v *VaultProvider) Cleanup() error {
|
||||
return v.client.Sys().Unmount(v.config.IntermediatePKIPath)
|
||||
}
|
||||
|
||||
func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) {
|
||||
var config structs.VaultCAProviderConfig
|
||||
|
||||
decodeConf := &mapstructure.DecoderConfig{
|
||||
ErrorUnused: true,
|
||||
Result: &config,
|
||||
WeaklyTypedInput: true,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(decodeConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := decoder.Decode(raw); err != nil {
|
||||
return nil, fmt.Errorf("error decoding config: %s", err)
|
||||
}
|
||||
|
||||
if config.Token == "" {
|
||||
return nil, fmt.Errorf("must provide a Vault token")
|
||||
}
|
||||
|
||||
if config.RootPKIPath == "" {
|
||||
return nil, fmt.Errorf("must provide a valid path to a root PKI backend")
|
||||
}
|
||||
if !strings.HasSuffix(config.RootPKIPath, "/") {
|
||||
config.RootPKIPath += "/"
|
||||
}
|
||||
|
||||
if config.IntermediatePKIPath == "" {
|
||||
return nil, fmt.Errorf("must provide a valid path for the intermediate PKI backend")
|
||||
}
|
||||
if !strings.HasSuffix(config.IntermediatePKIPath, "/") {
|
||||
config.IntermediatePKIPath += "/"
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
vaultapi "github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/pki"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testVaultCluster(t *testing.T) (*VaultProvider, *vault.Core, net.Listener) {
|
||||
if err := vault.AddTestLogicalBackend("pki", pki.Factory); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
core, _, token := vault.TestCoreUnsealedRaw(t)
|
||||
|
||||
ln, addr := vaulthttp.TestServer(t, core)
|
||||
|
||||
provider, err := NewVaultProvider(map[string]interface{}{
|
||||
"Address": addr,
|
||||
"Token": token,
|
||||
"RootPKIPath": "pki-root/",
|
||||
"IntermediatePKIPath": "pki-intermediate/",
|
||||
}, "asdf")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return provider, core, ln
|
||||
}
|
||||
|
||||
func TestVaultCAProvider_Bootstrap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
provider, core, listener := testVaultCluster(t)
|
||||
defer core.Shutdown()
|
||||
defer listener.Close()
|
||||
client, err := vaultapi.NewClient(&vaultapi.Config{
|
||||
Address: "http://" + listener.Addr().String(),
|
||||
})
|
||||
require.NoError(err)
|
||||
client.SetToken(provider.config.Token)
|
||||
|
||||
cases := []struct {
|
||||
certFunc func() (string, error)
|
||||
backendPath string
|
||||
}{
|
||||
{
|
||||
certFunc: provider.ActiveRoot,
|
||||
backendPath: "pki-root/",
|
||||
},
|
||||
{
|
||||
certFunc: provider.ActiveIntermediate,
|
||||
backendPath: "pki-intermediate/",
|
||||
},
|
||||
}
|
||||
|
||||
// Verify the root and intermediate certs match the ones in the vault backends
|
||||
for _, tc := range cases {
|
||||
cert, err := tc.certFunc()
|
||||
require.NoError(err)
|
||||
req := client.NewRequest("GET", "/v1/"+tc.backendPath+"ca/pem")
|
||||
resp, err := client.RawRequest(req)
|
||||
require.NoError(err)
|
||||
bytes, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(err)
|
||||
require.Equal(cert, string(bytes))
|
||||
|
||||
// Should be a valid CA cert
|
||||
parsed, err := connect.ParseCert(cert)
|
||||
require.NoError(err)
|
||||
require.True(parsed.IsCA)
|
||||
require.Len(parsed.URIs, 1)
|
||||
require.Equal(parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", provider.clusterId))
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultCAProvider_SignLeaf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
provider, core, listener := testVaultCluster(t)
|
||||
defer core.Shutdown()
|
||||
defer listener.Close()
|
||||
client, err := vaultapi.NewClient(&vaultapi.Config{
|
||||
Address: "http://" + listener.Addr().String(),
|
||||
})
|
||||
require.NoError(err)
|
||||
client.SetToken(provider.config.Token)
|
||||
|
||||
spiffeService := &connect.SpiffeIDService{
|
||||
Host: "node1",
|
||||
Namespace: "default",
|
||||
Datacenter: "dc1",
|
||||
Service: "foo",
|
||||
}
|
||||
|
||||
// Generate a leaf cert for the service.
|
||||
var firstSerial uint64
|
||||
{
|
||||
raw, _ := connect.TestCSR(t, spiffeService)
|
||||
|
||||
csr, err := connect.ParseCSR(raw)
|
||||
require.NoError(err)
|
||||
|
||||
cert, err := provider.Sign(csr)
|
||||
require.NoError(err)
|
||||
|
||||
parsed, err := connect.ParseCert(cert)
|
||||
require.NoError(err)
|
||||
require.Equal(parsed.URIs[0], spiffeService.URI())
|
||||
firstSerial = parsed.SerialNumber.Uint64()
|
||||
|
||||
// Ensure the cert is valid now and expires within the correct limit.
|
||||
require.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
|
||||
require.True(parsed.NotBefore.Before(time.Now()))
|
||||
}
|
||||
|
||||
// Generate a new cert for another service and make sure
|
||||
// the serial number is unique.
|
||||
spiffeService.Service = "bar"
|
||||
{
|
||||
raw, _ := connect.TestCSR(t, spiffeService)
|
||||
|
||||
csr, err := connect.ParseCSR(raw)
|
||||
require.NoError(err)
|
||||
|
||||
cert, err := provider.Sign(csr)
|
||||
require.NoError(err)
|
||||
|
||||
parsed, err := connect.ParseCert(cert)
|
||||
require.NoError(err)
|
||||
require.Equal(parsed.URIs[0], spiffeService.URI())
|
||||
require.NotEqual(firstSerial, parsed.SerialNumber.Uint64())
|
||||
|
||||
// Ensure the cert is valid now and expires within the correct limit.
|
||||
require.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
|
||||
require.True(parsed.NotBefore.Before(time.Now()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultCAProvider_CrossSignCA(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider1, core1, listener1 := testVaultCluster(t)
|
||||
defer core1.Shutdown()
|
||||
defer listener1.Close()
|
||||
|
||||
provider2, core2, listener2 := testVaultCluster(t)
|
||||
defer core2.Shutdown()
|
||||
defer listener2.Close()
|
||||
|
||||
testCrossSignProviders(t, provider1, provider2)
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// CreateCSR returns a CSR to sign the given service along with the PEM-encoded
|
||||
// private key for this certificate.
|
||||
func CreateCSR(uri CertURI, privateKey crypto.Signer) (string, error) {
|
||||
template := &x509.CertificateRequest{
|
||||
URIs: []*url.URL{uri.URI()},
|
||||
SignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
}
|
||||
|
||||
// Create the CSR itself
|
||||
var csrBuf bytes.Buffer
|
||||
bs, err := x509.CreateCertificateRequest(rand.Reader, template, privateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = pem.Encode(&csrBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: bs})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return csrBuf.String(), nil
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GeneratePrivateKey generates a new Private key
|
||||
func GeneratePrivateKey() (crypto.Signer, string, error) {
|
||||
var pk *ecdsa.PrivateKey
|
||||
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error generating private key: %s", err)
|
||||
}
|
||||
|
||||
bs, err := x509.MarshalECPrivateKey(pk)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error generating private key: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: bs})
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
return pk, buf.String(), nil
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseCert parses the x509 certificate from a PEM-encoded value.
|
||||
func ParseCert(pemValue string) (*x509.Certificate, error) {
|
||||
// The _ result below is not an error but the remaining PEM bytes.
|
||||
block, _ := pem.Decode([]byte(pemValue))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM-encoded data found")
|
||||
}
|
||||
|
||||
if block.Type != "CERTIFICATE" {
|
||||
return nil, fmt.Errorf("first PEM-block should be CERTIFICATE type")
|
||||
}
|
||||
|
||||
return x509.ParseCertificate(block.Bytes)
|
||||
}
|
||||
|
||||
// CalculateCertFingerprint parses the x509 certificate from a PEM-encoded value
|
||||
// and calculates the SHA-1 fingerprint.
|
||||
func CalculateCertFingerprint(pemValue string) (string, error) {
|
||||
// The _ result below is not an error but the remaining PEM bytes.
|
||||
block, _ := pem.Decode([]byte(pemValue))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("no PEM-encoded data found")
|
||||
}
|
||||
|
||||
if block.Type != "CERTIFICATE" {
|
||||
return "", fmt.Errorf("first PEM-block should be CERTIFICATE type")
|
||||
}
|
||||
|
||||
hash := sha1.Sum(block.Bytes)
|
||||
return HexString(hash[:]), nil
|
||||
}
|
||||
|
||||
// ParseSigner parses a crypto.Signer from a PEM-encoded key. The private key
|
||||
// is expected to be the first block in the PEM value.
|
||||
func ParseSigner(pemValue string) (crypto.Signer, error) {
|
||||
// The _ result below is not an error but the remaining PEM bytes.
|
||||
block, _ := pem.Decode([]byte(pemValue))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM-encoded data found")
|
||||
}
|
||||
|
||||
switch block.Type {
|
||||
case "EC PRIVATE KEY":
|
||||
return x509.ParseECPrivateKey(block.Bytes)
|
||||
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
|
||||
case "PRIVATE KEY":
|
||||
signer, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pk, ok := signer.(crypto.Signer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("private key is not a valid format")
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown PEM block type for signing key: %s", block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// ParseCSR parses a CSR from a PEM-encoded value. The certificate request
|
||||
// must be the the first block in the PEM value.
|
||||
func ParseCSR(pemValue string) (*x509.CertificateRequest, error) {
|
||||
// The _ result below is not an error but the remaining PEM bytes.
|
||||
block, _ := pem.Decode([]byte(pemValue))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM-encoded data found")
|
||||
}
|
||||
|
||||
if block.Type != "CERTIFICATE REQUEST" {
|
||||
return nil, fmt.Errorf("first PEM-block should be CERTIFICATE REQUEST type")
|
||||
}
|
||||
|
||||
return x509.ParseCertificateRequest(block.Bytes)
|
||||
}
|
||||
|
||||
// KeyId returns a x509 KeyId from the given signing key. The key must be
|
||||
// an *ecdsa.PublicKey currently, but may support more types in the future.
|
||||
func KeyId(raw interface{}) ([]byte, error) {
|
||||
switch raw.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid key type: %T", raw)
|
||||
}
|
||||
|
||||
// This is not standard; RFC allows any unique identifier as long as they
|
||||
// match in subject/authority chains but suggests specific hashing of DER
|
||||
// bytes of public key including DER tags.
|
||||
bs, err := x509.MarshalPKIXPublicKey(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// String formatted
|
||||
kID := sha256.Sum256(bs)
|
||||
return []byte(strings.Replace(fmt.Sprintf("% x", kID), " ", ":", -1)), nil
|
||||
}
|
||||
|
||||
// HexString returns a standard colon-separated hex value for the input
|
||||
// byte slice. This should be used with cert serial numbers and so on.
|
||||
func HexString(input []byte) string {
|
||||
return strings.Replace(fmt.Sprintf("% x", input), " ", ":", -1)
|
||||
}
|
|
@ -0,0 +1,332 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
)
|
||||
|
||||
// TestClusterID is the Consul cluster ID for testing.
|
||||
const TestClusterID = "11111111-2222-3333-4444-555555555555"
|
||||
|
||||
// testCACounter is just an atomically incremented counter for creating
|
||||
// unique names for the CA certs.
|
||||
var testCACounter uint64
|
||||
|
||||
// TestCA creates a test CA certificate and signing key and returns it
|
||||
// in the CARoot structure format. The returned CA will be set as Active = true.
|
||||
//
|
||||
// If xc is non-nil, then the returned certificate will have a signing cert
|
||||
// that is cross-signed with the previous cert, and this will be set as
|
||||
// SigningCert.
|
||||
func TestCA(t testing.T, xc *structs.CARoot) *structs.CARoot {
|
||||
var result structs.CARoot
|
||||
result.Active = true
|
||||
result.Name = fmt.Sprintf("Test CA %d", atomic.AddUint64(&testCACounter, 1))
|
||||
|
||||
// Create the private key we'll use for this CA cert.
|
||||
signer, keyPEM := testPrivateKey(t)
|
||||
result.SigningKey = keyPEM
|
||||
|
||||
// The serial number for the cert
|
||||
sn, err := testSerialNumber()
|
||||
if err != nil {
|
||||
t.Fatalf("error generating serial number: %s", err)
|
||||
}
|
||||
|
||||
// The URI (SPIFFE compatible) for the cert
|
||||
id := &SpiffeIDSigning{ClusterID: TestClusterID, Domain: "consul"}
|
||||
|
||||
// Create the CA cert
|
||||
template := x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
Subject: pkix.Name{CommonName: result.Name},
|
||||
URIs: []*url.URL{id.URI()},
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageCertSign |
|
||||
x509.KeyUsageCRLSign |
|
||||
x509.KeyUsageDigitalSignature,
|
||||
IsCA: true,
|
||||
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
|
||||
NotBefore: time.Now(),
|
||||
AuthorityKeyId: testKeyID(t, signer.Public()),
|
||||
SubjectKeyId: testKeyID(t, signer.Public()),
|
||||
}
|
||||
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, &template, signer.Public(), signer)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating CA certificate: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
t.Fatalf("error encoding private key: %s", err)
|
||||
}
|
||||
result.RootCert = buf.String()
|
||||
result.ID, err = CalculateCertFingerprint(result.RootCert)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating CA ID fingerprint: %s", err)
|
||||
}
|
||||
|
||||
// If there is a prior CA to cross-sign with, then we need to create that
|
||||
// and set it as the signing cert.
|
||||
if xc != nil {
|
||||
xccert, err := ParseCert(xc.RootCert)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing CA cert: %s", err)
|
||||
}
|
||||
xcsigner, err := ParseSigner(xc.SigningKey)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing signing key: %s", err)
|
||||
}
|
||||
|
||||
// Set the authority key to be the previous one.
|
||||
// NOTE(mitchellh): From Paul Banks: if we have to cross-sign a cert
|
||||
// that came from outside (e.g. vault) we can't rely on them using the
|
||||
// same KeyID hashing algo we do so we'd need to actually copy this
|
||||
// from the xc cert's subjectKeyIdentifier extension.
|
||||
template.AuthorityKeyId = testKeyID(t, xcsigner.Public())
|
||||
|
||||
// Create the new certificate where the parent is the previous
|
||||
// CA, the public key is the new public key, and the signing private
|
||||
// key is the old private key.
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, xccert, signer.Public(), xcsigner)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating CA certificate: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
t.Fatalf("error encoding private key: %s", err)
|
||||
}
|
||||
result.SigningCert = buf.String()
|
||||
}
|
||||
|
||||
return &result
|
||||
}
|
||||
|
||||
// TestLeaf returns a valid leaf certificate and it's private key for the named
|
||||
// service with the given CA Root.
|
||||
func TestLeaf(t testing.T, service string, root *structs.CARoot) (string, string) {
|
||||
// Parse the CA cert and signing key from the root
|
||||
cert := root.SigningCert
|
||||
if cert == "" {
|
||||
cert = root.RootCert
|
||||
}
|
||||
caCert, err := ParseCert(cert)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing CA cert: %s", err)
|
||||
}
|
||||
caSigner, err := ParseSigner(root.SigningKey)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing signing key: %s", err)
|
||||
}
|
||||
|
||||
// Build the SPIFFE ID
|
||||
spiffeId := &SpiffeIDService{
|
||||
Host: fmt.Sprintf("%s.consul", TestClusterID),
|
||||
Namespace: "default",
|
||||
Datacenter: "dc1",
|
||||
Service: service,
|
||||
}
|
||||
|
||||
// The serial number for the cert
|
||||
sn, err := testSerialNumber()
|
||||
if err != nil {
|
||||
t.Fatalf("error generating serial number: %s", err)
|
||||
}
|
||||
|
||||
// Generate fresh private key
|
||||
pkSigner, pkPEM, err := GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key: %s", err)
|
||||
}
|
||||
|
||||
// Cert template for generation
|
||||
template := x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
Subject: pkix.Name{CommonName: service},
|
||||
URIs: []*url.URL{spiffeId.URI()},
|
||||
SignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageDataEncipherment |
|
||||
x509.KeyUsageKeyAgreement |
|
||||
x509.KeyUsageDigitalSignature |
|
||||
x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
|
||||
NotBefore: time.Now(),
|
||||
AuthorityKeyId: testKeyID(t, caSigner.Public()),
|
||||
SubjectKeyId: testKeyID(t, pkSigner.Public()),
|
||||
}
|
||||
|
||||
// Create the certificate, PEM encode it and return that value.
|
||||
var buf bytes.Buffer
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, caCert, pkSigner.Public(), caSigner)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating certificate: %s", err)
|
||||
}
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
t.Fatalf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
return buf.String(), pkPEM
|
||||
}
|
||||
|
||||
// TestCSR returns a CSR to sign the given service along with the PEM-encoded
|
||||
// private key for this certificate.
|
||||
func TestCSR(t testing.T, uri CertURI) (string, string) {
|
||||
template := &x509.CertificateRequest{
|
||||
URIs: []*url.URL{uri.URI()},
|
||||
SignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
}
|
||||
|
||||
// Create the private key we'll use
|
||||
signer, pkPEM := testPrivateKey(t)
|
||||
|
||||
// Create the CSR itself
|
||||
var csrBuf bytes.Buffer
|
||||
bs, err := x509.CreateCertificateRequest(rand.Reader, template, signer)
|
||||
if err != nil {
|
||||
t.Fatalf("error creating CSR: %s", err)
|
||||
}
|
||||
|
||||
err = pem.Encode(&csrBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: bs})
|
||||
if err != nil {
|
||||
t.Fatalf("error encoding CSR: %s", err)
|
||||
}
|
||||
|
||||
return csrBuf.String(), pkPEM
|
||||
}
|
||||
|
||||
// testKeyID returns a KeyID from the given public key. This just calls
|
||||
// KeyId but handles errors for tests.
|
||||
func testKeyID(t testing.T, raw interface{}) []byte {
|
||||
result, err := KeyId(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("KeyId error: %s", err)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// testPrivateKey creates an ECDSA based private key. Both a crypto.Signer and
|
||||
// the key in PEM form are returned.
|
||||
//
|
||||
// NOTE(banks): this was memoized to save entropy during tests but it turns out
|
||||
// crypto/rand will never block and always reads from /dev/urandom on unix OSes
|
||||
// which does not consume entropy.
|
||||
//
|
||||
// If we find by profiling it's taking a lot of cycles we could optimise/cache
|
||||
// again but we at least need to use different keys for each distinct CA (when
|
||||
// multiple CAs are generated at once e.g. to test cross-signing) and a
|
||||
// different one again for the leafs otherwise we risk tests that have false
|
||||
// positives since signatures from different logical cert's keys are
|
||||
// indistinguishable, but worse we build validation chains using AuthorityKeyID
|
||||
// which will be the same for multiple CAs/Leafs. Also note that our UUID
|
||||
// generator also reads from crypto rand and is called far more often during
|
||||
// tests than this will be.
|
||||
func testPrivateKey(t testing.T) (crypto.Signer, string) {
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating private key: %s", err)
|
||||
}
|
||||
|
||||
bs, err := x509.MarshalECPrivateKey(pk)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating private key: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: bs})
|
||||
if err != nil {
|
||||
t.Fatalf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
return pk, buf.String()
|
||||
}
|
||||
|
||||
// testSerialNumber generates a serial number suitable for a certificate.
|
||||
// For testing, this just sets it to a random number.
|
||||
//
|
||||
// This function is taken directly from the Vault implementation.
|
||||
func testSerialNumber() (*big.Int, error) {
|
||||
return rand.Int(rand.Reader, (&big.Int{}).Exp(big.NewInt(2), big.NewInt(159), nil))
|
||||
}
|
||||
|
||||
// testUUID generates a UUID for testing.
|
||||
func testUUID(t testing.T) string {
|
||||
ret, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate a UUID, %s", err)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// TestAgentRPC is an interface that an RPC client must implement. This is a
|
||||
// helper interface that is implemented by the agent delegate so that test
|
||||
// helpers can make RPCs without introducing an import cycle on `agent`.
|
||||
type TestAgentRPC interface {
|
||||
RPC(method string, args interface{}, reply interface{}) error
|
||||
}
|
||||
|
||||
// TestCAConfigSet sets a CARoot returned by TestCA into the TestAgent state. It
|
||||
// requires that TestAgent had connect enabled in it's config. If ca is nil, a
|
||||
// new CA is created.
|
||||
//
|
||||
// It returns the CARoot passed or created.
|
||||
//
|
||||
// Note that we have to use an interface for the TestAgent.RPC method since we
|
||||
// can't introduce an import cycle by importing `agent.TestAgent` here directly.
|
||||
// It also means this will work in a few other places we mock that method.
|
||||
func TestCAConfigSet(t testing.T, a TestAgentRPC,
|
||||
ca *structs.CARoot) *structs.CARoot {
|
||||
t.Helper()
|
||||
|
||||
if ca == nil {
|
||||
ca = TestCA(t, nil)
|
||||
}
|
||||
newConfig := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": ca.SigningKey,
|
||||
"RootCert": ca.RootCert,
|
||||
"RotationPeriod": 180 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
args := &structs.CARequest{
|
||||
Datacenter: "dc1",
|
||||
Config: newConfig,
|
||||
}
|
||||
var reply interface{}
|
||||
|
||||
err := a.RPC("ConnectCA.ConfigurationSet", args, &reply)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set test CA config: %s", err)
|
||||
}
|
||||
return ca
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// hasOpenSSL is used to determine if the openssl CLI exists for unit tests.
|
||||
var hasOpenSSL bool
|
||||
|
||||
func init() {
|
||||
_, err := exec.LookPath("openssl")
|
||||
hasOpenSSL = err == nil
|
||||
}
|
||||
|
||||
// Test that the TestCA and TestLeaf functions generate valid certificates.
|
||||
func TestTestCAAndLeaf(t *testing.T) {
|
||||
if !hasOpenSSL {
|
||||
t.Skip("openssl not found")
|
||||
return
|
||||
}
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
// Create the certs
|
||||
ca := TestCA(t, nil)
|
||||
leaf, _ := TestLeaf(t, "web", ca)
|
||||
|
||||
// Create a temporary directory for storing the certs
|
||||
td, err := ioutil.TempDir("", "consul")
|
||||
assert.Nil(err)
|
||||
defer os.RemoveAll(td)
|
||||
|
||||
// Write the cert
|
||||
assert.Nil(ioutil.WriteFile(filepath.Join(td, "ca.pem"), []byte(ca.RootCert), 0644))
|
||||
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf.pem"), []byte(leaf), 0644))
|
||||
|
||||
// Use OpenSSL to verify so we have an external, known-working process
|
||||
// that can verify this outside of our own implementations.
|
||||
cmd := exec.Command(
|
||||
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf.pem")
|
||||
cmd.Dir = td
|
||||
output, err := cmd.Output()
|
||||
t.Log(string(output))
|
||||
assert.Nil(err)
|
||||
}
|
||||
|
||||
// Test cross-signing.
|
||||
func TestTestCAAndLeaf_xc(t *testing.T) {
|
||||
if !hasOpenSSL {
|
||||
t.Skip("openssl not found")
|
||||
return
|
||||
}
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
// Create the certs
|
||||
ca1 := TestCA(t, nil)
|
||||
ca2 := TestCA(t, ca1)
|
||||
leaf1, _ := TestLeaf(t, "web", ca1)
|
||||
leaf2, _ := TestLeaf(t, "web", ca2)
|
||||
|
||||
// Create a temporary directory for storing the certs
|
||||
td, err := ioutil.TempDir("", "consul")
|
||||
assert.Nil(err)
|
||||
defer os.RemoveAll(td)
|
||||
|
||||
// Write the cert
|
||||
xcbundle := []byte(ca1.RootCert)
|
||||
xcbundle = append(xcbundle, '\n')
|
||||
xcbundle = append(xcbundle, []byte(ca2.SigningCert)...)
|
||||
assert.Nil(ioutil.WriteFile(filepath.Join(td, "ca.pem"), xcbundle, 0644))
|
||||
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf1.pem"), []byte(leaf1), 0644))
|
||||
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf2.pem"), []byte(leaf2), 0644))
|
||||
|
||||
// OpenSSL verify the cross-signed leaf (leaf2)
|
||||
{
|
||||
cmd := exec.Command(
|
||||
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf2.pem")
|
||||
cmd.Dir = td
|
||||
output, err := cmd.Output()
|
||||
t.Log(string(output))
|
||||
assert.Nil(err)
|
||||
}
|
||||
|
||||
// OpenSSL verify the old leaf (leaf1)
|
||||
{
|
||||
cmd := exec.Command(
|
||||
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf1.pem")
|
||||
cmd.Dir = td
|
||||
output, err := cmd.Output()
|
||||
t.Log(string(output))
|
||||
assert.Nil(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
)
|
||||
|
||||
// TestSpiffeIDService returns a SPIFFE ID representing a service.
|
||||
func TestSpiffeIDService(t testing.T, service string) *SpiffeIDService {
|
||||
return TestSpiffeIDServiceWithHost(t, service, TestClusterID+".consul")
|
||||
}
|
||||
|
||||
// TestSpiffeIDServiceWithHost returns a SPIFFE ID representing a service with
|
||||
// the specified trust domain.
|
||||
func TestSpiffeIDServiceWithHost(t testing.T, service, host string) *SpiffeIDService {
|
||||
return &SpiffeIDService{
|
||||
Host: host,
|
||||
Namespace: "default",
|
||||
Datacenter: "dc1",
|
||||
Service: service,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// CertURI represents a Connect-valid URI value for a TLS certificate.
|
||||
// The user should type switch on the various implementations in this
|
||||
// package to determine the type of URI and the data encoded within it.
|
||||
//
|
||||
// Note that the current implementations of this are all also SPIFFE IDs.
|
||||
// However, we anticipate that we may accept URIs that are also not SPIFFE
|
||||
// compliant and therefore the interface is named as such.
|
||||
type CertURI interface {
|
||||
// Authorize tests the authorization for this URI as a client
|
||||
// for the given intention. The return value `auth` is only valid if
|
||||
// the second value `match` is true. If the second value `match` is
|
||||
// false, then the intention doesn't match this client and any
|
||||
// result should be ignored.
|
||||
Authorize(*structs.Intention) (auth bool, match bool)
|
||||
|
||||
// URI is the valid URI value used in the cert.
|
||||
URI() *url.URL
|
||||
}
|
||||
|
||||
var (
|
||||
spiffeIDServiceRegexp = regexp.MustCompile(
|
||||
`^/ns/([^/]+)/dc/([^/]+)/svc/([^/]+)$`)
|
||||
)
|
||||
|
||||
// ParseCertURI parses a the URI value from a TLS certificate.
|
||||
func ParseCertURI(input *url.URL) (CertURI, error) {
|
||||
if input.Scheme != "spiffe" {
|
||||
return nil, fmt.Errorf("SPIFFE ID must have 'spiffe' scheme")
|
||||
}
|
||||
|
||||
// Path is the raw value of the path without url decoding values.
|
||||
// RawPath is empty if there were no encoded values so we must
|
||||
// check both.
|
||||
path := input.Path
|
||||
if input.RawPath != "" {
|
||||
path = input.RawPath
|
||||
}
|
||||
|
||||
// Test for service IDs
|
||||
if v := spiffeIDServiceRegexp.FindStringSubmatch(path); v != nil {
|
||||
// Determine the values. We assume they're sane to save cycles,
|
||||
// but if the raw path is not empty that means that something is
|
||||
// URL encoded so we go to the slow path.
|
||||
ns := v[1]
|
||||
dc := v[2]
|
||||
service := v[3]
|
||||
if input.RawPath != "" {
|
||||
var err error
|
||||
if ns, err = url.PathUnescape(v[1]); err != nil {
|
||||
return nil, fmt.Errorf("Invalid namespace: %s", err)
|
||||
}
|
||||
if dc, err = url.PathUnescape(v[2]); err != nil {
|
||||
return nil, fmt.Errorf("Invalid datacenter: %s", err)
|
||||
}
|
||||
if service, err = url.PathUnescape(v[3]); err != nil {
|
||||
return nil, fmt.Errorf("Invalid service: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &SpiffeIDService{
|
||||
Host: input.Host,
|
||||
Namespace: ns,
|
||||
Datacenter: dc,
|
||||
Service: service,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Test for signing ID
|
||||
if input.Path == "" {
|
||||
idx := strings.Index(input.Host, ".")
|
||||
if idx > 0 {
|
||||
return &SpiffeIDSigning{
|
||||
ClusterID: input.Host[:idx],
|
||||
Domain: input.Host[idx+1:],
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("SPIFFE ID is not in the expected format")
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// SpiffeIDService is the structure to represent the SPIFFE ID for a service.
|
||||
type SpiffeIDService struct {
|
||||
Host string
|
||||
Namespace string
|
||||
Datacenter string
|
||||
Service string
|
||||
}
|
||||
|
||||
// URI returns the *url.URL for this SPIFFE ID.
|
||||
func (id *SpiffeIDService) URI() *url.URL {
|
||||
var result url.URL
|
||||
result.Scheme = "spiffe"
|
||||
result.Host = id.Host
|
||||
result.Path = fmt.Sprintf("/ns/%s/dc/%s/svc/%s",
|
||||
id.Namespace, id.Datacenter, id.Service)
|
||||
return &result
|
||||
}
|
||||
|
||||
// CertURI impl.
|
||||
func (id *SpiffeIDService) Authorize(ixn *structs.Intention) (bool, bool) {
|
||||
if ixn.SourceNS != structs.IntentionWildcard && ixn.SourceNS != id.Namespace {
|
||||
// Non-matching namespace
|
||||
return false, false
|
||||
}
|
||||
|
||||
if ixn.SourceName != structs.IntentionWildcard && ixn.SourceName != id.Service {
|
||||
// Non-matching name
|
||||
return false, false
|
||||
}
|
||||
|
||||
// Match, return allow value
|
||||
return ixn.Action == structs.IntentionActionAllow, true
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSpiffeIDServiceAuthorize(t *testing.T) {
|
||||
ns := structs.IntentionDefaultNamespace
|
||||
serviceWeb := &SpiffeIDService{
|
||||
Host: "1234.consul",
|
||||
Namespace: structs.IntentionDefaultNamespace,
|
||||
Datacenter: "dc01",
|
||||
Service: "web",
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
Name string
|
||||
URI *SpiffeIDService
|
||||
Ixn *structs.Intention
|
||||
Auth bool
|
||||
Match bool
|
||||
}{
|
||||
{
|
||||
"exact source, not matching namespace",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: "different",
|
||||
SourceName: "db",
|
||||
},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
"exact source, not matching name",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: ns,
|
||||
SourceName: "db",
|
||||
},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
"exact source, allow",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: serviceWeb.Namespace,
|
||||
SourceName: serviceWeb.Service,
|
||||
Action: structs.IntentionActionAllow,
|
||||
},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
|
||||
{
|
||||
"exact source, deny",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: serviceWeb.Namespace,
|
||||
SourceName: serviceWeb.Service,
|
||||
Action: structs.IntentionActionDeny,
|
||||
},
|
||||
false,
|
||||
true,
|
||||
},
|
||||
|
||||
{
|
||||
"exact namespace, wildcard service, deny",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: serviceWeb.Namespace,
|
||||
SourceName: structs.IntentionWildcard,
|
||||
Action: structs.IntentionActionDeny,
|
||||
},
|
||||
false,
|
||||
true,
|
||||
},
|
||||
|
||||
{
|
||||
"exact namespace, wildcard service, allow",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: serviceWeb.Namespace,
|
||||
SourceName: structs.IntentionWildcard,
|
||||
Action: structs.IntentionActionAllow,
|
||||
},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
auth, match := tc.URI.Authorize(tc.Ixn)
|
||||
assert.Equal(t, tc.Auth, auth)
|
||||
assert.Equal(t, tc.Match, match)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// SpiffeIDSigning is the structure to represent the SPIFFE ID for a
|
||||
// signing certificate (not a leaf service).
|
||||
type SpiffeIDSigning struct {
|
||||
ClusterID string // Unique cluster ID
|
||||
Domain string // The domain, usually "consul"
|
||||
}
|
||||
|
||||
// URI returns the *url.URL for this SPIFFE ID.
|
||||
func (id *SpiffeIDSigning) URI() *url.URL {
|
||||
var result url.URL
|
||||
result.Scheme = "spiffe"
|
||||
result.Host = id.Host()
|
||||
return &result
|
||||
}
|
||||
|
||||
// Host is the canonical representation as a DNS-compatible hostname.
|
||||
func (id *SpiffeIDSigning) Host() string {
|
||||
return strings.ToLower(fmt.Sprintf("%s.%s", id.ClusterID, id.Domain))
|
||||
}
|
||||
|
||||
// CertURI impl.
|
||||
func (id *SpiffeIDSigning) Authorize(ixn *structs.Intention) (bool, bool) {
|
||||
// Never authorize as a client.
|
||||
return false, true
|
||||
}
|
||||
|
||||
// CanSign takes any CertURI and returns whether or not this signing entity is
|
||||
// allowed to sign CSRs for that entity (i.e. represents the trust domain for
|
||||
// that entity).
|
||||
//
|
||||
// I choose to make this a fixed centralised method here for now rather than a
|
||||
// method on CertURI interface since we don't intend this to be extensible
|
||||
// outside and it's easier to reason about the security properties when they are
|
||||
// all in one place with "whitelist" semantics.
|
||||
func (id *SpiffeIDSigning) CanSign(cu CertURI) bool {
|
||||
switch other := cu.(type) {
|
||||
case *SpiffeIDSigning:
|
||||
// We can only sign other CA certificates for the same trust domain. Note
|
||||
// that we could open this up later for example to support external
|
||||
// federation of roots and cross-signing external roots that have different
|
||||
// URI structure but it's simpler to start off restrictive.
|
||||
return id == other
|
||||
case *SpiffeIDService:
|
||||
// The host component of the service must be an exact match for now under
|
||||
// ascii case folding (since hostnames are case-insensitive). Later we might
|
||||
// worry about Unicode domains if we start allowing customisation beyond the
|
||||
// built-in cluster ids.
|
||||
return strings.ToLower(other.Host) == id.Host()
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SpiffeIDSigningForCluster returns the SPIFFE signing identifier (trust
|
||||
// domain) representation of the given CA config. If config is nil this function
|
||||
// will panic.
|
||||
//
|
||||
// NOTE(banks): we intentionally fix the tld `.consul` for now rather than tie
|
||||
// this to the `domain` config used for DNS because changing DNS domain can't
|
||||
// break all certificate validation. That does mean that DNS prefix might not
|
||||
// match the identity URIs and so the trust domain might not actually resolve
|
||||
// which we would like but don't actually need.
|
||||
func SpiffeIDSigningForCluster(config *structs.CAConfiguration) *SpiffeIDSigning {
|
||||
return &SpiffeIDSigning{ClusterID: config.ClusterID, Domain: "consul"}
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Signing ID should never authorize
|
||||
func TestSpiffeIDSigningAuthorize(t *testing.T) {
|
||||
var id SpiffeIDSigning
|
||||
auth, ok := id.Authorize(nil)
|
||||
assert.False(t, auth)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestSpiffeIDSigningForCluster(t *testing.T) {
|
||||
// For now it should just append .consul to the ID.
|
||||
config := &structs.CAConfiguration{
|
||||
ClusterID: TestClusterID,
|
||||
}
|
||||
id := SpiffeIDSigningForCluster(config)
|
||||
assert.Equal(t, id.URI().String(), "spiffe://"+TestClusterID+".consul")
|
||||
}
|
||||
|
||||
// fakeCertURI is a CertURI implementation that our implementation doesn't know
|
||||
// about
|
||||
type fakeCertURI string
|
||||
|
||||
func (f fakeCertURI) Authorize(*structs.Intention) (auth bool, match bool) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (f fakeCertURI) URI() *url.URL {
|
||||
u, _ := url.Parse(string(f))
|
||||
return u
|
||||
}
|
||||
func TestSpiffeIDSigning_CanSign(t *testing.T) {
|
||||
|
||||
testSigning := &SpiffeIDSigning{
|
||||
ClusterID: TestClusterID,
|
||||
Domain: "consul",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id *SpiffeIDSigning
|
||||
input CertURI
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "same signing ID",
|
||||
id: testSigning,
|
||||
input: testSigning,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other signing ID",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDSigning{
|
||||
ClusterID: "fakedomain",
|
||||
Domain: "consul",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "different TLD signing ID",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDSigning{
|
||||
ClusterID: TestClusterID,
|
||||
Domain: "evil",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
id: testSigning,
|
||||
input: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unrecognised CertURI implementation",
|
||||
id: testSigning,
|
||||
input: fakeCertURI("spiffe://foo.bar/baz"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "service - good",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDService{TestClusterID + ".consul", "default", "dc1", "web"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "service - good midex case",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDService{strings.ToUpper(TestClusterID) + ".CONsuL", "defAUlt", "dc1", "WEB"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "service - different cluster",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDService{"55555555-4444-3333-2222-111111111111.consul", "default", "dc1", "web"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "service - different TLD",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDService{TestClusterID + ".fake", "default", "dc1", "web"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.id.CanSign(tt.input)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// testCertURICases contains the test cases for parsing and encoding
|
||||
// the SPIFFE IDs. This is a global since it is used in multiple test functions.
|
||||
var testCertURICases = []struct {
|
||||
Name string
|
||||
URI string
|
||||
Struct interface{}
|
||||
ParseError string
|
||||
}{
|
||||
{
|
||||
"invalid scheme",
|
||||
"http://google.com/",
|
||||
nil,
|
||||
"scheme",
|
||||
},
|
||||
|
||||
{
|
||||
"basic service ID",
|
||||
"spiffe://1234.consul/ns/default/dc/dc01/svc/web",
|
||||
&SpiffeIDService{
|
||||
Host: "1234.consul",
|
||||
Namespace: "default",
|
||||
Datacenter: "dc01",
|
||||
Service: "web",
|
||||
},
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"service with URL-encoded values",
|
||||
"spiffe://1234.consul/ns/foo%2Fbar/dc/bar%2Fbaz/svc/baz%2Fqux",
|
||||
&SpiffeIDService{
|
||||
Host: "1234.consul",
|
||||
Namespace: "foo/bar",
|
||||
Datacenter: "bar/baz",
|
||||
Service: "baz/qux",
|
||||
},
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"signing ID",
|
||||
"spiffe://1234.consul",
|
||||
&SpiffeIDSigning{
|
||||
ClusterID: "1234",
|
||||
Domain: "consul",
|
||||
},
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
func TestParseCertURI(t *testing.T) {
|
||||
for _, tc := range testCertURICases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Parse the URI, should always be valid
|
||||
uri, err := url.Parse(tc.URI)
|
||||
assert.Nil(err)
|
||||
|
||||
// Parse the ID and check the error/return value
|
||||
actual, err := ParseCertURI(uri)
|
||||
if err != nil {
|
||||
t.Logf("parse error: %s", err.Error())
|
||||
}
|
||||
assert.Equal(tc.ParseError != "", err != nil, "error value")
|
||||
if err != nil {
|
||||
assert.Contains(err.Error(), tc.ParseError)
|
||||
return
|
||||
}
|
||||
assert.Equal(tc.Struct, actual)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// GET /v1/connect/ca/roots
|
||||
func (s *HTTPServer) ConnectCARoots(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
var args structs.DCSpecificRequest
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var reply structs.IndexedCARoots
|
||||
defer setMeta(resp, &reply.QueryMeta)
|
||||
if err := s.agent.RPC("ConnectCA.Roots", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// /v1/connect/ca/configuration
|
||||
func (s *HTTPServer) ConnectCAConfiguration(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
switch req.Method {
|
||||
case "GET":
|
||||
return s.ConnectCAConfigurationGet(resp, req)
|
||||
|
||||
case "PUT":
|
||||
return s.ConnectCAConfigurationSet(resp, req)
|
||||
|
||||
default:
|
||||
return nil, MethodNotAllowedError{req.Method, []string{"GET", "POST"}}
|
||||
}
|
||||
}
|
||||
|
||||
// GEt /v1/connect/ca/configuration
|
||||
func (s *HTTPServer) ConnectCAConfigurationGet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in ConnectCAConfiguration
|
||||
var args structs.DCSpecificRequest
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var reply structs.CAConfiguration
|
||||
err := s.agent.RPC("ConnectCA.ConfigurationGet", &args, &reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fixupConfig(&reply)
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// PUT /v1/connect/ca/configuration
|
||||
func (s *HTTPServer) ConnectCAConfigurationSet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in ConnectCAConfiguration
|
||||
|
||||
var args structs.CARequest
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
s.parseToken(req, &args.Token)
|
||||
if err := decodeBody(req, &args.Config, nil); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Request decode failed: %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var reply interface{}
|
||||
err := s.agent.RPC("ConnectCA.ConfigurationSet", &args, &reply)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// A hack to fix up the config types inside of the map[string]interface{}
|
||||
// so that they get formatted correctly during json.Marshal. Without this,
|
||||
// string values that get converted to []uint8 end up getting output back
|
||||
// to the user in base64-encoded form.
|
||||
func fixupConfig(conf *structs.CAConfiguration) {
|
||||
for k, v := range conf.Config {
|
||||
if raw, ok := v.([]uint8); ok {
|
||||
strVal := ca.Uint8ToString(raw)
|
||||
conf.Config[k] = strVal
|
||||
switch conf.Provider {
|
||||
case structs.ConsulCAProvider:
|
||||
if k == "PrivateKey" && strVal != "" {
|
||||
conf.Config["PrivateKey"] = "hidden"
|
||||
}
|
||||
case structs.VaultCAProvider:
|
||||
if k == "Token" && strVal != "" {
|
||||
conf.Config["Token"] = "hidden"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
ca "github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConnectCARoots_empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "connect { enabled = false }")
|
||||
defer a.Shutdown()
|
||||
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/ca/roots", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ConnectCARoots(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(structs.IndexedCARoots)
|
||||
assert.Equal(value.ActiveRootID, "")
|
||||
assert.Len(value.Roots, 0)
|
||||
}
|
||||
|
||||
func TestConnectCARoots_list(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Set some CAs. Note that NewTestAgent already bootstraps one CA so this just
|
||||
// adds a second and makes it active.
|
||||
ca2 := connect.TestCAConfigSet(t, a, nil)
|
||||
|
||||
// List
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/ca/roots", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ConnectCARoots(resp, req)
|
||||
assert.NoError(err)
|
||||
|
||||
value := obj.(structs.IndexedCARoots)
|
||||
assert.Equal(value.ActiveRootID, ca2.ID)
|
||||
assert.Len(value.Roots, 2)
|
||||
|
||||
// We should never have the secret information
|
||||
for _, r := range value.Roots {
|
||||
assert.Equal("", r.SigningCert)
|
||||
assert.Equal("", r.SigningKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectCAConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
expected := &structs.ConsulCAProviderConfig{
|
||||
RotationPeriod: 90 * 24 * time.Hour,
|
||||
}
|
||||
|
||||
// Get the initial config.
|
||||
{
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/ca/configuration", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ConnectCAConfiguration(resp, req)
|
||||
assert.NoError(err)
|
||||
|
||||
value := obj.(structs.CAConfiguration)
|
||||
parsed, err := ca.ParseConsulCAConfig(value.Config)
|
||||
assert.NoError(err)
|
||||
assert.Equal("consul", value.Provider)
|
||||
assert.Equal(expected, parsed)
|
||||
}
|
||||
|
||||
// Set the config.
|
||||
{
|
||||
body := bytes.NewBuffer([]byte(`
|
||||
{
|
||||
"Provider": "consul",
|
||||
"Config": {
|
||||
"RotationPeriod": 3600000000000
|
||||
}
|
||||
}`))
|
||||
req, _ := http.NewRequest("PUT", "/v1/connect/ca/configuration", body)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ConnectCAConfiguration(resp, req)
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
// The config should be updated now.
|
||||
{
|
||||
expected.RotationPeriod = time.Hour
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/ca/configuration", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ConnectCAConfiguration(resp, req)
|
||||
assert.NoError(err)
|
||||
|
||||
value := obj.(structs.CAConfiguration)
|
||||
parsed, err := ca.ParseConsulCAConfig(value.Config)
|
||||
assert.NoError(err)
|
||||
assert.Equal("consul", value.Provider)
|
||||
assert.Equal(expected, parsed)
|
||||
}
|
||||
}
|
|
@ -454,6 +454,33 @@ func (f *aclFilter) filterCoordinates(coords *structs.Coordinates) {
|
|||
*coords = c
|
||||
}
|
||||
|
||||
// filterIntentions is used to filter intentions based on ACL rules.
|
||||
// We prune entries the user doesn't have access to, and we redact any tokens
|
||||
// if the user doesn't have a management token.
|
||||
func (f *aclFilter) filterIntentions(ixns *structs.Intentions) {
|
||||
// Management tokens can see everything with no filtering.
|
||||
if f.acl.ACLList() {
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, we need to see what the token has access to.
|
||||
ret := make(structs.Intentions, 0, len(*ixns))
|
||||
for _, ixn := range *ixns {
|
||||
// If no prefix ACL applies to this then filter it, since
|
||||
// we know at this point the user doesn't have a management
|
||||
// token, otherwise see what the policy says.
|
||||
prefix, ok := ixn.GetACLPrefix()
|
||||
if !ok || !f.acl.IntentionRead(prefix) {
|
||||
f.logger.Printf("[DEBUG] consul: dropping intention %q from result due to ACLs", ixn.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
ret = append(ret, ixn)
|
||||
}
|
||||
|
||||
*ixns = ret
|
||||
}
|
||||
|
||||
// filterNodeDump is used to filter through all parts of a node dump and
|
||||
// remove elements the provided ACL token cannot access.
|
||||
func (f *aclFilter) filterNodeDump(dump *structs.NodeDump) {
|
||||
|
@ -598,6 +625,9 @@ func (s *Server) filterACL(token string, subj interface{}) error {
|
|||
case *structs.IndexedHealthChecks:
|
||||
filt.filterHealthChecks(&v.HealthChecks)
|
||||
|
||||
case *structs.IndexedIntentions:
|
||||
filt.filterIntentions(&v.Intentions)
|
||||
|
||||
case *structs.IndexedNodeDump:
|
||||
filt.filterNodeDump(&v.Dump)
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/consul/testutil/retry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var testACLPolicy = `
|
||||
|
@ -847,6 +848,58 @@ node "node1" {
|
|||
}
|
||||
}
|
||||
|
||||
func TestACL_filterIntentions(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert := assert.New(t)
|
||||
|
||||
fill := func() structs.Intentions {
|
||||
return structs.Intentions{
|
||||
&structs.Intention{
|
||||
ID: "f004177f-2c28-83b7-4229-eacc25fe55d1",
|
||||
DestinationName: "bar",
|
||||
},
|
||||
&structs.Intention{
|
||||
ID: "f004177f-2c28-83b7-4229-eacc25fe55d2",
|
||||
DestinationName: "foo",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Try permissive filtering.
|
||||
{
|
||||
ixns := fill()
|
||||
filt := newACLFilter(acl.AllowAll(), nil, false)
|
||||
filt.filterIntentions(&ixns)
|
||||
assert.Len(ixns, 2)
|
||||
}
|
||||
|
||||
// Try restrictive filtering.
|
||||
{
|
||||
ixns := fill()
|
||||
filt := newACLFilter(acl.DenyAll(), nil, false)
|
||||
filt.filterIntentions(&ixns)
|
||||
assert.Len(ixns, 0)
|
||||
}
|
||||
|
||||
// Policy to see one
|
||||
policy, err := acl.Parse(`
|
||||
service "foo" {
|
||||
policy = "read"
|
||||
}
|
||||
`, nil)
|
||||
assert.Nil(err)
|
||||
perms, err := acl.New(acl.DenyAll(), policy, nil)
|
||||
assert.Nil(err)
|
||||
|
||||
// Filter
|
||||
{
|
||||
ixns := fill()
|
||||
filt := newACLFilter(perms, nil, false)
|
||||
filt.filterIntentions(&ixns)
|
||||
assert.Len(ixns, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestACL_filterServices(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create some services
|
||||
|
|
|
@ -47,6 +47,13 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error
|
|||
|
||||
// Handle a service registration.
|
||||
if args.Service != nil {
|
||||
// Validate the service. This is in addition to the below since
|
||||
// the above just hasn't been moved over yet. We should move it over
|
||||
// in time.
|
||||
if err := args.Service.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If no service id, but service name, use default
|
||||
if args.Service.ID == "" && args.Service.Service != "" {
|
||||
args.Service.ID = args.Service.Service
|
||||
|
@ -73,6 +80,13 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error
|
|||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
// Proxies must have write permission on their destination
|
||||
if args.Service.Kind == structs.ServiceKindConnectProxy {
|
||||
if rule != nil && !rule.ServiceWrite(args.Service.ProxyDestination, nil) {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Move the old format single check into the slice, and fixup IDs.
|
||||
|
@ -244,24 +258,52 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru
|
|||
return fmt.Errorf("Must provide service name")
|
||||
}
|
||||
|
||||
// Determine the function we'll call
|
||||
var f func(memdb.WatchSet, *state.Store) (uint64, structs.ServiceNodes, error)
|
||||
switch {
|
||||
case args.Connect:
|
||||
f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) {
|
||||
return s.ConnectServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
|
||||
default:
|
||||
f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) {
|
||||
if args.ServiceAddress != "" {
|
||||
return s.ServiceAddressNodes(ws, args.ServiceAddress)
|
||||
}
|
||||
|
||||
if args.TagFilter {
|
||||
return s.ServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
|
||||
}
|
||||
|
||||
return s.ServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
}
|
||||
|
||||
// If we're doing a connect query, we need read access to the service
|
||||
// we're trying to find proxies for, so check that.
|
||||
if args.Connect {
|
||||
// Fetch the ACL token, if any.
|
||||
rule, err := c.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule != nil && !rule.ServiceRead(args.ServiceName) {
|
||||
// Just return nil, which will return an empty response (tested)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
err := c.srv.blockingQuery(
|
||||
&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
var index uint64
|
||||
var services structs.ServiceNodes
|
||||
var err error
|
||||
if args.TagFilter {
|
||||
index, services, err = state.ServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
|
||||
} else {
|
||||
index, services, err = state.ServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
if args.ServiceAddress != "" {
|
||||
index, services, err = state.ServiceAddressNodes(ws, args.ServiceAddress)
|
||||
}
|
||||
index, services, err := f(ws, state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index, reply.ServiceNodes = index, services
|
||||
if len(args.NodeMetaFilters) > 0 {
|
||||
var filtered structs.ServiceNodes
|
||||
|
@ -280,17 +322,24 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru
|
|||
|
||||
// Provide some metrics
|
||||
if err == nil {
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", "service", "query"}, 1,
|
||||
// For metrics, we separate Connect-based lookups from non-Connect
|
||||
key := "service"
|
||||
if args.Connect {
|
||||
key = "connect"
|
||||
}
|
||||
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", key, "query"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
|
||||
if args.ServiceTag != "" {
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", "service", "query-tag"}, 1,
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", key, "query-tag"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}, {Name: "tag", Value: args.ServiceTag}})
|
||||
}
|
||||
if len(reply.ServiceNodes) == 0 {
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", "service", "not-found"}, 1,
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", key, "not-found"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,8 @@ import (
|
|||
"github.com/hashicorp/consul/testutil/retry"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCatalog_Register(t *testing.T) {
|
||||
|
@ -332,6 +334,147 @@ func TestCatalog_Register_ForwardDC(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCatalog_Register_ConnectProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
|
||||
// Register
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
|
||||
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
|
||||
}
|
||||
|
||||
// Test an invalid ConnectProxy. We don't need to exhaustively test because
|
||||
// this is all tested in structs on the Validate method.
|
||||
func TestCatalog_Register_ConnectProxy_invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Service.ProxyDestination = ""
|
||||
|
||||
// Register
|
||||
var out struct{}
|
||||
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "ProxyDestination")
|
||||
}
|
||||
|
||||
// Test that write is required for the proxy destination to register a proxy.
|
||||
func TestCatalog_Register_ConnectProxy_ACLProxyDestination(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServerWithConfig(t, func(c *Config) {
|
||||
c.ACLDatacenter = "dc1"
|
||||
c.ACLMasterToken = "root"
|
||||
c.ACLDefaultPolicy = "deny"
|
||||
})
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Create the ACL.
|
||||
arg := structs.ACLRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.ACLSet,
|
||||
ACL: structs.ACL{
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
Rules: `
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
}
|
||||
`,
|
||||
},
|
||||
WriteRequest: structs.WriteRequest{Token: "root"},
|
||||
}
|
||||
var token string
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &token))
|
||||
|
||||
// Register should fail because we don't have permission on the destination
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "foo"
|
||||
args.Service.ProxyDestination = "bar"
|
||||
args.WriteRequest.Token = token
|
||||
var out struct{}
|
||||
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
|
||||
assert.True(acl.IsErrPermissionDenied(err))
|
||||
|
||||
// Register should fail with the right destination but wrong name
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "bar"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.WriteRequest.Token = token
|
||||
err = msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
|
||||
assert.True(acl.IsErrPermissionDenied(err))
|
||||
|
||||
// Register should work with the right destination
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "foo"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.WriteRequest.Token = token
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
}
|
||||
|
||||
func TestCatalog_Register_ConnectNative(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
args := structs.TestRegisterRequest(t)
|
||||
args.Service.Connect.Native = true
|
||||
|
||||
// Register
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal(structs.ServiceKindTypical, v.ServiceKind)
|
||||
assert.True(v.ServiceConnect.Native)
|
||||
}
|
||||
|
||||
func TestCatalog_Deregister(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir1, s1 := testServer(t)
|
||||
|
@ -1599,6 +1742,246 @@ func TestCatalog_ListServiceNodes_DistanceSort(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCatalog_ListServiceNodes_ConnectProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the service
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
TagFilter: false,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
|
||||
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
|
||||
}
|
||||
|
||||
func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the proxy service
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// Register the service
|
||||
{
|
||||
dst := args.Service.ProxyDestination
|
||||
args := structs.TestRegisterRequest(t)
|
||||
args.Service.Service = dst
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
}
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.ProxyDestination,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
|
||||
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
|
||||
|
||||
// List by non-Connect
|
||||
req = structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.ProxyDestination,
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v = resp.ServiceNodes[0]
|
||||
assert.Equal(args.Service.ProxyDestination, v.ServiceName)
|
||||
assert.Equal("", v.ServiceProxyDestination)
|
||||
}
|
||||
|
||||
// Test that calling ServiceNodes with Connect: true will return
|
||||
// Connect native services.
|
||||
func TestCatalog_ListServiceNodes_ConnectDestinationNative(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the native service
|
||||
args := structs.TestRegisterRequest(t)
|
||||
args.Service.Connect.Native = true
|
||||
var out struct{}
|
||||
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
require.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
require.Equal(args.Service.Service, v.ServiceName)
|
||||
|
||||
// List by non-Connect
|
||||
req = structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
}
|
||||
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
require.Len(resp.ServiceNodes, 1)
|
||||
v = resp.ServiceNodes[0]
|
||||
require.Equal(args.Service.Service, v.ServiceName)
|
||||
}
|
||||
|
||||
func TestCatalog_ListServiceNodes_ConnectProxy_ACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServerWithConfig(t, func(c *Config) {
|
||||
c.ACLDatacenter = "dc1"
|
||||
c.ACLMasterToken = "root"
|
||||
c.ACLDefaultPolicy = "deny"
|
||||
})
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Create the ACL.
|
||||
arg := structs.ACLRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.ACLSet,
|
||||
ACL: structs.ACL{
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
Rules: `
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
}
|
||||
`,
|
||||
},
|
||||
WriteRequest: structs.WriteRequest{Token: "root"},
|
||||
}
|
||||
var token string
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &token))
|
||||
|
||||
{
|
||||
// Register a proxy
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "foo-proxy"
|
||||
args.Service.ProxyDestination = "bar"
|
||||
args.WriteRequest.Token = "root"
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// Register a proxy
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "foo-proxy"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.WriteRequest.Token = "root"
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// Register a proxy
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "another-proxy"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.WriteRequest.Token = "root"
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
}
|
||||
|
||||
// List w/ token. This should disallow because we don't have permission
|
||||
// to read "bar"
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: "bar",
|
||||
QueryOptions: structs.QueryOptions{Token: token},
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 0)
|
||||
|
||||
// List w/ token. This should work since we're requesting "foo", but should
|
||||
// also only contain the proxies with names that adhere to our ACL.
|
||||
req = structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: "foo",
|
||||
QueryOptions: structs.QueryOptions{Token: token},
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal("foo-proxy", v.ServiceName)
|
||||
}
|
||||
|
||||
func TestCatalog_ListServiceNodes_ConnectNative(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the service
|
||||
args := structs.TestRegisterRequest(t)
|
||||
args.Service.Connect.Native = true
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
TagFilter: false,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal(args.Service.Connect.Native, v.ServiceConnect.Native)
|
||||
}
|
||||
|
||||
func TestCatalog_NodeServices(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir1, s1 := testServer(t)
|
||||
|
@ -1649,6 +2032,67 @@ func TestCatalog_NodeServices(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the service
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// List
|
||||
req := structs.NodeSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: args.Node,
|
||||
}
|
||||
var resp structs.IndexedNodeServices
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
|
||||
|
||||
assert.Len(resp.NodeServices.Services, 1)
|
||||
v := resp.NodeServices.Services[args.Service.Service]
|
||||
assert.Equal(structs.ServiceKindConnectProxy, v.Kind)
|
||||
assert.Equal(args.Service.ProxyDestination, v.ProxyDestination)
|
||||
}
|
||||
|
||||
func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the service
|
||||
args := structs.TestRegisterRequest(t)
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// List
|
||||
req := structs.NodeSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: args.Node,
|
||||
}
|
||||
var resp structs.IndexedNodeServices
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
|
||||
|
||||
assert.Len(resp.NodeServices.Services, 1)
|
||||
v := resp.NodeServices.Services[args.Service.Service]
|
||||
assert.Equal(args.Service.Connect.Native, v.Connect.Native)
|
||||
}
|
||||
|
||||
// Used to check for a regression against a known bug
|
||||
func TestCatalog_Register_FailedCase1(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
|
@ -56,7 +57,7 @@ type Client struct {
|
|||
|
||||
// rpcLimiter is used to rate limit the total number of RPCs initiated
|
||||
// from an agent.
|
||||
rpcLimiter *rate.Limiter
|
||||
rpcLimiter atomic.Value
|
||||
|
||||
// eventCh is used to receive events from the
|
||||
// serf cluster in the datacenter
|
||||
|
@ -128,12 +129,13 @@ func NewClientLogger(config *Config, logger *log.Logger) (*Client, error) {
|
|||
c := &Client{
|
||||
config: config,
|
||||
connPool: connPool,
|
||||
rpcLimiter: rate.NewLimiter(config.RPCRate, config.RPCMaxBurst),
|
||||
eventCh: make(chan serf.Event, serfEventBacklog),
|
||||
logger: logger,
|
||||
shutdownCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
c.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))
|
||||
|
||||
if err := c.initEnterprise(); err != nil {
|
||||
c.Shutdown()
|
||||
return nil, err
|
||||
|
@ -263,7 +265,7 @@ TRY:
|
|||
|
||||
// Enforce the RPC limit.
|
||||
metrics.IncrCounter([]string{"client", "rpc"}, 1)
|
||||
if !c.rpcLimiter.Allow() {
|
||||
if !c.rpcLimiter.Load().(*rate.Limiter).Allow() {
|
||||
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
|
||||
return structs.ErrRPCRateExceeded
|
||||
}
|
||||
|
@ -306,7 +308,7 @@ func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io
|
|||
|
||||
// Enforce the RPC limit.
|
||||
metrics.IncrCounter([]string{"client", "rpc"}, 1)
|
||||
if !c.rpcLimiter.Allow() {
|
||||
if !c.rpcLimiter.Load().(*rate.Limiter).Allow() {
|
||||
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
|
||||
return structs.ErrRPCRateExceeded
|
||||
}
|
||||
|
@ -381,3 +383,10 @@ func (c *Client) GetLANCoordinate() (lib.CoordinateSet, error) {
|
|||
cs := lib.CoordinateSet{c.config.Segment: lan}
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
// ReloadConfig is used to have the Client do an online reload of
|
||||
// relevant configuration information
|
||||
func (c *Client) ReloadConfig(config *Config) error {
|
||||
c.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@ import (
|
|||
"github.com/hashicorp/consul/testutil/retry"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func testClientConfig(t *testing.T) (string, *Config) {
|
||||
|
@ -665,3 +667,25 @@ func TestClient_Encrypted(t *testing.T) {
|
|||
t.Fatalf("should be encrypted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Reload(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir1, c := testClientWithConfig(t, func(c *Config) {
|
||||
c.RPCRate = 500
|
||||
c.RPCMaxBurst = 5000
|
||||
})
|
||||
defer os.RemoveAll(dir1)
|
||||
defer c.Shutdown()
|
||||
|
||||
limiter := c.rpcLimiter.Load().(*rate.Limiter)
|
||||
require.Equal(t, rate.Limit(500), limiter.Limit())
|
||||
require.Equal(t, 5000, limiter.Burst())
|
||||
|
||||
c.config.RPCRate = 1000
|
||||
c.config.RPCMaxBurst = 10000
|
||||
|
||||
require.NoError(t, c.ReloadConfig(c.config))
|
||||
limiter = c.rpcLimiter.Load().(*rate.Limiter)
|
||||
require.Equal(t, rate.Limit(1000), limiter.Limit())
|
||||
require.Equal(t, 10000, limiter.Burst())
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"github.com/hashicorp/consul/types"
|
||||
|
@ -346,6 +347,13 @@ type Config struct {
|
|||
// autopilot tasks, such as promoting eligible non-voters and removing
|
||||
// dead servers.
|
||||
AutopilotInterval time.Duration
|
||||
|
||||
// ConnectEnabled is whether to enable Connect features such as the CA.
|
||||
ConnectEnabled bool
|
||||
|
||||
// CAConfig is used to apply the initial Connect CA configuration when
|
||||
// bootstrapping.
|
||||
CAConfig *structs.CAConfiguration
|
||||
}
|
||||
|
||||
// CheckProtocolVersion validates the protocol version.
|
||||
|
@ -425,6 +433,13 @@ func DefaultConfig() *Config {
|
|||
ServerStabilizationTime: 10 * time.Second,
|
||||
},
|
||||
|
||||
CAConfig: &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"RotationPeriod": "2160h",
|
||||
},
|
||||
},
|
||||
|
||||
ServerHealthInterval: 2 * time.Second,
|
||||
AutopilotInterval: 10 * time.Second,
|
||||
}
|
||||
|
|
|
@ -0,0 +1,393 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
var ErrConnectNotEnabled = errors.New("Connect must be enabled in order to use this endpoint")
|
||||
|
||||
// ConnectCA manages the Connect CA.
|
||||
type ConnectCA struct {
|
||||
// srv is a pointer back to the server.
|
||||
srv *Server
|
||||
}
|
||||
|
||||
// ConfigurationGet returns the configuration for the CA.
|
||||
func (s *ConnectCA) ConfigurationGet(
|
||||
args *structs.DCSpecificRequest,
|
||||
reply *structs.CAConfiguration) error {
|
||||
// Exit early if Connect hasn't been enabled.
|
||||
if !s.srv.config.ConnectEnabled {
|
||||
return ErrConnectNotEnabled
|
||||
}
|
||||
|
||||
if done, err := s.srv.forward("ConnectCA.ConfigurationGet", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// This action requires operator read access.
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rule != nil && !rule.OperatorRead() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
state := s.srv.fsm.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*reply = *config
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigurationSet updates the configuration for the CA.
|
||||
func (s *ConnectCA) ConfigurationSet(
|
||||
args *structs.CARequest,
|
||||
reply *interface{}) error {
|
||||
// Exit early if Connect hasn't been enabled.
|
||||
if !s.srv.config.ConnectEnabled {
|
||||
return ErrConnectNotEnabled
|
||||
}
|
||||
|
||||
if done, err := s.srv.forward("ConnectCA.ConfigurationSet", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// This action requires operator write access.
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rule != nil && !rule.OperatorWrite() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// Exit early if it's a no-op change
|
||||
state := s.srv.fsm.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
args.Config.ClusterID = config.ClusterID
|
||||
if args.Config.Provider == config.Provider && reflect.DeepEqual(args.Config.Config, config.Config) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create a new instance of the provider described by the config
|
||||
// and get the current active root CA. This acts as a good validation
|
||||
// of the config and makes sure the provider is functioning correctly
|
||||
// before we commit any changes to Raft.
|
||||
newProvider, err := s.srv.createCAProvider(args.Config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not initialize provider: %v", err)
|
||||
}
|
||||
|
||||
newRootPEM, err := newProvider.ActiveRoot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newActiveRoot, err := parseCARoot(newRootPEM, args.Config.Provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Compare the new provider's root CA ID to the current one. If they
|
||||
// match, just update the existing provider with the new config.
|
||||
// If they don't match, begin the root rotation process.
|
||||
_, root, err := state.CARootActive(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if root != nil && root.ID == newActiveRoot.ID {
|
||||
args.Op = structs.CAOpSetConfig
|
||||
resp, err := s.srv.raftApply(structs.ConnectCARequestType, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
// If the config has been committed, update the local provider instance
|
||||
s.srv.setCAProvider(newProvider, newActiveRoot)
|
||||
|
||||
s.srv.logger.Printf("[INFO] connect: CA provider config updated")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// At this point, we know the config change has trigged a root rotation,
|
||||
// either by swapping the provider type or changing the provider's config
|
||||
// to use a different root certificate.
|
||||
|
||||
// If it's a config change that would trigger a rotation (different provider/root):
|
||||
// 1. Get the root from the new provider.
|
||||
// 2. Call CrossSignCA on the old provider to sign the new root with the old one to
|
||||
// get a cross-signed certificate.
|
||||
// 3. Take the active root for the new provider and append the intermediate from step 2
|
||||
// to its list of intermediates.
|
||||
newRoot, err := connect.ParseCert(newRootPEM)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Have the old provider cross-sign the new intermediate
|
||||
oldProvider, _ := s.srv.getCAProvider()
|
||||
if oldProvider == nil {
|
||||
return fmt.Errorf("internal error: CA provider is nil")
|
||||
}
|
||||
xcCert, err := oldProvider.CrossSignCA(newRoot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the cross signed cert to the new root's intermediates.
|
||||
newActiveRoot.IntermediateCerts = []string{xcCert}
|
||||
intermediate, err := newProvider.GenerateIntermediate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if intermediate != newRootPEM {
|
||||
newActiveRoot.IntermediateCerts = append(newActiveRoot.IntermediateCerts, intermediate)
|
||||
}
|
||||
|
||||
// Update the roots and CA config in the state store at the same time
|
||||
idx, roots, err := state.CARoots(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var newRoots structs.CARoots
|
||||
for _, r := range roots {
|
||||
newRoot := *r
|
||||
if newRoot.Active {
|
||||
newRoot.Active = false
|
||||
}
|
||||
newRoots = append(newRoots, &newRoot)
|
||||
}
|
||||
newRoots = append(newRoots, newActiveRoot)
|
||||
|
||||
args.Op = structs.CAOpSetRootsAndConfig
|
||||
args.Index = idx
|
||||
args.Roots = newRoots
|
||||
resp, err := s.srv.raftApply(structs.ConnectCARequestType, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
// If the config has been committed, update the local provider instance
|
||||
// and call teardown on the old provider
|
||||
s.srv.setCAProvider(newProvider, newActiveRoot)
|
||||
|
||||
if err := oldProvider.Cleanup(); err != nil {
|
||||
s.srv.logger.Printf("[WARN] connect: failed to clean up old provider %q", config.Provider)
|
||||
}
|
||||
|
||||
s.srv.logger.Printf("[INFO] connect: CA rotated to new root under provider %q", args.Config.Provider)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Roots returns the currently trusted root certificates.
|
||||
func (s *ConnectCA) Roots(
|
||||
args *structs.DCSpecificRequest,
|
||||
reply *structs.IndexedCARoots) error {
|
||||
// Forward if necessary
|
||||
if done, err := s.srv.forward("ConnectCA.Roots", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load the ClusterID to generate TrustDomain. We do this outside the loop
|
||||
// since by definition this value should be immutable once set for lifetime of
|
||||
// the cluster so we don't need to look it up more than once. We also don't
|
||||
// have to worry about non-atomicity between the config fetch transaction and
|
||||
// the CARoots transaction below since this field must remain immutable. Do
|
||||
// not re-use this state/config for other logic that might care about changes
|
||||
// of config during the blocking query below.
|
||||
{
|
||||
state := s.srv.fsm.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check CA is actually bootstrapped...
|
||||
if config != nil {
|
||||
// Build TrustDomain based on the ClusterID stored.
|
||||
signingID := connect.SpiffeIDSigningForCluster(config)
|
||||
if signingID == nil {
|
||||
// If CA is bootstrapped at all then this should never happen but be
|
||||
// defensive.
|
||||
return errors.New("no cluster trust domain setup")
|
||||
}
|
||||
reply.TrustDomain = signingID.Host()
|
||||
}
|
||||
}
|
||||
|
||||
return s.srv.blockingQuery(
|
||||
&args.QueryOptions, &reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, roots, err := state.CARoots(ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index, reply.Roots = index, roots
|
||||
if reply.Roots == nil {
|
||||
reply.Roots = make(structs.CARoots, 0)
|
||||
}
|
||||
|
||||
// The API response must NEVER contain the secret information
|
||||
// such as keys and so on. We use a whitelist below to copy the
|
||||
// specific fields we want to expose.
|
||||
for i, r := range reply.Roots {
|
||||
// IMPORTANT: r must NEVER be modified, since it is a pointer
|
||||
// directly to the structure in the memdb store.
|
||||
|
||||
reply.Roots[i] = &structs.CARoot{
|
||||
ID: r.ID,
|
||||
Name: r.Name,
|
||||
SerialNumber: r.SerialNumber,
|
||||
SigningKeyID: r.SigningKeyID,
|
||||
NotBefore: r.NotBefore,
|
||||
NotAfter: r.NotAfter,
|
||||
RootCert: r.RootCert,
|
||||
IntermediateCerts: r.IntermediateCerts,
|
||||
RaftIndex: r.RaftIndex,
|
||||
Active: r.Active,
|
||||
}
|
||||
|
||||
if r.Active {
|
||||
reply.ActiveRootID = r.ID
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Sign signs a certificate for a service.
|
||||
func (s *ConnectCA) Sign(
|
||||
args *structs.CASignRequest,
|
||||
reply *structs.IssuedCert) error {
|
||||
// Exit early if Connect hasn't been enabled.
|
||||
if !s.srv.config.ConnectEnabled {
|
||||
return ErrConnectNotEnabled
|
||||
}
|
||||
|
||||
if done, err := s.srv.forward("ConnectCA.Sign", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the CSR
|
||||
csr, err := connect.ParseCSR(args.CSR)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the SPIFFE ID
|
||||
spiffeID, err := connect.ParseCertURI(csr.URIs[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serviceID, ok := spiffeID.(*connect.SpiffeIDService)
|
||||
if !ok {
|
||||
return fmt.Errorf("SPIFFE ID in CSR must be a service ID")
|
||||
}
|
||||
|
||||
provider, caRoot := s.srv.getCAProvider()
|
||||
if provider == nil {
|
||||
return fmt.Errorf("internal error: CA provider is nil")
|
||||
}
|
||||
|
||||
// Verify that the CSR entity is in the cluster's trust domain
|
||||
state := s.srv.fsm.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
signingID := connect.SpiffeIDSigningForCluster(config)
|
||||
if !signingID.CanSign(serviceID) {
|
||||
return fmt.Errorf("SPIFFE ID in CSR from a different trust domain: %s, "+
|
||||
"we are %s", serviceID.Host, signingID.Host())
|
||||
}
|
||||
|
||||
// Verify that the ACL token provided has permission to act as this service
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rule != nil && !rule.ServiceWrite(serviceID.Service, nil) {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// Verify that the DC in the service URI matches us. We might relax this
|
||||
// requirement later but being restrictive for now is safer.
|
||||
if serviceID.Datacenter != s.srv.config.Datacenter {
|
||||
return fmt.Errorf("SPIFFE ID in CSR from a different datacenter: %s, "+
|
||||
"we are %s", serviceID.Datacenter, s.srv.config.Datacenter)
|
||||
}
|
||||
|
||||
// All seems to be in order, actually sign it.
|
||||
pem, err := provider.Sign(csr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Append any intermediates needed by this root.
|
||||
for _, p := range caRoot.IntermediateCerts {
|
||||
pem = strings.TrimSpace(pem) + "\n" + p
|
||||
}
|
||||
|
||||
// TODO(banks): when we implement IssuedCerts table we can use the insert to
|
||||
// that as the raft index to return in response. Right now we can rely on only
|
||||
// the built-in provider being supported and the implementation detail that we
|
||||
// have to write a SerialIndex update to the provider config table for every
|
||||
// cert issued so in all cases this index will be higher than any previous
|
||||
// sign response. This has to be reloaded after the provider.Sign call to
|
||||
// observe the index update.
|
||||
state = s.srv.fsm.State()
|
||||
modIdx, _, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cert, err := connect.ParseCert(pem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the response
|
||||
*reply = structs.IssuedCert{
|
||||
SerialNumber: connect.HexString(cert.SerialNumber.Bytes()),
|
||||
CertPEM: pem,
|
||||
Service: serviceID.Service,
|
||||
ServiceURI: cert.URIs[0].String(),
|
||||
ValidAfter: cert.NotBefore,
|
||||
ValidBefore: cert.NotAfter,
|
||||
RaftIndex: structs.RaftIndex{
|
||||
ModifyIndex: modIdx,
|
||||
CreateIndex: modIdx,
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,434 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
ca "github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testParseCert(t *testing.T, pemValue string) *x509.Certificate {
|
||||
cert, err := connect.ParseCert(pemValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
// Test listing root CAs.
|
||||
func TestConnectCARoots(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Insert some CAs
|
||||
state := s1.fsm.State()
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
ca2.Active = false
|
||||
idx, _, err := state.CARoots(nil)
|
||||
require.NoError(err)
|
||||
ok, err := state.CARootSetCAS(idx, idx, []*structs.CARoot{ca1, ca2})
|
||||
assert.True(ok)
|
||||
require.NoError(err)
|
||||
_, caCfg, err := state.CAConfig()
|
||||
require.NoError(err)
|
||||
|
||||
// Request
|
||||
args := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var reply structs.IndexedCARoots
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
|
||||
|
||||
// Verify
|
||||
assert.Equal(ca1.ID, reply.ActiveRootID)
|
||||
assert.Len(reply.Roots, 2)
|
||||
for _, r := range reply.Roots {
|
||||
// These must never be set, for security
|
||||
assert.Equal("", r.SigningCert)
|
||||
assert.Equal("", r.SigningKey)
|
||||
}
|
||||
assert.Equal(fmt.Sprintf("%s.consul", caCfg.ClusterID), reply.TrustDomain)
|
||||
}
|
||||
|
||||
func TestConnectCAConfig_GetSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Get the starting config
|
||||
{
|
||||
args := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var reply structs.CAConfiguration
|
||||
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
|
||||
|
||||
actual, err := ca.ParseConsulCAConfig(reply.Config)
|
||||
assert.NoError(err)
|
||||
expected, err := ca.ParseConsulCAConfig(s1.config.CAConfig.Config)
|
||||
assert.NoError(err)
|
||||
assert.Equal(reply.Provider, s1.config.CAConfig.Provider)
|
||||
assert.Equal(actual, expected)
|
||||
}
|
||||
|
||||
// Update a config value
|
||||
newConfig := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": "",
|
||||
"RootCert": "",
|
||||
"RotationPeriod": 180 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
{
|
||||
args := &structs.CARequest{
|
||||
Datacenter: "dc1",
|
||||
Config: newConfig,
|
||||
}
|
||||
var reply interface{}
|
||||
|
||||
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
|
||||
}
|
||||
|
||||
// Verify the new config was set
|
||||
{
|
||||
args := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var reply structs.CAConfiguration
|
||||
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
|
||||
|
||||
actual, err := ca.ParseConsulCAConfig(reply.Config)
|
||||
assert.NoError(err)
|
||||
expected, err := ca.ParseConsulCAConfig(newConfig.Config)
|
||||
assert.NoError(err)
|
||||
assert.Equal(reply.Provider, newConfig.Provider)
|
||||
assert.Equal(actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectCAConfig_TriggerRotation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Store the current root
|
||||
rootReq := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var rootList structs.IndexedCARoots
|
||||
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
|
||||
assert.Len(rootList.Roots, 1)
|
||||
oldRoot := rootList.Roots[0]
|
||||
|
||||
// Update the provider config to use a new private key, which should
|
||||
// cause a rotation.
|
||||
_, newKey, err := connect.GeneratePrivateKey()
|
||||
assert.NoError(err)
|
||||
newConfig := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": newKey,
|
||||
"RootCert": "",
|
||||
"RotationPeriod": 90 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
{
|
||||
args := &structs.CARequest{
|
||||
Datacenter: "dc1",
|
||||
Config: newConfig,
|
||||
}
|
||||
var reply interface{}
|
||||
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
|
||||
}
|
||||
|
||||
// Make sure the new root has been added along with an intermediate
|
||||
// cross-signed by the old root.
|
||||
var newRootPEM string
|
||||
{
|
||||
args := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var reply structs.IndexedCARoots
|
||||
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
|
||||
assert.Len(reply.Roots, 2)
|
||||
|
||||
for _, r := range reply.Roots {
|
||||
if r.ID == oldRoot.ID {
|
||||
// The old root should no longer be marked as the active root,
|
||||
// and none of its other fields should have changed.
|
||||
assert.False(r.Active)
|
||||
assert.Equal(r.Name, oldRoot.Name)
|
||||
assert.Equal(r.RootCert, oldRoot.RootCert)
|
||||
assert.Equal(r.SigningCert, oldRoot.SigningCert)
|
||||
assert.Equal(r.IntermediateCerts, oldRoot.IntermediateCerts)
|
||||
} else {
|
||||
newRootPEM = r.RootCert
|
||||
// The new root should have a valid cross-signed cert from the old
|
||||
// root as an intermediate.
|
||||
assert.True(r.Active)
|
||||
assert.Len(r.IntermediateCerts, 1)
|
||||
|
||||
xc := testParseCert(t, r.IntermediateCerts[0])
|
||||
oldRootCert := testParseCert(t, oldRoot.RootCert)
|
||||
newRootCert := testParseCert(t, r.RootCert)
|
||||
|
||||
// Should have the authority key ID and signature algo of the
|
||||
// (old) signing CA.
|
||||
assert.Equal(xc.AuthorityKeyId, oldRootCert.AuthorityKeyId)
|
||||
assert.NotEqual(xc.SubjectKeyId, oldRootCert.SubjectKeyId)
|
||||
assert.Equal(xc.SignatureAlgorithm, oldRootCert.SignatureAlgorithm)
|
||||
|
||||
// The common name and SAN should not have changed.
|
||||
assert.Equal(xc.Subject.CommonName, newRootCert.Subject.CommonName)
|
||||
assert.Equal(xc.URIs, newRootCert.URIs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the new config was set.
|
||||
{
|
||||
args := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var reply structs.CAConfiguration
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
|
||||
|
||||
actual, err := ca.ParseConsulCAConfig(reply.Config)
|
||||
require.NoError(err)
|
||||
expected, err := ca.ParseConsulCAConfig(newConfig.Config)
|
||||
require.NoError(err)
|
||||
assert.Equal(reply.Provider, newConfig.Provider)
|
||||
assert.Equal(actual, expected)
|
||||
}
|
||||
|
||||
// Verify that new leaf certs get the cross-signed intermediate bundled
|
||||
{
|
||||
// Generate a CSR and request signing
|
||||
spiffeId := connect.TestSpiffeIDService(t, "web")
|
||||
csr, _ := connect.TestCSR(t, spiffeId)
|
||||
args := &structs.CASignRequest{
|
||||
Datacenter: "dc1",
|
||||
CSR: csr,
|
||||
}
|
||||
var reply structs.IssuedCert
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
|
||||
|
||||
// Verify that the cert is signed by the new CA
|
||||
{
|
||||
roots := x509.NewCertPool()
|
||||
require.True(roots.AppendCertsFromPEM([]byte(newRootPEM)))
|
||||
leaf, err := connect.ParseCert(reply.CertPEM)
|
||||
require.NoError(err)
|
||||
_, err = leaf.Verify(x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
})
|
||||
require.NoError(err)
|
||||
}
|
||||
|
||||
// And that it validates via the intermediate
|
||||
{
|
||||
roots := x509.NewCertPool()
|
||||
assert.True(roots.AppendCertsFromPEM([]byte(oldRoot.RootCert)))
|
||||
leaf, err := connect.ParseCert(reply.CertPEM)
|
||||
require.NoError(err)
|
||||
|
||||
// Make sure the intermediate was returned as well as leaf
|
||||
_, rest := pem.Decode([]byte(reply.CertPEM))
|
||||
require.NotEmpty(rest)
|
||||
|
||||
intermediates := x509.NewCertPool()
|
||||
require.True(intermediates.AppendCertsFromPEM(rest))
|
||||
|
||||
_, err = leaf.Verify(x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
Intermediates: intermediates,
|
||||
})
|
||||
require.NoError(err)
|
||||
}
|
||||
|
||||
// Verify other fields
|
||||
assert.Equal("web", reply.Service)
|
||||
assert.Equal(spiffeId.URI().String(), reply.ServiceURI)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CA signing
|
||||
func TestConnectCASign(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Generate a CSR and request signing
|
||||
spiffeId := connect.TestSpiffeIDService(t, "web")
|
||||
csr, _ := connect.TestCSR(t, spiffeId)
|
||||
args := &structs.CASignRequest{
|
||||
Datacenter: "dc1",
|
||||
CSR: csr,
|
||||
}
|
||||
var reply structs.IssuedCert
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
|
||||
|
||||
// Get the current CA
|
||||
state := s1.fsm.State()
|
||||
_, ca, err := state.CARootActive(nil)
|
||||
require.NoError(err)
|
||||
|
||||
// Verify that the cert is signed by the CA
|
||||
roots := x509.NewCertPool()
|
||||
assert.True(roots.AppendCertsFromPEM([]byte(ca.RootCert)))
|
||||
leaf, err := connect.ParseCert(reply.CertPEM)
|
||||
require.NoError(err)
|
||||
_, err = leaf.Verify(x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
})
|
||||
require.NoError(err)
|
||||
|
||||
// Verify other fields
|
||||
assert.Equal("web", reply.Service)
|
||||
assert.Equal(spiffeId.URI().String(), reply.ServiceURI)
|
||||
}
|
||||
|
||||
func TestConnectCASignValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir1, s1 := testServerWithConfig(t, func(c *Config) {
|
||||
c.ACLDatacenter = "dc1"
|
||||
c.ACLMasterToken = "root"
|
||||
c.ACLDefaultPolicy = "deny"
|
||||
})
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Create an ACL token with service:write for web*
|
||||
var webToken string
|
||||
{
|
||||
arg := structs.ACLRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.ACLSet,
|
||||
ACL: structs.ACL{
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
Rules: `
|
||||
service "web" {
|
||||
policy = "write"
|
||||
}`,
|
||||
},
|
||||
WriteRequest: structs.WriteRequest{Token: "root"},
|
||||
}
|
||||
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &webToken))
|
||||
}
|
||||
|
||||
testWebID := connect.TestSpiffeIDService(t, "web")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id connect.CertURI
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "different cluster",
|
||||
id: &connect.SpiffeIDService{
|
||||
Host: "55555555-4444-3333-2222-111111111111.consul",
|
||||
Namespace: testWebID.Namespace,
|
||||
Datacenter: testWebID.Datacenter,
|
||||
Service: testWebID.Service,
|
||||
},
|
||||
wantErr: "different trust domain",
|
||||
},
|
||||
{
|
||||
name: "same cluster should validate",
|
||||
id: testWebID,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "same cluster, CSR for a different DC should NOT validate",
|
||||
id: &connect.SpiffeIDService{
|
||||
Host: testWebID.Host,
|
||||
Namespace: testWebID.Namespace,
|
||||
Datacenter: "dc2",
|
||||
Service: testWebID.Service,
|
||||
},
|
||||
wantErr: "different datacenter",
|
||||
},
|
||||
{
|
||||
name: "same cluster and DC, different service should not have perms",
|
||||
id: &connect.SpiffeIDService{
|
||||
Host: testWebID.Host,
|
||||
Namespace: testWebID.Namespace,
|
||||
Datacenter: testWebID.Datacenter,
|
||||
Service: "db",
|
||||
},
|
||||
wantErr: "Permission denied",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
csr, _ := connect.TestCSR(t, tt.id)
|
||||
args := &structs.CASignRequest{
|
||||
Datacenter: "dc1",
|
||||
CSR: csr,
|
||||
WriteRequest: structs.WriteRequest{Token: webToken},
|
||||
}
|
||||
var reply structs.IssuedCert
|
||||
err := msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply)
|
||||
if tt.wantErr == "" {
|
||||
require.NoError(t, err)
|
||||
// No other validation that is handled in different tests
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// consulCADelegate providers callbacks for the Consul CA provider
|
||||
// to use the state store for its operations.
|
||||
type consulCADelegate struct {
|
||||
srv *Server
|
||||
}
|
||||
|
||||
func (c *consulCADelegate) State() *state.Store {
|
||||
return c.srv.fsm.State()
|
||||
}
|
||||
|
||||
func (c *consulCADelegate) ApplyCARequest(req *structs.CARequest) error {
|
||||
resp, err := c.srv.raftApply(structs.ConnectCARequestType, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -20,6 +20,8 @@ func init() {
|
|||
registerCommand(structs.PreparedQueryRequestType, (*FSM).applyPreparedQueryOperation)
|
||||
registerCommand(structs.TxnRequestType, (*FSM).applyTxn)
|
||||
registerCommand(structs.AutopilotRequestType, (*FSM).applyAutopilotUpdate)
|
||||
registerCommand(structs.IntentionRequestType, (*FSM).applyIntentionOperation)
|
||||
registerCommand(structs.ConnectCARequestType, (*FSM).applyConnectCAOperation)
|
||||
}
|
||||
|
||||
func (c *FSM) applyRegister(buf []byte, index uint64) interface{} {
|
||||
|
@ -246,3 +248,85 @@ func (c *FSM) applyAutopilotUpdate(buf []byte, index uint64) interface{} {
|
|||
}
|
||||
return c.state.AutopilotSetConfig(index, &req.Config)
|
||||
}
|
||||
|
||||
// applyIntentionOperation applies the given intention operation to the state store.
|
||||
func (c *FSM) applyIntentionOperation(buf []byte, index uint64) interface{} {
|
||||
var req structs.IntentionRequest
|
||||
if err := structs.Decode(buf, &req); err != nil {
|
||||
panic(fmt.Errorf("failed to decode request: %v", err))
|
||||
}
|
||||
|
||||
defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "intention"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
|
||||
defer metrics.MeasureSinceWithLabels([]string{"fsm", "intention"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
|
||||
switch req.Op {
|
||||
case structs.IntentionOpCreate, structs.IntentionOpUpdate:
|
||||
return c.state.IntentionSet(index, req.Intention)
|
||||
case structs.IntentionOpDelete:
|
||||
return c.state.IntentionDelete(index, req.Intention.ID)
|
||||
default:
|
||||
c.logger.Printf("[WARN] consul.fsm: Invalid Intention operation '%s'", req.Op)
|
||||
return fmt.Errorf("Invalid Intention operation '%s'", req.Op)
|
||||
}
|
||||
}
|
||||
|
||||
// applyConnectCAOperation applies the given CA operation to the state store.
|
||||
func (c *FSM) applyConnectCAOperation(buf []byte, index uint64) interface{} {
|
||||
var req structs.CARequest
|
||||
if err := structs.Decode(buf, &req); err != nil {
|
||||
panic(fmt.Errorf("failed to decode request: %v", err))
|
||||
}
|
||||
|
||||
defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "ca"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
|
||||
defer metrics.MeasureSinceWithLabels([]string{"fsm", "ca"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
|
||||
switch req.Op {
|
||||
case structs.CAOpSetConfig:
|
||||
if req.Config.ModifyIndex != 0 {
|
||||
act, err := c.state.CACheckAndSetConfig(index, req.Config.ModifyIndex, req.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return act
|
||||
}
|
||||
|
||||
return c.state.CASetConfig(index, req.Config)
|
||||
case structs.CAOpSetRoots:
|
||||
act, err := c.state.CARootSetCAS(index, req.Index, req.Roots)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return act
|
||||
case structs.CAOpSetProviderState:
|
||||
act, err := c.state.CASetProviderState(index, req.ProviderState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return act
|
||||
case structs.CAOpDeleteProviderState:
|
||||
if err := c.state.CADeleteProviderState(req.ProviderState.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return true
|
||||
case structs.CAOpSetRootsAndConfig:
|
||||
act, err := c.state.CARootSetCAS(index, req.Index, req.Roots)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.state.CASetConfig(index+1, req.Config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return act
|
||||
default:
|
||||
c.logger.Printf("[WARN] consul.fsm: Invalid CA operation '%s'", req.Op)
|
||||
return fmt.Errorf("Invalid CA operation '%s'", req.Op)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,13 +8,16 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/serf/coordinate"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func generateUUID() (ret string) {
|
||||
|
@ -1148,3 +1151,209 @@ func TestFSM_Autopilot(t *testing.T) {
|
|||
t.Fatalf("bad: %v", config.CleanupDeadServers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFSM_Intention_CRUD(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
fsm, err := New(nil, os.Stderr)
|
||||
assert.Nil(err)
|
||||
|
||||
// Create a new intention.
|
||||
ixn := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: structs.TestIntention(t),
|
||||
}
|
||||
ixn.Intention.ID = generateUUID()
|
||||
ixn.Intention.UpdatePrecedence()
|
||||
|
||||
{
|
||||
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
|
||||
assert.Nil(err)
|
||||
assert.Nil(fsm.Apply(makeLog(buf)))
|
||||
}
|
||||
|
||||
// Verify it's in the state store.
|
||||
{
|
||||
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
|
||||
assert.Nil(err)
|
||||
|
||||
actual.CreateIndex, actual.ModifyIndex = 0, 0
|
||||
actual.CreatedAt = ixn.Intention.CreatedAt
|
||||
actual.UpdatedAt = ixn.Intention.UpdatedAt
|
||||
assert.Equal(ixn.Intention, actual)
|
||||
}
|
||||
|
||||
// Make an update
|
||||
ixn.Op = structs.IntentionOpUpdate
|
||||
ixn.Intention.SourceName = "api"
|
||||
{
|
||||
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
|
||||
assert.Nil(err)
|
||||
assert.Nil(fsm.Apply(makeLog(buf)))
|
||||
}
|
||||
|
||||
// Verify the update.
|
||||
{
|
||||
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
|
||||
assert.Nil(err)
|
||||
|
||||
actual.CreateIndex, actual.ModifyIndex = 0, 0
|
||||
actual.CreatedAt = ixn.Intention.CreatedAt
|
||||
actual.UpdatedAt = ixn.Intention.UpdatedAt
|
||||
assert.Equal(ixn.Intention, actual)
|
||||
}
|
||||
|
||||
// Delete
|
||||
ixn.Op = structs.IntentionOpDelete
|
||||
{
|
||||
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
|
||||
assert.Nil(err)
|
||||
assert.Nil(fsm.Apply(makeLog(buf)))
|
||||
}
|
||||
|
||||
// Make sure it's gone.
|
||||
{
|
||||
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
|
||||
assert.Nil(err)
|
||||
assert.Nil(actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFSM_CAConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
fsm, err := New(nil, os.Stderr)
|
||||
assert.Nil(err)
|
||||
|
||||
// Set the autopilot config using a request.
|
||||
req := structs.CARequest{
|
||||
Op: structs.CAOpSetConfig,
|
||||
Config: &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": "asdf",
|
||||
"RootCert": "qwer",
|
||||
"RotationPeriod": 90 * 24 * time.Hour,
|
||||
},
|
||||
},
|
||||
}
|
||||
buf, err := structs.Encode(structs.ConnectCARequestType, req)
|
||||
assert.Nil(err)
|
||||
resp := fsm.Apply(makeLog(buf))
|
||||
if _, ok := resp.(error); ok {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
|
||||
// Verify key is set directly in the state store.
|
||||
_, config, err := fsm.state.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
var conf *structs.ConsulCAProviderConfig
|
||||
if err := mapstructure.WeakDecode(config.Config, &conf); err != nil {
|
||||
t.Fatalf("error decoding config: %s, %v", err, config.Config)
|
||||
}
|
||||
if got, want := config.Provider, req.Config.Provider; got != want {
|
||||
t.Fatalf("got %v, want %v", got, want)
|
||||
}
|
||||
if got, want := conf.PrivateKey, "asdf"; got != want {
|
||||
t.Fatalf("got %v, want %v", got, want)
|
||||
}
|
||||
if got, want := conf.RootCert, "qwer"; got != want {
|
||||
t.Fatalf("got %v, want %v", got, want)
|
||||
}
|
||||
if got, want := conf.RotationPeriod, 90*24*time.Hour; got != want {
|
||||
t.Fatalf("got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
// Now use CAS and provide an old index
|
||||
req.Config.Provider = "static"
|
||||
req.Config.ModifyIndex = config.ModifyIndex - 1
|
||||
buf, err = structs.Encode(structs.ConnectCARequestType, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp = fsm.Apply(makeLog(buf))
|
||||
if _, ok := resp.(error); ok {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
|
||||
_, config, err = fsm.state.CAConfig()
|
||||
assert.Nil(err)
|
||||
if config.Provider != "static" {
|
||||
t.Fatalf("bad: %v", config.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFSM_CARoots(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
fsm, err := New(nil, os.Stderr)
|
||||
assert.Nil(err)
|
||||
|
||||
// Roots
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
ca2.Active = false
|
||||
|
||||
// Create a new request.
|
||||
req := structs.CARequest{
|
||||
Op: structs.CAOpSetRoots,
|
||||
Roots: []*structs.CARoot{ca1, ca2},
|
||||
}
|
||||
|
||||
{
|
||||
buf, err := structs.Encode(structs.ConnectCARequestType, req)
|
||||
assert.Nil(err)
|
||||
assert.True(fsm.Apply(makeLog(buf)).(bool))
|
||||
}
|
||||
|
||||
// Verify it's in the state store.
|
||||
{
|
||||
_, roots, err := fsm.state.CARoots(nil)
|
||||
assert.Nil(err)
|
||||
assert.Len(roots, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFSM_CABuiltinProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
fsm, err := New(nil, os.Stderr)
|
||||
assert.Nil(err)
|
||||
|
||||
// Provider state.
|
||||
expected := &structs.CAConsulProviderState{
|
||||
ID: "foo",
|
||||
PrivateKey: "a",
|
||||
RootCert: "b",
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 1,
|
||||
ModifyIndex: 1,
|
||||
},
|
||||
}
|
||||
|
||||
// Create a new request.
|
||||
req := structs.CARequest{
|
||||
Op: structs.CAOpSetProviderState,
|
||||
ProviderState: expected,
|
||||
}
|
||||
|
||||
{
|
||||
buf, err := structs.Encode(structs.ConnectCARequestType, req)
|
||||
assert.Nil(err)
|
||||
assert.True(fsm.Apply(makeLog(buf)).(bool))
|
||||
}
|
||||
|
||||
// Verify it's in the state store.
|
||||
{
|
||||
_, state, err := fsm.state.CAProviderState("foo")
|
||||
assert.Nil(err)
|
||||
assert.Equal(expected, state)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@ func init() {
|
|||
registerRestorer(structs.CoordinateBatchUpdateType, restoreCoordinates)
|
||||
registerRestorer(structs.PreparedQueryRequestType, restorePreparedQuery)
|
||||
registerRestorer(structs.AutopilotRequestType, restoreAutopilot)
|
||||
registerRestorer(structs.IntentionRequestType, restoreIntention)
|
||||
registerRestorer(structs.ConnectCARequestType, restoreConnectCA)
|
||||
}
|
||||
|
||||
func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error {
|
||||
|
@ -44,6 +46,12 @@ func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) err
|
|||
if err := s.persistAutopilot(sink, encoder); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.persistIntentions(sink, encoder); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.persistConnectCA(sink, encoder); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -258,6 +266,42 @@ func (s *snapshot) persistAutopilot(sink raft.SnapshotSink,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *snapshot) persistConnectCA(sink raft.SnapshotSink,
|
||||
encoder *codec.Encoder) error {
|
||||
roots, err := s.state.CARoots()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, r := range roots {
|
||||
if _, err := sink.Write([]byte{byte(structs.ConnectCARequestType)}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := encoder.Encode(r); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *snapshot) persistIntentions(sink raft.SnapshotSink,
|
||||
encoder *codec.Encoder) error {
|
||||
ixns, err := s.state.Intentions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, ixn := range ixns {
|
||||
if _, err := sink.Write([]byte{byte(structs.IntentionRequestType)}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := encoder.Encode(ixn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreRegistration(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
|
||||
var req structs.RegisterRequest
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
|
@ -364,3 +408,25 @@ func restoreAutopilot(header *snapshotHeader, restore *state.Restore, decoder *c
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreIntention(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
|
||||
var req structs.Intention
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := restore.Intention(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreConnectCA(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
|
||||
var req structs.CARoot
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := restore.CARoot(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -7,16 +7,20 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFSM_SnapshotRestore_OSS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
fsm, err := New(nil, os.Stderr)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
|
@ -98,6 +102,27 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
|
|||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Intentions
|
||||
ixn := structs.TestIntention(t)
|
||||
ixn.ID = generateUUID()
|
||||
ixn.RaftIndex = structs.RaftIndex{
|
||||
CreateIndex: 14,
|
||||
ModifyIndex: 14,
|
||||
}
|
||||
assert.Nil(fsm.state.IntentionSet(14, ixn))
|
||||
|
||||
// CA Roots
|
||||
roots := []*structs.CARoot{
|
||||
connect.TestCA(t, nil),
|
||||
connect.TestCA(t, nil),
|
||||
}
|
||||
for _, r := range roots[1:] {
|
||||
r.Active = false
|
||||
}
|
||||
ok, err := fsm.state.CARootSetCAS(15, 0, roots)
|
||||
assert.Nil(err)
|
||||
assert.True(ok)
|
||||
|
||||
// Snapshot
|
||||
snap, err := fsm.Snapshot()
|
||||
if err != nil {
|
||||
|
@ -260,6 +285,17 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
|
|||
t.Fatalf("bad: %#v, %#v", restoredConf, autopilotConf)
|
||||
}
|
||||
|
||||
// Verify intentions are restored.
|
||||
_, ixns, err := fsm2.state.Intentions(nil)
|
||||
assert.Nil(err)
|
||||
assert.Len(ixns, 1)
|
||||
assert.Equal(ixn, ixns[0])
|
||||
|
||||
// Verify CA roots are restored.
|
||||
_, roots, err = fsm2.state.CARoots(nil)
|
||||
assert.Nil(err)
|
||||
assert.Len(roots, 2)
|
||||
|
||||
// Snapshot
|
||||
snap, err = fsm2.Snapshot()
|
||||
if err != nil {
|
||||
|
|
|
@ -111,18 +111,37 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc
|
|||
return fmt.Errorf("Must provide service name")
|
||||
}
|
||||
|
||||
// Determine the function we'll call
|
||||
var f func(memdb.WatchSet, *state.Store, *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error)
|
||||
switch {
|
||||
case args.Connect:
|
||||
f = h.serviceNodesConnect
|
||||
case args.TagFilter:
|
||||
f = h.serviceNodesTagFilter
|
||||
default:
|
||||
f = h.serviceNodesDefault
|
||||
}
|
||||
|
||||
// If we're doing a connect query, we need read access to the service
|
||||
// we're trying to find proxies for, so check that.
|
||||
if args.Connect {
|
||||
// Fetch the ACL token, if any.
|
||||
rule, err := h.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule != nil && !rule.ServiceRead(args.ServiceName) {
|
||||
// Just return nil, which will return an empty response (tested)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
err := h.srv.blockingQuery(
|
||||
&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
var index uint64
|
||||
var nodes structs.CheckServiceNodes
|
||||
var err error
|
||||
if args.TagFilter {
|
||||
index, nodes, err = state.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
|
||||
} else {
|
||||
index, nodes, err = state.CheckServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
index, nodes, err := f(ws, state, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -139,16 +158,37 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc
|
|||
|
||||
// Provide some metrics
|
||||
if err == nil {
|
||||
metrics.IncrCounterWithLabels([]string{"health", "service", "query"}, 1,
|
||||
// For metrics, we separate Connect-based lookups from non-Connect
|
||||
key := "service"
|
||||
if args.Connect {
|
||||
key = "connect"
|
||||
}
|
||||
|
||||
metrics.IncrCounterWithLabels([]string{"health", key, "query"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
|
||||
if args.ServiceTag != "" {
|
||||
metrics.IncrCounterWithLabels([]string{"health", "service", "query-tag"}, 1,
|
||||
metrics.IncrCounterWithLabels([]string{"health", key, "query-tag"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}, {Name: "tag", Value: args.ServiceTag}})
|
||||
}
|
||||
if len(reply.Nodes) == 0 {
|
||||
metrics.IncrCounterWithLabels([]string{"health", "service", "not-found"}, 1,
|
||||
metrics.IncrCounterWithLabels([]string{"health", key, "not-found"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// The serviceNodes* functions below are the various lookup methods that
|
||||
// can be used by the ServiceNodes endpoint.
|
||||
|
||||
func (h *Health) serviceNodesConnect(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
|
||||
return s.CheckConnectServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
|
||||
func (h *Health) serviceNodesTagFilter(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
|
||||
return s.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
|
||||
}
|
||||
|
||||
func (h *Health) serviceNodesDefault(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
|
||||
return s.CheckServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHealth_ChecksInState(t *testing.T) {
|
||||
|
@ -821,6 +822,106 @@ func TestHealth_ServiceNodes_DistanceSort(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHealth_ServiceNodes_ConnectProxy_ACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServerWithConfig(t, func(c *Config) {
|
||||
c.ACLDatacenter = "dc1"
|
||||
c.ACLMasterToken = "root"
|
||||
c.ACLDefaultPolicy = "deny"
|
||||
c.ACLEnforceVersion8 = false
|
||||
})
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Create the ACL.
|
||||
arg := structs.ACLRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.ACLSet,
|
||||
ACL: structs.ACL{
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
Rules: `
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
}
|
||||
`,
|
||||
},
|
||||
WriteRequest: structs.WriteRequest{Token: "root"},
|
||||
}
|
||||
var token string
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", arg, &token))
|
||||
|
||||
{
|
||||
var out struct{}
|
||||
|
||||
// Register a service
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.WriteRequest.Token = "root"
|
||||
args.Service.ID = "foo-proxy-0"
|
||||
args.Service.Service = "foo-proxy"
|
||||
args.Service.ProxyDestination = "bar"
|
||||
args.Check = &structs.HealthCheck{
|
||||
Name: "proxy",
|
||||
Status: api.HealthPassing,
|
||||
ServiceID: args.Service.ID,
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// Register a service
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.WriteRequest.Token = "root"
|
||||
args.Service.Service = "foo-proxy"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.Check = &structs.HealthCheck{
|
||||
Name: "proxy",
|
||||
Status: api.HealthPassing,
|
||||
ServiceID: args.Service.Service,
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// Register a service
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.WriteRequest.Token = "root"
|
||||
args.Service.Service = "another-proxy"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.Check = &structs.HealthCheck{
|
||||
Name: "proxy",
|
||||
Status: api.HealthPassing,
|
||||
ServiceID: args.Service.Service,
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
}
|
||||
|
||||
// List w/ token. This should disallow because we don't have permission
|
||||
// to read "bar"
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: "bar",
|
||||
QueryOptions: structs.QueryOptions{Token: token},
|
||||
}
|
||||
var resp structs.IndexedCheckServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.Nodes, 0)
|
||||
|
||||
// List w/ token. This should work since we're requesting "foo", but should
|
||||
// also only contain the proxies with names that adhere to our ACL.
|
||||
req = structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: "foo",
|
||||
QueryOptions: structs.QueryOptions{Token: token},
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.Nodes, 1)
|
||||
}
|
||||
|
||||
func TestHealth_NodeChecks_FilterACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir, token, srv, codec := testACLFilterServer(t)
|
||||
|
|
|
@ -0,0 +1,358 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrIntentionNotFound is returned if the intention lookup failed.
|
||||
ErrIntentionNotFound = errors.New("Intention not found")
|
||||
)
|
||||
|
||||
// Intention manages the Connect intentions.
|
||||
type Intention struct {
|
||||
// srv is a pointer back to the server.
|
||||
srv *Server
|
||||
}
|
||||
|
||||
// Apply creates or updates an intention in the data store.
|
||||
func (s *Intention) Apply(
|
||||
args *structs.IntentionRequest,
|
||||
reply *string) error {
|
||||
if done, err := s.srv.forward("Intention.Apply", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
defer metrics.MeasureSince([]string{"consul", "intention", "apply"}, time.Now())
|
||||
defer metrics.MeasureSince([]string{"intention", "apply"}, time.Now())
|
||||
|
||||
// Always set a non-nil intention to avoid nil-access below
|
||||
if args.Intention == nil {
|
||||
args.Intention = &structs.Intention{}
|
||||
}
|
||||
|
||||
// If no ID is provided, generate a new ID. This must be done prior to
|
||||
// appending to the Raft log, because the ID is not deterministic. Once
|
||||
// the entry is in the log, the state update MUST be deterministic or
|
||||
// the followers will not converge.
|
||||
if args.Op == structs.IntentionOpCreate {
|
||||
if args.Intention.ID != "" {
|
||||
return fmt.Errorf("ID must be empty when creating a new intention")
|
||||
}
|
||||
|
||||
state := s.srv.fsm.State()
|
||||
for {
|
||||
var err error
|
||||
args.Intention.ID, err = uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
s.srv.logger.Printf("[ERR] consul.intention: UUID generation failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, ixn, err := state.IntentionGet(nil, args.Intention.ID)
|
||||
if err != nil {
|
||||
s.srv.logger.Printf("[ERR] consul.intention: intention lookup failed: %v", err)
|
||||
return err
|
||||
}
|
||||
if ixn == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Set the created at
|
||||
args.Intention.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
*reply = args.Intention.ID
|
||||
|
||||
// Get the ACL token for the request for the checks below.
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Perform the ACL check
|
||||
if prefix, ok := args.Intention.GetACLPrefix(); ok {
|
||||
if rule != nil && !rule.IntentionWrite(prefix) {
|
||||
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention '%s' denied due to ACLs", args.Intention.ID)
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
// If this is not a create, then we have to verify the ID.
|
||||
if args.Op != structs.IntentionOpCreate {
|
||||
state := s.srv.fsm.State()
|
||||
_, ixn, err := state.IntentionGet(nil, args.Intention.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Intention lookup failed: %v", err)
|
||||
}
|
||||
if ixn == nil {
|
||||
return fmt.Errorf("Cannot modify non-existent intention: '%s'", args.Intention.ID)
|
||||
}
|
||||
|
||||
// Perform the ACL check that we have write to the old prefix too,
|
||||
// which must be true to perform any rename.
|
||||
if prefix, ok := ixn.GetACLPrefix(); ok {
|
||||
if rule != nil && !rule.IntentionWrite(prefix) {
|
||||
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention '%s' denied due to ACLs", args.Intention.ID)
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We always update the updatedat field. This has no effect for deletion.
|
||||
args.Intention.UpdatedAt = time.Now().UTC()
|
||||
|
||||
// Default source type
|
||||
if args.Intention.SourceType == "" {
|
||||
args.Intention.SourceType = structs.IntentionSourceConsul
|
||||
}
|
||||
|
||||
// Until we support namespaces, we force all namespaces to be default
|
||||
if args.Intention.SourceNS == "" {
|
||||
args.Intention.SourceNS = structs.IntentionDefaultNamespace
|
||||
}
|
||||
if args.Intention.DestinationNS == "" {
|
||||
args.Intention.DestinationNS = structs.IntentionDefaultNamespace
|
||||
}
|
||||
|
||||
// Validate. We do not validate on delete since it is valid to only
|
||||
// send an ID in that case.
|
||||
if args.Op != structs.IntentionOpDelete {
|
||||
// Set the precedence
|
||||
args.Intention.UpdatePrecedence()
|
||||
|
||||
if err := args.Intention.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit
|
||||
resp, err := s.srv.raftApply(structs.IntentionRequestType, args)
|
||||
if err != nil {
|
||||
s.srv.logger.Printf("[ERR] consul.intention: Apply failed %v", err)
|
||||
return err
|
||||
}
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns a single intention by ID.
|
||||
func (s *Intention) Get(
|
||||
args *structs.IntentionQueryRequest,
|
||||
reply *structs.IndexedIntentions) error {
|
||||
// Forward if necessary
|
||||
if done, err := s.srv.forward("Intention.Get", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.srv.blockingQuery(
|
||||
&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, ixn, err := state.IntentionGet(ws, args.IntentionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ixn == nil {
|
||||
return ErrIntentionNotFound
|
||||
}
|
||||
|
||||
reply.Index = index
|
||||
reply.Intentions = structs.Intentions{ixn}
|
||||
|
||||
// Filter
|
||||
if err := s.srv.filterACL(args.Token, reply); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If ACLs prevented any responses, error
|
||||
if len(reply.Intentions) == 0 {
|
||||
s.srv.logger.Printf("[WARN] consul.intention: Request to get intention '%s' denied due to ACLs", args.IntentionID)
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// List returns all the intentions.
|
||||
func (s *Intention) List(
|
||||
args *structs.DCSpecificRequest,
|
||||
reply *structs.IndexedIntentions) error {
|
||||
// Forward if necessary
|
||||
if done, err := s.srv.forward("Intention.List", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.srv.blockingQuery(
|
||||
&args.QueryOptions, &reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, ixns, err := state.Intentions(ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index, reply.Intentions = index, ixns
|
||||
if reply.Intentions == nil {
|
||||
reply.Intentions = make(structs.Intentions, 0)
|
||||
}
|
||||
|
||||
return s.srv.filterACL(args.Token, reply)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Match returns the set of intentions that match the given source/destination.
|
||||
func (s *Intention) Match(
|
||||
args *structs.IntentionQueryRequest,
|
||||
reply *structs.IndexedIntentionMatches) error {
|
||||
// Forward if necessary
|
||||
if done, err := s.srv.forward("Intention.Match", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the ACL token for the request for the checks below.
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule != nil {
|
||||
// We go through each entry and test the destination to check if it
|
||||
// matches.
|
||||
for _, entry := range args.Match.Entries {
|
||||
if prefix := entry.Name; prefix != "" && !rule.IntentionRead(prefix) {
|
||||
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention prefix '%s' denied due to ACLs", prefix)
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.srv.blockingQuery(
|
||||
&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, matches, err := state.IntentionMatch(ws, args.Match)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index = index
|
||||
reply.Matches = matches
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Check tests a source/destination and returns whether it would be allowed
|
||||
// or denied based on the current ACL configuration.
|
||||
//
|
||||
// Note: Whenever the logic for this method is changed, you should take
|
||||
// a look at the agent authorize endpoint (agent/agent_endpoint.go) since
|
||||
// the logic there is similar.
|
||||
func (s *Intention) Check(
|
||||
args *structs.IntentionQueryRequest,
|
||||
reply *structs.IntentionQueryCheckResponse) error {
|
||||
// Forward maybe
|
||||
if done, err := s.srv.forward("Intention.Check", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the test args, and defensively guard against nil
|
||||
query := args.Check
|
||||
if query == nil {
|
||||
return errors.New("Check must be specified on args")
|
||||
}
|
||||
|
||||
// Build the URI
|
||||
var uri connect.CertURI
|
||||
switch query.SourceType {
|
||||
case structs.IntentionSourceConsul:
|
||||
uri = &connect.SpiffeIDService{
|
||||
Namespace: query.SourceNS,
|
||||
Service: query.SourceName,
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported SourceType: %q", query.SourceType)
|
||||
}
|
||||
|
||||
// Get the ACL token for the request for the checks below.
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Perform the ACL check. For Check we only require ServiceRead and
|
||||
// NOT IntentionRead because the Check API only returns pass/fail and
|
||||
// returns no other information about the intentions used.
|
||||
if prefix, ok := query.GetACLPrefix(); ok {
|
||||
if rule != nil && !rule.ServiceRead(prefix) {
|
||||
s.srv.logger.Printf("[WARN] consul.intention: test on intention '%s' denied due to ACLs", prefix)
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
// Get the matches for this destination
|
||||
state := s.srv.fsm.State()
|
||||
_, matches, err := state.IntentionMatch(nil, &structs.IntentionQueryMatch{
|
||||
Type: structs.IntentionMatchDestination,
|
||||
Entries: []structs.IntentionMatchEntry{
|
||||
structs.IntentionMatchEntry{
|
||||
Namespace: query.DestinationNS,
|
||||
Name: query.DestinationName,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(matches) != 1 {
|
||||
// This should never happen since the documented behavior of the
|
||||
// Match call is that it'll always return exactly the number of results
|
||||
// as entries passed in. But we guard against misbehavior.
|
||||
return errors.New("internal error loading matches")
|
||||
}
|
||||
|
||||
// Check the authorization for each match
|
||||
for _, ixn := range matches[0] {
|
||||
if auth, ok := uri.Authorize(ixn); ok {
|
||||
reply.Allowed = auth
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// No match, we need to determine the default behavior. We do this by
|
||||
// specifying the anonymous token token, which will get that behavior.
|
||||
// The default behavior if ACLs are disabled is to allow connections
|
||||
// to mimic the behavior of Consul itself: everything is allowed if
|
||||
// ACLs are disabled.
|
||||
//
|
||||
// NOTE(mitchellh): This is the same behavior as the agent authorize
|
||||
// endpoint. If this behavior is incorrect, we should also change it there
|
||||
// which is much more important.
|
||||
rule, err = s.srv.resolveToken("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Allowed = true
|
||||
if rule != nil {
|
||||
reply.Allowed = rule.IntentionDefaultAllow()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -4,16 +4,20 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
ca "github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/types"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/hashicorp/raft"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
|
@ -210,6 +214,12 @@ func (s *Server) establishLeadership() error {
|
|||
|
||||
s.getOrCreateAutopilotConfig()
|
||||
s.autopilot.Start()
|
||||
|
||||
// todo(kyhavlov): start a goroutine here for handling periodic CA rotation
|
||||
if err := s.initializeCA(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.setConsistentReadReady()
|
||||
return nil
|
||||
}
|
||||
|
@ -226,6 +236,8 @@ func (s *Server) revokeLeadership() error {
|
|||
return err
|
||||
}
|
||||
|
||||
s.setCAProvider(nil, nil)
|
||||
|
||||
s.resetConsistentReadReady()
|
||||
s.autopilot.Stop()
|
||||
return nil
|
||||
|
@ -359,6 +371,185 @@ func (s *Server) getOrCreateAutopilotConfig() *autopilot.Config {
|
|||
return config
|
||||
}
|
||||
|
||||
// initializeCAConfig is used to initialize the CA config if necessary
|
||||
// when setting up the CA during establishLeadership
|
||||
func (s *Server) initializeCAConfig() (*structs.CAConfiguration, error) {
|
||||
state := s.fsm.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if config != nil {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
config = s.config.CAConfig
|
||||
if config.ClusterID == "" {
|
||||
id, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.ClusterID = id
|
||||
}
|
||||
|
||||
req := structs.CARequest{
|
||||
Op: structs.CAOpSetConfig,
|
||||
Config: config,
|
||||
}
|
||||
if _, err = s.raftApply(structs.ConnectCARequestType, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// initializeCA sets up the CA provider when gaining leadership, bootstrapping
|
||||
// the root in the state store if necessary.
|
||||
func (s *Server) initializeCA() error {
|
||||
// Bail if connect isn't enabled.
|
||||
if !s.config.ConnectEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
conf, err := s.initializeCAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the right provider based on the config
|
||||
provider, err := s.createCAProvider(conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the active root cert from the CA
|
||||
rootPEM, err := provider.ActiveRoot()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting root cert: %v", err)
|
||||
}
|
||||
|
||||
rootCA, err := parseCARoot(rootPEM, conf.Provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(banks): in the case that we've just gained leadership in an already
|
||||
// configured cluster. We really need to fetch RootCA from state to provide it
|
||||
// in setCAProvider. This matters because if the current active root has
|
||||
// intermediates, parsing the rootCA from only the root cert PEM above will
|
||||
// not include them and so leafs we sign will not bundle the intermediates.
|
||||
|
||||
s.setCAProvider(provider, rootCA)
|
||||
|
||||
// Check if the CA root is already initialized and exit if it is.
|
||||
// Every change to the CA after this initial bootstrapping should
|
||||
// be done through the rotation process.
|
||||
state := s.fsm.State()
|
||||
_, activeRoot, err := state.CARootActive(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if activeRoot != nil {
|
||||
if activeRoot.ID != rootCA.ID {
|
||||
// TODO(banks): this seems like a pretty catastrophic state to get into.
|
||||
// Shouldn't we do something stronger than warn and continue signing with
|
||||
// a key that's not the active CA according to the state?
|
||||
s.logger.Printf("[WARN] connect: CA root %q is not the active root (%q)", rootCA.ID, activeRoot.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the highest index
|
||||
idx, _, err := state.CARoots(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the root cert in raft
|
||||
resp, err := s.raftApply(structs.ConnectCARequestType, &structs.CARequest{
|
||||
Op: structs.CAOpSetRoots,
|
||||
Index: idx,
|
||||
Roots: []*structs.CARoot{rootCA},
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Printf("[ERR] connect: Apply failed %v", err)
|
||||
return err
|
||||
}
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
s.logger.Printf("[INFO] connect: initialized CA with provider %q", conf.Provider)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseCARoot returns a filled-in structs.CARoot from a raw PEM value.
|
||||
func parseCARoot(pemValue, provider string) (*structs.CARoot, error) {
|
||||
id, err := connect.CalculateCertFingerprint(pemValue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing root fingerprint: %v", err)
|
||||
}
|
||||
rootCert, err := connect.ParseCert(pemValue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing root cert: %v", err)
|
||||
}
|
||||
return &structs.CARoot{
|
||||
ID: id,
|
||||
Name: fmt.Sprintf("%s CA Root Cert", strings.Title(provider)),
|
||||
SerialNumber: rootCert.SerialNumber.Uint64(),
|
||||
SigningKeyID: connect.HexString(rootCert.AuthorityKeyId),
|
||||
NotBefore: rootCert.NotBefore,
|
||||
NotAfter: rootCert.NotAfter,
|
||||
RootCert: pemValue,
|
||||
Active: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createProvider returns a connect CA provider from the given config.
|
||||
func (s *Server) createCAProvider(conf *structs.CAConfiguration) (ca.Provider, error) {
|
||||
switch conf.Provider {
|
||||
case structs.ConsulCAProvider:
|
||||
return ca.NewConsulProvider(conf.Config, &consulCADelegate{s})
|
||||
case structs.VaultCAProvider:
|
||||
return ca.NewVaultProvider(conf.Config, conf.ClusterID)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown CA provider %q", conf.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) getCAProvider() (ca.Provider, *structs.CARoot) {
|
||||
retries := 0
|
||||
var result ca.Provider
|
||||
var resultRoot *structs.CARoot
|
||||
for result == nil {
|
||||
s.caProviderLock.RLock()
|
||||
result = s.caProvider
|
||||
resultRoot = s.caProviderRoot
|
||||
s.caProviderLock.RUnlock()
|
||||
|
||||
// In cases where an agent is started with managed proxies, we may ask
|
||||
// for the provider before establishLeadership completes. If we're the
|
||||
// leader, then wait and get the provider again
|
||||
if result == nil && s.IsLeader() && retries < 10 {
|
||||
retries++
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return result, resultRoot
|
||||
}
|
||||
|
||||
func (s *Server) setCAProvider(newProvider ca.Provider, root *structs.CARoot) {
|
||||
s.caProviderLock.Lock()
|
||||
defer s.caProviderLock.Unlock()
|
||||
s.caProvider = newProvider
|
||||
s.caProviderRoot = root
|
||||
}
|
||||
|
||||
// reconcileReaped is used to reconcile nodes that have failed and been reaped
|
||||
// from Serf but remain in the catalog. This is done by looking for unknown nodes with serfHealth checks registered.
|
||||
// We generate a "reap" event to cause the node to be cleaned up.
|
||||
|
|
|
@ -354,7 +354,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
|
|||
}
|
||||
|
||||
// Execute the query for the local DC.
|
||||
if err := p.execute(query, reply); err != nil {
|
||||
if err := p.execute(query, reply, args.Connect); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -450,7 +450,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
|
|||
// by the query setup.
|
||||
if len(reply.Nodes) == 0 {
|
||||
wrapper := &queryServerWrapper{p.srv}
|
||||
if err := queryFailover(wrapper, query, args.Limit, args.QueryOptions, reply); err != nil {
|
||||
if err := queryFailover(wrapper, query, args, reply); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -479,7 +479,7 @@ func (p *PreparedQuery) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRe
|
|||
}
|
||||
|
||||
// Run the query locally to see what we can find.
|
||||
if err := p.execute(&args.Query, reply); err != nil {
|
||||
if err := p.execute(&args.Query, reply, args.Connect); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -509,9 +509,18 @@ func (p *PreparedQuery) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRe
|
|||
// execute runs a prepared query in the local DC without any failover. We don't
|
||||
// apply any sorting options or ACL checks at this level - it should be done up above.
|
||||
func (p *PreparedQuery) execute(query *structs.PreparedQuery,
|
||||
reply *structs.PreparedQueryExecuteResponse) error {
|
||||
reply *structs.PreparedQueryExecuteResponse,
|
||||
forceConnect bool) error {
|
||||
state := p.srv.fsm.State()
|
||||
_, nodes, err := state.CheckServiceNodes(nil, query.Service.Service)
|
||||
|
||||
// If we're requesting Connect-capable services, then switch the
|
||||
// lookup to be the Connect function.
|
||||
f := state.CheckServiceNodes
|
||||
if query.Service.Connect || forceConnect {
|
||||
f = state.CheckConnectServiceNodes
|
||||
}
|
||||
|
||||
_, nodes, err := f(nil, query.Service.Service)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -651,7 +660,7 @@ func (q *queryServerWrapper) ForwardDC(method, dc string, args interface{}, repl
|
|||
// queryFailover runs an algorithm to determine which DCs to try and then calls
|
||||
// them to try to locate alternative services.
|
||||
func queryFailover(q queryServer, query *structs.PreparedQuery,
|
||||
limit int, options structs.QueryOptions,
|
||||
args *structs.PreparedQueryExecuteRequest,
|
||||
reply *structs.PreparedQueryExecuteResponse) error {
|
||||
|
||||
// Pull the list of other DCs. This is sorted by RTT in case the user
|
||||
|
@ -719,8 +728,9 @@ func queryFailover(q queryServer, query *structs.PreparedQuery,
|
|||
remote := &structs.PreparedQueryExecuteRemoteRequest{
|
||||
Datacenter: dc,
|
||||
Query: *query,
|
||||
Limit: limit,
|
||||
QueryOptions: options,
|
||||
Limit: args.Limit,
|
||||
QueryOptions: args.QueryOptions,
|
||||
Connect: args.Connect,
|
||||
}
|
||||
if err := q.ForwardDC("PreparedQuery.ExecuteRemote", dc, remote, reply); err != nil {
|
||||
q.GetLogger().Printf("[WARN] consul.prepared_query: Failed querying for service '%s' in datacenter '%s': %s", query.Service.Service, dc, err)
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/hashicorp/consul/types"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/hashicorp/serf/coordinate"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPreparedQuery_Apply(t *testing.T) {
|
||||
|
@ -2617,6 +2618,159 @@ func TestPreparedQuery_Execute_ForwardLeader(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
// Setup 3 services on 3 nodes: one is non-Connect, one is Connect native,
|
||||
// and one is a proxy to the non-Connect one.
|
||||
for i := 0; i < 3; i++ {
|
||||
req := structs.RegisterRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: fmt.Sprintf("node%d", i+1),
|
||||
Address: fmt.Sprintf("127.0.0.%d", i+1),
|
||||
Service: &structs.NodeService{
|
||||
Service: "foo",
|
||||
Port: 8000,
|
||||
},
|
||||
}
|
||||
|
||||
switch i {
|
||||
case 0:
|
||||
// Default do nothing
|
||||
|
||||
case 1:
|
||||
// Connect native
|
||||
req.Service.Connect.Native = true
|
||||
|
||||
case 2:
|
||||
// Connect proxy
|
||||
req.Service.Kind = structs.ServiceKindConnectProxy
|
||||
req.Service.ProxyDestination = req.Service.Service
|
||||
req.Service.Service = "proxy"
|
||||
}
|
||||
|
||||
var reply struct{}
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &reply))
|
||||
}
|
||||
|
||||
// The query, start with connect disabled
|
||||
query := structs.PreparedQueryRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.PreparedQueryCreate,
|
||||
Query: &structs.PreparedQuery{
|
||||
Name: "test",
|
||||
Service: structs.ServiceQuery{
|
||||
Service: "foo",
|
||||
},
|
||||
DNS: structs.QueryDNSOptions{
|
||||
TTL: "10s",
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
|
||||
|
||||
// In the future we'll run updates
|
||||
query.Op = structs.PreparedQueryUpdate
|
||||
|
||||
// Run the registered query.
|
||||
{
|
||||
req := structs.PreparedQueryExecuteRequest{
|
||||
Datacenter: "dc1",
|
||||
QueryIDOrName: query.Query.ID,
|
||||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Execute", &req, &reply))
|
||||
|
||||
// Result should have two because it omits the proxy whose name
|
||||
// doesn't match the query.
|
||||
require.Len(reply.Nodes, 2)
|
||||
require.Equal(query.Query.Service.Service, reply.Service)
|
||||
require.Equal(query.Query.DNS, reply.DNS)
|
||||
require.True(reply.QueryMeta.KnownLeader, "queried leader")
|
||||
}
|
||||
|
||||
// Run with the Connect setting specified on the request
|
||||
{
|
||||
req := structs.PreparedQueryExecuteRequest{
|
||||
Datacenter: "dc1",
|
||||
QueryIDOrName: query.Query.ID,
|
||||
Connect: true,
|
||||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Execute", &req, &reply))
|
||||
|
||||
// Result should have two because we should get the native AND
|
||||
// the proxy (since the destination matches our service name).
|
||||
require.Len(reply.Nodes, 2)
|
||||
require.Equal(query.Query.Service.Service, reply.Service)
|
||||
require.Equal(query.Query.DNS, reply.DNS)
|
||||
require.True(reply.QueryMeta.KnownLeader, "queried leader")
|
||||
|
||||
// Make sure the native is the first one
|
||||
if !reply.Nodes[0].Service.Connect.Native {
|
||||
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
|
||||
}
|
||||
|
||||
require.True(reply.Nodes[0].Service.Connect.Native, "native")
|
||||
require.Equal(reply.Service, reply.Nodes[0].Service.Service)
|
||||
|
||||
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
|
||||
require.Equal(reply.Service, reply.Nodes[1].Service.ProxyDestination)
|
||||
}
|
||||
|
||||
// Update the query
|
||||
query.Query.Service.Connect = true
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
|
||||
|
||||
// Run the registered query.
|
||||
{
|
||||
req := structs.PreparedQueryExecuteRequest{
|
||||
Datacenter: "dc1",
|
||||
QueryIDOrName: query.Query.ID,
|
||||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Execute", &req, &reply))
|
||||
|
||||
// Result should have two because we should get the native AND
|
||||
// the proxy (since the destination matches our service name).
|
||||
require.Len(reply.Nodes, 2)
|
||||
require.Equal(query.Query.Service.Service, reply.Service)
|
||||
require.Equal(query.Query.DNS, reply.DNS)
|
||||
require.True(reply.QueryMeta.KnownLeader, "queried leader")
|
||||
|
||||
// Make sure the native is the first one
|
||||
if !reply.Nodes[0].Service.Connect.Native {
|
||||
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
|
||||
}
|
||||
|
||||
require.True(reply.Nodes[0].Service.Connect.Native, "native")
|
||||
require.Equal(reply.Service, reply.Nodes[0].Service.Service)
|
||||
|
||||
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
|
||||
require.Equal(reply.Service, reply.Nodes[1].Service.ProxyDestination)
|
||||
}
|
||||
|
||||
// Unset the query
|
||||
query.Query.Service.Connect = false
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
|
||||
}
|
||||
|
||||
func TestPreparedQuery_tagFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
testNodes := func() structs.CheckServiceNodes {
|
||||
|
@ -2820,7 +2974,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 0 || reply.Datacenter != "" || reply.Failovers != 0 {
|
||||
|
@ -2836,7 +2990,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply)
|
||||
err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply)
|
||||
if err == nil || !strings.Contains(err.Error(), "XXX") {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
|
@ -2853,7 +3007,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 0 || reply.Datacenter != "" || reply.Failovers != 0 {
|
||||
|
@ -2876,7 +3030,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -2904,7 +3058,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -2925,7 +3079,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 0 ||
|
||||
|
@ -2954,7 +3108,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -2983,7 +3137,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -3012,7 +3166,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -3047,7 +3201,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -3079,7 +3233,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -3115,7 +3269,10 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 5, structs.QueryOptions{RequireConsistent: true}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{
|
||||
Limit: 5,
|
||||
QueryOptions: structs.QueryOptions{RequireConsistent: true},
|
||||
}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
ca "github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/consul/agent/consul/fsm"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
|
@ -96,6 +97,16 @@ type Server struct {
|
|||
// autopilotWaitGroup is used to block until Autopilot shuts down.
|
||||
autopilotWaitGroup sync.WaitGroup
|
||||
|
||||
// caProvider is the current CA provider in use for Connect. This is
|
||||
// only non-nil when we are the leader.
|
||||
caProvider ca.Provider
|
||||
// caProviderRoot is the CARoot that was stored along with the ca.Provider
|
||||
// active. It's only updated in lock-step with the caProvider. This prevents
|
||||
// races between state updates to active roots and the fetch of the provider
|
||||
// instance.
|
||||
caProviderRoot *structs.CARoot
|
||||
caProviderLock sync.RWMutex
|
||||
|
||||
// Consul configuration
|
||||
config *Config
|
||||
|
||||
|
@ -1066,6 +1077,12 @@ func (s *Server) GetLANCoordinate() (lib.CoordinateSet, error) {
|
|||
return cs, nil
|
||||
}
|
||||
|
||||
// ReloadConfig is used to have the Server do an online reload of
|
||||
// relevant configuration information
|
||||
func (s *Server) ReloadConfig(config *Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Atomically sets a readiness state flag when leadership is obtained, to indicate that server is past its barrier write
|
||||
func (s *Server) setConsistentReadReady() {
|
||||
atomic.StoreInt32(&s.readyForConsistentReads, 1)
|
||||
|
|
|
@ -4,7 +4,9 @@ func init() {
|
|||
registerEndpoint(func(s *Server) interface{} { return &ACL{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &Catalog{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return NewCoordinate(s) })
|
||||
registerEndpoint(func(s *Server) interface{} { return &ConnectCA{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &Health{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &Intention{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &Internal{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &KVS{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &Operator{s} })
|
||||
|
|
|
@ -10,7 +10,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/agent/token"
|
||||
"github.com/hashicorp/consul/lib/freeport"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
|
@ -91,6 +93,17 @@ func testServerConfig(t *testing.T) (string, *Config) {
|
|||
// looks like several depend on it.
|
||||
config.RPCHoldTimeout = 5 * time.Second
|
||||
|
||||
config.ConnectEnabled = true
|
||||
config.CAConfig = &structs.CAConfiguration{
|
||||
ClusterID: connect.TestClusterID,
|
||||
Provider: structs.ConsulCAProvider,
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": "",
|
||||
"RootCert": "",
|
||||
"RotationPeriod": 90 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
|
||||
return dir, config
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,10 @@ import (
|
|||
"github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
const (
|
||||
servicesTableName = "services"
|
||||
)
|
||||
|
||||
// nodesTableSchema returns a new table schema used for storing node
|
||||
// information.
|
||||
func nodesTableSchema() *memdb.TableSchema {
|
||||
|
@ -87,6 +91,12 @@ func servicesTableSchema() *memdb.TableSchema {
|
|||
Lowercase: true,
|
||||
},
|
||||
},
|
||||
"connect": &memdb.IndexSchema{
|
||||
Name: "connect",
|
||||
AllowMissing: true,
|
||||
Unique: false,
|
||||
Indexer: &IndexConnectService{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -779,15 +789,39 @@ func maxIndexForService(tx *memdb.Txn, serviceName string, checks bool) uint64 {
|
|||
return maxIndexTxn(tx, "nodes", "services")
|
||||
}
|
||||
|
||||
// ConnectServiceNodes returns the nodes associated with a Connect
|
||||
// compatible destination for the given service name. This will include
|
||||
// both proxies and native integrations.
|
||||
func (s *Store) ConnectServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.ServiceNodes, error) {
|
||||
return s.serviceNodes(ws, serviceName, true)
|
||||
}
|
||||
|
||||
// ServiceNodes returns the nodes associated with a given service name.
|
||||
func (s *Store) ServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.ServiceNodes, error) {
|
||||
return s.serviceNodes(ws, serviceName, false)
|
||||
}
|
||||
|
||||
func (s *Store) serviceNodes(ws memdb.WatchSet, serviceName string, connect bool) (uint64, structs.ServiceNodes, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexForService(tx, serviceName, false)
|
||||
|
||||
// Function for lookup
|
||||
var f func() (memdb.ResultIterator, error)
|
||||
if !connect {
|
||||
f = func() (memdb.ResultIterator, error) {
|
||||
return tx.Get("services", "service", serviceName)
|
||||
}
|
||||
} else {
|
||||
f = func() (memdb.ResultIterator, error) {
|
||||
return tx.Get("services", "connect", serviceName)
|
||||
}
|
||||
}
|
||||
|
||||
// List all the services.
|
||||
services, err := tx.Get("services", "service", serviceName)
|
||||
services, err := f()
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed service lookup: %s", err)
|
||||
}
|
||||
|
@ -1479,14 +1513,36 @@ func (s *Store) deleteCheckTxn(tx *memdb.Txn, idx uint64, node string, checkID t
|
|||
|
||||
// CheckServiceNodes is used to query all nodes and checks for a given service.
|
||||
func (s *Store) CheckServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.CheckServiceNodes, error) {
|
||||
return s.checkServiceNodes(ws, serviceName, false)
|
||||
}
|
||||
|
||||
// CheckConnectServiceNodes is used to query all nodes and checks for Connect
|
||||
// compatible endpoints for a given service.
|
||||
func (s *Store) CheckConnectServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.CheckServiceNodes, error) {
|
||||
return s.checkServiceNodes(ws, serviceName, true)
|
||||
}
|
||||
|
||||
func (s *Store) checkServiceNodes(ws memdb.WatchSet, serviceName string, connect bool) (uint64, structs.CheckServiceNodes, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexForService(tx, serviceName, true)
|
||||
|
||||
// Function for lookup
|
||||
var f func() (memdb.ResultIterator, error)
|
||||
if !connect {
|
||||
f = func() (memdb.ResultIterator, error) {
|
||||
return tx.Get("services", "service", serviceName)
|
||||
}
|
||||
} else {
|
||||
f = func() (memdb.ResultIterator, error) {
|
||||
return tx.Get("services", "connect", serviceName)
|
||||
}
|
||||
}
|
||||
|
||||
// Query the state store for the service.
|
||||
iter, err := tx.Get("services", "service", serviceName)
|
||||
iter, err := f()
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed service lookup: %s", err)
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/hashicorp/go-memdb"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func makeRandomNodeID(t *testing.T) types.NodeID {
|
||||
|
@ -981,6 +982,35 @@ func TestStateStore_EnsureService(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestStateStore_EnsureService_connectProxy(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Create the service registration.
|
||||
ns1 := &structs.NodeService{
|
||||
Kind: structs.ServiceKindConnectProxy,
|
||||
ID: "connect-proxy",
|
||||
Service: "connect-proxy",
|
||||
Address: "1.1.1.1",
|
||||
Port: 1111,
|
||||
ProxyDestination: "foo",
|
||||
}
|
||||
|
||||
// Service successfully registers into the state store.
|
||||
testRegisterNode(t, s, 0, "node1")
|
||||
assert.Nil(s.EnsureService(10, "node1", ns1))
|
||||
|
||||
// Retrieve and verify
|
||||
_, out, err := s.NodeServices(nil, "node1")
|
||||
assert.Nil(err)
|
||||
assert.NotNil(out)
|
||||
assert.Len(out.Services, 1)
|
||||
|
||||
expect1 := *ns1
|
||||
expect1.CreateIndex, expect1.ModifyIndex = 10, 10
|
||||
assert.Equal(&expect1, out.Services["connect-proxy"])
|
||||
}
|
||||
|
||||
func TestStateStore_Services(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
|
@ -1542,6 +1572,51 @@ func TestStateStore_DeleteService(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestStateStore_ConnectServiceNodes(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Listing with no results returns an empty list.
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, nodes, err := s.ConnectServiceNodes(ws, "db")
|
||||
assert.Nil(err)
|
||||
assert.Equal(idx, uint64(0))
|
||||
assert.Len(nodes, 0)
|
||||
|
||||
// Create some nodes and services.
|
||||
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
|
||||
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
|
||||
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
|
||||
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
|
||||
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
|
||||
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
|
||||
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "native-db", Service: "db", Connect: structs.ServiceConnect{Native: true}}))
|
||||
assert.Nil(s.EnsureService(17, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}))
|
||||
assert.True(watchFired(ws))
|
||||
|
||||
// Read everything back.
|
||||
ws = memdb.NewWatchSet()
|
||||
idx, nodes, err = s.ConnectServiceNodes(ws, "db")
|
||||
assert.Nil(err)
|
||||
assert.Equal(idx, uint64(idx))
|
||||
assert.Len(nodes, 3)
|
||||
|
||||
for _, n := range nodes {
|
||||
assert.True(
|
||||
n.ServiceKind == structs.ServiceKindConnectProxy ||
|
||||
n.ServiceConnect.Native,
|
||||
"either proxy or connect native")
|
||||
}
|
||||
|
||||
// Registering some unrelated node should not fire the watch.
|
||||
testRegisterNode(t, s, 17, "nope")
|
||||
assert.False(watchFired(ws))
|
||||
|
||||
// But removing a node with the "db" service should fire the watch.
|
||||
assert.Nil(s.DeleteNode(18, "bar"))
|
||||
assert.True(watchFired(ws))
|
||||
}
|
||||
|
||||
func TestStateStore_Service_Snapshot(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
|
@ -2457,6 +2532,48 @@ func TestStateStore_CheckServiceNodes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestStateStore_CheckConnectServiceNodes(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Listing with no results returns an empty list.
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, nodes, err := s.CheckConnectServiceNodes(ws, "db")
|
||||
assert.Nil(err)
|
||||
assert.Equal(idx, uint64(0))
|
||||
assert.Len(nodes, 0)
|
||||
|
||||
// Create some nodes and services.
|
||||
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
|
||||
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
|
||||
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
|
||||
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
|
||||
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
|
||||
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
|
||||
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}))
|
||||
assert.True(watchFired(ws))
|
||||
|
||||
// Register node checks
|
||||
testRegisterCheck(t, s, 17, "foo", "", "check1", api.HealthPassing)
|
||||
testRegisterCheck(t, s, 18, "bar", "", "check2", api.HealthPassing)
|
||||
|
||||
// Register checks against the services.
|
||||
testRegisterCheck(t, s, 19, "foo", "db", "check3", api.HealthPassing)
|
||||
testRegisterCheck(t, s, 20, "bar", "proxy", "check4", api.HealthPassing)
|
||||
|
||||
// Read everything back.
|
||||
ws = memdb.NewWatchSet()
|
||||
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db")
|
||||
assert.Nil(err)
|
||||
assert.Equal(idx, uint64(idx))
|
||||
assert.Len(nodes, 2)
|
||||
|
||||
for _, n := range nodes {
|
||||
assert.Equal(structs.ServiceKindConnectProxy, n.Service.Kind)
|
||||
assert.Equal("db", n.Service.ProxyDestination)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCheckServiceNodes(b *testing.B) {
|
||||
s, err := NewStateStore(nil)
|
||||
if err != nil {
|
||||
|
|
|
@ -0,0 +1,435 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
const (
|
||||
caBuiltinProviderTableName = "connect-ca-builtin"
|
||||
caConfigTableName = "connect-ca-config"
|
||||
caRootTableName = "connect-ca-roots"
|
||||
)
|
||||
|
||||
// caBuiltinProviderTableSchema returns a new table schema used for storing
|
||||
// the built-in CA provider's state for connect. This is only used by
|
||||
// the internal Consul CA provider.
|
||||
func caBuiltinProviderTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: caBuiltinProviderTableName,
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
"id": &memdb.IndexSchema{
|
||||
Name: "id",
|
||||
AllowMissing: false,
|
||||
Unique: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "ID",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// caConfigTableSchema returns a new table schema used for storing
|
||||
// the CA config for Connect.
|
||||
func caConfigTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: caConfigTableName,
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
// This table only stores one row, so this just ignores the ID field
|
||||
// and always overwrites the same config object.
|
||||
"id": &memdb.IndexSchema{
|
||||
Name: "id",
|
||||
AllowMissing: true,
|
||||
Unique: true,
|
||||
Indexer: &memdb.ConditionalIndex{
|
||||
Conditional: func(obj interface{}) (bool, error) { return true, nil },
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// caRootTableSchema returns a new table schema used for storing
|
||||
// CA roots for Connect.
|
||||
func caRootTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: caRootTableName,
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
"id": &memdb.IndexSchema{
|
||||
Name: "id",
|
||||
AllowMissing: false,
|
||||
Unique: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "ID",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
registerSchema(caBuiltinProviderTableSchema)
|
||||
registerSchema(caConfigTableSchema)
|
||||
registerSchema(caRootTableSchema)
|
||||
}
|
||||
|
||||
// CAConfig is used to pull the CA config from the snapshot.
|
||||
func (s *Snapshot) CAConfig() (*structs.CAConfiguration, error) {
|
||||
c, err := s.tx.First(caConfigTableName, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config, ok := c.(*structs.CAConfiguration)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// CAConfig is used when restoring from a snapshot.
|
||||
func (s *Restore) CAConfig(config *structs.CAConfiguration) error {
|
||||
if err := s.tx.Insert(caConfigTableName, config); err != nil {
|
||||
return fmt.Errorf("failed restoring CA config: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CAConfig is used to get the current CA configuration.
|
||||
func (s *Store) CAConfig() (uint64, *structs.CAConfiguration, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the CA config
|
||||
c, err := tx.First(caConfigTableName, "id")
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed CA config lookup: %s", err)
|
||||
}
|
||||
|
||||
config, ok := c.(*structs.CAConfiguration)
|
||||
if !ok {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
return config.ModifyIndex, config, nil
|
||||
}
|
||||
|
||||
// CASetConfig is used to set the current CA configuration.
|
||||
func (s *Store) CASetConfig(idx uint64, config *structs.CAConfiguration) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
if err := s.caSetConfigTxn(idx, tx, config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CACheckAndSetConfig is used to try updating the CA configuration with a
|
||||
// given Raft index. If the CAS index specified is not equal to the last observed index
|
||||
// for the config, then the call is a noop,
|
||||
func (s *Store) CACheckAndSetConfig(idx, cidx uint64, config *structs.CAConfiguration) (bool, error) {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
// Check for an existing config
|
||||
existing, err := tx.First(caConfigTableName, "id")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed CA config lookup: %s", err)
|
||||
}
|
||||
|
||||
// If the existing index does not match the provided CAS
|
||||
// index arg, then we shouldn't update anything and can safely
|
||||
// return early here.
|
||||
e, ok := existing.(*structs.CAConfiguration)
|
||||
if !ok || e.ModifyIndex != cidx {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := s.caSetConfigTxn(idx, tx, config); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *Store) caSetConfigTxn(idx uint64, tx *memdb.Txn, config *structs.CAConfiguration) error {
|
||||
// Check for an existing config
|
||||
prev, err := tx.First(caConfigTableName, "id")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed CA config lookup: %s", err)
|
||||
}
|
||||
|
||||
// Set the indexes, prevent the cluster ID from changing.
|
||||
if prev != nil {
|
||||
existing := prev.(*structs.CAConfiguration)
|
||||
config.CreateIndex = existing.CreateIndex
|
||||
config.ClusterID = existing.ClusterID
|
||||
} else {
|
||||
config.CreateIndex = idx
|
||||
}
|
||||
config.ModifyIndex = idx
|
||||
|
||||
if err := tx.Insert(caConfigTableName, config); err != nil {
|
||||
return fmt.Errorf("failed updating CA config: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CARoots is used to pull all the CA roots for the snapshot.
|
||||
func (s *Snapshot) CARoots() (structs.CARoots, error) {
|
||||
ixns, err := s.tx.Get(caRootTableName, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret structs.CARoots
|
||||
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
|
||||
ret = append(ret, wrapped.(*structs.CARoot))
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// CARoots is used when restoring from a snapshot.
|
||||
func (s *Restore) CARoot(r *structs.CARoot) error {
|
||||
// Insert
|
||||
if err := s.tx.Insert(caRootTableName, r); err != nil {
|
||||
return fmt.Errorf("failed restoring CA root: %s", err)
|
||||
}
|
||||
if err := indexUpdateMaxTxn(s.tx, r.ModifyIndex, caRootTableName); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CARoots returns the list of all CA roots.
|
||||
func (s *Store) CARoots(ws memdb.WatchSet) (uint64, structs.CARoots, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the index
|
||||
idx := maxIndexTxn(tx, caRootTableName)
|
||||
|
||||
// Get all
|
||||
iter, err := tx.Get(caRootTableName, "id")
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed CA root lookup: %s", err)
|
||||
}
|
||||
ws.Add(iter.WatchCh())
|
||||
|
||||
var results structs.CARoots
|
||||
for v := iter.Next(); v != nil; v = iter.Next() {
|
||||
results = append(results, v.(*structs.CARoot))
|
||||
}
|
||||
return idx, results, nil
|
||||
}
|
||||
|
||||
// CARootActive returns the currently active CARoot.
|
||||
func (s *Store) CARootActive(ws memdb.WatchSet) (uint64, *structs.CARoot, error) {
|
||||
// Get all the roots since there should never be that many and just
|
||||
// do the filtering in this method.
|
||||
var result *structs.CARoot
|
||||
idx, roots, err := s.CARoots(ws)
|
||||
if err == nil {
|
||||
for _, r := range roots {
|
||||
if r.Active {
|
||||
result = r
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return idx, result, err
|
||||
}
|
||||
|
||||
// CARootSetCAS sets the current CA root state using a check-and-set operation.
|
||||
// On success, this will replace the previous set of CARoots completely with
|
||||
// the given set of roots.
|
||||
//
|
||||
// The first boolean result returns whether the transaction succeeded or not.
|
||||
func (s *Store) CARootSetCAS(idx, cidx uint64, rs []*structs.CARoot) (bool, error) {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
// There must be exactly one active CA root.
|
||||
activeCount := 0
|
||||
for _, r := range rs {
|
||||
if r.Active {
|
||||
activeCount++
|
||||
}
|
||||
}
|
||||
if activeCount != 1 {
|
||||
return false, fmt.Errorf("there must be exactly one active CA")
|
||||
}
|
||||
|
||||
// Get the current max index
|
||||
if midx := maxIndexTxn(tx, caRootTableName); midx != cidx {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Go through and find any existing matching CAs so we can preserve and
|
||||
// update their Create/ModifyIndex values.
|
||||
for _, r := range rs {
|
||||
if r.ID == "" {
|
||||
return false, ErrMissingCARootID
|
||||
}
|
||||
|
||||
existing, err := tx.First(caRootTableName, "id", r.ID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed CA root lookup: %s", err)
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
r.CreateIndex = existing.(*structs.CARoot).CreateIndex
|
||||
} else {
|
||||
r.CreateIndex = idx
|
||||
}
|
||||
r.ModifyIndex = idx
|
||||
}
|
||||
|
||||
// Delete all
|
||||
_, err := tx.DeleteAll(caRootTableName, "id")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Insert all
|
||||
for _, r := range rs {
|
||||
if err := tx.Insert(caRootTableName, r); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Update the index
|
||||
if err := tx.Insert("index", &IndexEntry{caRootTableName, idx}); err != nil {
|
||||
return false, fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// CAProviderState is used to pull the built-in provider states from the snapshot.
|
||||
func (s *Snapshot) CAProviderState() ([]*structs.CAConsulProviderState, error) {
|
||||
ixns, err := s.tx.Get(caBuiltinProviderTableName, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret []*structs.CAConsulProviderState
|
||||
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
|
||||
ret = append(ret, wrapped.(*structs.CAConsulProviderState))
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// CAProviderState is used when restoring from a snapshot.
|
||||
func (s *Restore) CAProviderState(state *structs.CAConsulProviderState) error {
|
||||
if err := s.tx.Insert(caBuiltinProviderTableName, state); err != nil {
|
||||
return fmt.Errorf("failed restoring built-in CA state: %s", err)
|
||||
}
|
||||
if err := indexUpdateMaxTxn(s.tx, state.ModifyIndex, caBuiltinProviderTableName); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CAProviderState is used to get the Consul CA provider state for the given ID.
|
||||
func (s *Store) CAProviderState(id string) (uint64, *structs.CAConsulProviderState, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the index
|
||||
idx := maxIndexTxn(tx, caBuiltinProviderTableName)
|
||||
|
||||
// Get the provider config
|
||||
c, err := tx.First(caBuiltinProviderTableName, "id", id)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed built-in CA state lookup: %s", err)
|
||||
}
|
||||
|
||||
state, ok := c.(*structs.CAConsulProviderState)
|
||||
if !ok {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
return idx, state, nil
|
||||
}
|
||||
|
||||
// CASetProviderState is used to set the current built-in CA provider state.
|
||||
func (s *Store) CASetProviderState(idx uint64, state *structs.CAConsulProviderState) (bool, error) {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
// Check for an existing config
|
||||
existing, err := tx.First(caBuiltinProviderTableName, "id", state.ID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed built-in CA state lookup: %s", err)
|
||||
}
|
||||
|
||||
// Set the indexes.
|
||||
if existing != nil {
|
||||
state.CreateIndex = existing.(*structs.CAConsulProviderState).CreateIndex
|
||||
} else {
|
||||
state.CreateIndex = idx
|
||||
}
|
||||
state.ModifyIndex = idx
|
||||
|
||||
if err := tx.Insert(caBuiltinProviderTableName, state); err != nil {
|
||||
return false, fmt.Errorf("failed updating built-in CA state: %s", err)
|
||||
}
|
||||
|
||||
// Update the index
|
||||
if err := tx.Insert("index", &IndexEntry{caBuiltinProviderTableName, idx}); err != nil {
|
||||
return false, fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// CADeleteProviderState is used to remove the built-in Consul CA provider
|
||||
// state for the given ID.
|
||||
func (s *Store) CADeleteProviderState(id string) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the index
|
||||
idx := maxIndexTxn(tx, caBuiltinProviderTableName)
|
||||
|
||||
// Check for an existing config
|
||||
existing, err := tx.First(caBuiltinProviderTableName, "id", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed built-in CA state lookup: %s", err)
|
||||
}
|
||||
if existing == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
providerState := existing.(*structs.CAConsulProviderState)
|
||||
|
||||
// Do the delete and update the index
|
||||
if err := tx.Delete(caBuiltinProviderTableName, providerState); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Insert("index", &IndexEntry{caBuiltinProviderTableName, idx}); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,449 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStore_CAConfig(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
expected := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": "asdf",
|
||||
"RootCert": "qwer",
|
||||
"RotationPeriod": 90 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.CASetConfig(0, expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
idx, config, err := s.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if idx != 0 {
|
||||
t.Fatalf("bad: %d", idx)
|
||||
}
|
||||
if !reflect.DeepEqual(expected, config) {
|
||||
t.Fatalf("bad: %#v, %#v", expected, config)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_CAConfigCAS(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
expected := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
}
|
||||
|
||||
if err := s.CASetConfig(0, expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Do an extra operation to move the index up by 1 for the
|
||||
// check-and-set operation after this
|
||||
if err := s.CASetConfig(1, expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Do a CAS with an index lower than the entry
|
||||
ok, err := s.CACheckAndSetConfig(2, 0, &structs.CAConfiguration{
|
||||
Provider: "static",
|
||||
})
|
||||
if ok || err != nil {
|
||||
t.Fatalf("expected (false, nil), got: (%v, %#v)", ok, err)
|
||||
}
|
||||
|
||||
// Check that the index is untouched and the entry
|
||||
// has not been updated.
|
||||
idx, config, err := s.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if idx != 1 {
|
||||
t.Fatalf("bad: %d", idx)
|
||||
}
|
||||
if config.Provider != "consul" {
|
||||
t.Fatalf("bad: %#v", config)
|
||||
}
|
||||
|
||||
// Do another CAS, this time with the correct index
|
||||
ok, err = s.CACheckAndSetConfig(2, 1, &structs.CAConfiguration{
|
||||
Provider: "static",
|
||||
})
|
||||
if !ok || err != nil {
|
||||
t.Fatalf("expected (true, nil), got: (%v, %#v)", ok, err)
|
||||
}
|
||||
|
||||
// Make sure the config was updated
|
||||
idx, config, err = s.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if idx != 2 {
|
||||
t.Fatalf("bad: %d", idx)
|
||||
}
|
||||
if config.Provider != "static" {
|
||||
t.Fatalf("bad: %#v", config)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_CAConfig_Snapshot_Restore(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
before := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": "asdf",
|
||||
"RootCert": "qwer",
|
||||
"RotationPeriod": 90 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
if err := s.CASetConfig(99, before); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
snap := s.Snapshot()
|
||||
defer snap.Close()
|
||||
|
||||
after := &structs.CAConfiguration{
|
||||
Provider: "static",
|
||||
Config: map[string]interface{}{},
|
||||
}
|
||||
if err := s.CASetConfig(100, after); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
snapped, err := snap.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
verify.Values(t, "", before, snapped)
|
||||
|
||||
s2 := testStateStore(t)
|
||||
restore := s2.Restore()
|
||||
if err := restore.CAConfig(snapped); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
restore.Commit()
|
||||
|
||||
idx, res, err := s2.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if idx != 99 {
|
||||
t.Fatalf("bad index: %d", idx)
|
||||
}
|
||||
verify.Values(t, "", before, res)
|
||||
}
|
||||
|
||||
func TestStore_CARootSetList(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call list to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
|
||||
// Build a valid value
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
|
||||
// Set
|
||||
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
|
||||
assert.Nil(err)
|
||||
assert.True(ok)
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(s.maxIndex(caRootTableName), uint64(1))
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Read it back out and verify it.
|
||||
expected := *ca1
|
||||
expected.RaftIndex = structs.RaftIndex{
|
||||
CreateIndex: 1,
|
||||
ModifyIndex: 1,
|
||||
}
|
||||
|
||||
ws = memdb.NewWatchSet()
|
||||
_, roots, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
assert.Len(roots, 1)
|
||||
actual := roots[0]
|
||||
assert.Equal(&expected, actual)
|
||||
}
|
||||
|
||||
func TestStore_CARootSet_emptyID(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call list to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
|
||||
// Build a valid value
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca1.ID = ""
|
||||
|
||||
// Set
|
||||
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), ErrMissingCARootID.Error())
|
||||
assert.False(ok)
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(s.maxIndex(caRootTableName), uint64(0))
|
||||
assert.False(watchFired(ws), "watch fired")
|
||||
|
||||
// Read it back out and verify it.
|
||||
ws = memdb.NewWatchSet()
|
||||
_, roots, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
assert.Len(roots, 0)
|
||||
}
|
||||
|
||||
func TestStore_CARootSet_noActive(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call list to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
|
||||
// Build a valid value
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca1.Active = false
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
ca2.Active = false
|
||||
|
||||
// Set
|
||||
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "exactly one active")
|
||||
assert.False(ok)
|
||||
}
|
||||
|
||||
func TestStore_CARootSet_multipleActive(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call list to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
|
||||
// Build a valid value
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
|
||||
// Set
|
||||
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "exactly one active")
|
||||
assert.False(ok)
|
||||
}
|
||||
|
||||
func TestStore_CARootActive_valid(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Build a valid value
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca1.Active = false
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
ca3 := connect.TestCA(t, nil)
|
||||
ca3.Active = false
|
||||
|
||||
// Set
|
||||
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2, ca3})
|
||||
assert.Nil(err)
|
||||
assert.True(ok)
|
||||
|
||||
// Query
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, res, err := s.CARootActive(ws)
|
||||
assert.Equal(idx, uint64(1))
|
||||
assert.Nil(err)
|
||||
assert.NotNil(res)
|
||||
assert.Equal(ca2.ID, res.ID)
|
||||
}
|
||||
|
||||
// Test that querying the active CA returns the correct value.
|
||||
func TestStore_CARootActive_none(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Querying with no results returns nil.
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, res, err := s.CARootActive(ws)
|
||||
assert.Equal(idx, uint64(0))
|
||||
assert.Nil(res)
|
||||
assert.Nil(err)
|
||||
}
|
||||
|
||||
func TestStore_CARoot_Snapshot_Restore(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Create some intentions.
|
||||
roots := structs.CARoots{
|
||||
connect.TestCA(t, nil),
|
||||
connect.TestCA(t, nil),
|
||||
connect.TestCA(t, nil),
|
||||
}
|
||||
for _, r := range roots[1:] {
|
||||
r.Active = false
|
||||
}
|
||||
|
||||
// Force the sort order of the UUIDs before we create them so the
|
||||
// order is deterministic.
|
||||
id := testUUID()
|
||||
roots[0].ID = "a" + id[1:]
|
||||
roots[1].ID = "b" + id[1:]
|
||||
roots[2].ID = "c" + id[1:]
|
||||
|
||||
// Now create
|
||||
ok, err := s.CARootSetCAS(1, 0, roots)
|
||||
assert.Nil(err)
|
||||
assert.True(ok)
|
||||
|
||||
// Snapshot the queries.
|
||||
snap := s.Snapshot()
|
||||
defer snap.Close()
|
||||
|
||||
// Alter the real state store.
|
||||
ok, err = s.CARootSetCAS(2, 1, roots[:1])
|
||||
assert.Nil(err)
|
||||
assert.True(ok)
|
||||
|
||||
// Verify the snapshot.
|
||||
assert.Equal(snap.LastIndex(), uint64(1))
|
||||
dump, err := snap.CARoots()
|
||||
assert.Nil(err)
|
||||
assert.Equal(roots, dump)
|
||||
|
||||
// Restore the values into a new state store.
|
||||
func() {
|
||||
s := testStateStore(t)
|
||||
restore := s.Restore()
|
||||
for _, r := range dump {
|
||||
assert.Nil(restore.CARoot(r))
|
||||
}
|
||||
restore.Commit()
|
||||
|
||||
// Read the restored values back out and verify that they match.
|
||||
idx, actual, err := s.CARoots(nil)
|
||||
assert.Nil(err)
|
||||
assert.Equal(idx, uint64(2))
|
||||
assert.Equal(roots, actual)
|
||||
}()
|
||||
}
|
||||
|
||||
func TestStore_CABuiltinProvider(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
{
|
||||
expected := &structs.CAConsulProviderState{
|
||||
ID: "foo",
|
||||
PrivateKey: "a",
|
||||
RootCert: "b",
|
||||
}
|
||||
|
||||
ok, err := s.CASetProviderState(0, expected)
|
||||
assert.NoError(err)
|
||||
assert.True(ok)
|
||||
|
||||
idx, state, err := s.CAProviderState(expected.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(0))
|
||||
assert.Equal(expected, state)
|
||||
}
|
||||
|
||||
{
|
||||
expected := &structs.CAConsulProviderState{
|
||||
ID: "bar",
|
||||
PrivateKey: "c",
|
||||
RootCert: "d",
|
||||
}
|
||||
|
||||
ok, err := s.CASetProviderState(1, expected)
|
||||
assert.NoError(err)
|
||||
assert.True(ok)
|
||||
|
||||
idx, state, err := s.CAProviderState(expected.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(1))
|
||||
assert.Equal(expected, state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Create multiple state entries.
|
||||
before := []*structs.CAConsulProviderState{
|
||||
{
|
||||
ID: "bar",
|
||||
PrivateKey: "y",
|
||||
RootCert: "z",
|
||||
},
|
||||
{
|
||||
ID: "foo",
|
||||
PrivateKey: "a",
|
||||
RootCert: "b",
|
||||
},
|
||||
}
|
||||
|
||||
for i, state := range before {
|
||||
ok, err := s.CASetProviderState(uint64(98+i), state)
|
||||
assert.NoError(err)
|
||||
assert.True(ok)
|
||||
}
|
||||
|
||||
// Take a snapshot.
|
||||
snap := s.Snapshot()
|
||||
defer snap.Close()
|
||||
|
||||
// Modify the state store.
|
||||
after := &structs.CAConsulProviderState{
|
||||
ID: "foo",
|
||||
PrivateKey: "c",
|
||||
RootCert: "d",
|
||||
}
|
||||
ok, err := s.CASetProviderState(100, after)
|
||||
assert.NoError(err)
|
||||
assert.True(ok)
|
||||
|
||||
snapped, err := snap.CAProviderState()
|
||||
assert.NoError(err)
|
||||
assert.Equal(before, snapped)
|
||||
|
||||
// Restore onto a new state store.
|
||||
s2 := testStateStore(t)
|
||||
restore := s2.Restore()
|
||||
for _, entry := range snapped {
|
||||
assert.NoError(restore.CAProviderState(entry))
|
||||
}
|
||||
restore.Commit()
|
||||
|
||||
// Verify the restored values match those from before the snapshot.
|
||||
for _, state := range before {
|
||||
idx, res, err := s2.CAProviderState(state.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(99))
|
||||
assert.Equal(state, res)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// IndexConnectService indexes a *struct.ServiceNode for querying by
|
||||
// services that support Connect to some target service. This will
|
||||
// properly index the proxy destination for proxies and the service name
|
||||
// for native services.
|
||||
type IndexConnectService struct{}
|
||||
|
||||
func (idx *IndexConnectService) FromObject(obj interface{}) (bool, []byte, error) {
|
||||
sn, ok := obj.(*structs.ServiceNode)
|
||||
if !ok {
|
||||
return false, nil, fmt.Errorf("Object must be ServiceNode, got %T", obj)
|
||||
}
|
||||
|
||||
var result []byte
|
||||
switch {
|
||||
case sn.ServiceKind == structs.ServiceKindConnectProxy:
|
||||
// For proxies, this service supports Connect for the destination
|
||||
result = []byte(strings.ToLower(sn.ServiceProxyDestination))
|
||||
|
||||
case sn.ServiceConnect.Native:
|
||||
// For native, this service supports Connect directly
|
||||
result = []byte(strings.ToLower(sn.ServiceName))
|
||||
|
||||
default:
|
||||
// Doesn't support Connect at all
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Return the result with the null terminator appended so we can
|
||||
// differentiate prefix vs. non-prefix matches.
|
||||
return true, append(result, '\x00'), nil
|
||||
}
|
||||
|
||||
func (idx *IndexConnectService) FromArgs(args ...interface{}) ([]byte, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, fmt.Errorf("must provide only a single argument")
|
||||
}
|
||||
|
||||
arg, ok := args[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
|
||||
}
|
||||
|
||||
// Add the null character as a terminator
|
||||
return append([]byte(strings.ToLower(arg)), '\x00'), nil
|
||||
}
|
|
@ -0,0 +1,122 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIndexConnectService_FromObject(t *testing.T) {
|
||||
cases := []struct {
|
||||
Name string
|
||||
Input interface{}
|
||||
ExpectMatch bool
|
||||
ExpectVal []byte
|
||||
ExpectErr string
|
||||
}{
|
||||
{
|
||||
"not a ServiceNode",
|
||||
42,
|
||||
false,
|
||||
nil,
|
||||
"ServiceNode",
|
||||
},
|
||||
|
||||
{
|
||||
"typical service, not native",
|
||||
&structs.ServiceNode{
|
||||
ServiceName: "db",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"typical service, is native",
|
||||
&structs.ServiceNode{
|
||||
ServiceName: "dB",
|
||||
ServiceConnect: structs.ServiceConnect{Native: true},
|
||||
},
|
||||
true,
|
||||
[]byte("db\x00"),
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"proxy service",
|
||||
&structs.ServiceNode{
|
||||
ServiceKind: structs.ServiceKindConnectProxy,
|
||||
ServiceName: "db",
|
||||
ServiceProxyDestination: "fOo",
|
||||
},
|
||||
true,
|
||||
[]byte("foo\x00"),
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
var idx IndexConnectService
|
||||
match, val, err := idx.FromObject(tc.Input)
|
||||
if tc.ExpectErr != "" {
|
||||
require.Error(err)
|
||||
require.Contains(err.Error(), tc.ExpectErr)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
require.Equal(tc.ExpectMatch, match)
|
||||
require.Equal(tc.ExpectVal, val)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexConnectService_FromArgs(t *testing.T) {
|
||||
cases := []struct {
|
||||
Name string
|
||||
Args []interface{}
|
||||
ExpectVal []byte
|
||||
ExpectErr string
|
||||
}{
|
||||
{
|
||||
"multiple arguments",
|
||||
[]interface{}{"foo", "bar"},
|
||||
nil,
|
||||
"single",
|
||||
},
|
||||
|
||||
{
|
||||
"not a string",
|
||||
[]interface{}{42},
|
||||
nil,
|
||||
"must be a string",
|
||||
},
|
||||
|
||||
{
|
||||
"string",
|
||||
[]interface{}{"fOO"},
|
||||
[]byte("foo\x00"),
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
var idx IndexConnectService
|
||||
val, err := idx.FromArgs(tc.Args...)
|
||||
if tc.ExpectErr != "" {
|
||||
require.Error(err)
|
||||
require.Contains(err.Error(), tc.ExpectErr)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
require.Equal(tc.ExpectVal, val)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,366 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
const (
|
||||
intentionsTableName = "connect-intentions"
|
||||
)
|
||||
|
||||
// intentionsTableSchema returns a new table schema used for storing
|
||||
// intentions for Connect.
|
||||
func intentionsTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: intentionsTableName,
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
"id": &memdb.IndexSchema{
|
||||
Name: "id",
|
||||
AllowMissing: false,
|
||||
Unique: true,
|
||||
Indexer: &memdb.UUIDFieldIndex{
|
||||
Field: "ID",
|
||||
},
|
||||
},
|
||||
"destination": &memdb.IndexSchema{
|
||||
Name: "destination",
|
||||
AllowMissing: true,
|
||||
// This index is not unique since we need uniqueness across the whole
|
||||
// 4-tuple.
|
||||
Unique: false,
|
||||
Indexer: &memdb.CompoundIndex{
|
||||
Indexes: []memdb.Indexer{
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "DestinationNS",
|
||||
Lowercase: true,
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "DestinationName",
|
||||
Lowercase: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"source": &memdb.IndexSchema{
|
||||
Name: "source",
|
||||
AllowMissing: true,
|
||||
// This index is not unique since we need uniqueness across the whole
|
||||
// 4-tuple.
|
||||
Unique: false,
|
||||
Indexer: &memdb.CompoundIndex{
|
||||
Indexes: []memdb.Indexer{
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "SourceNS",
|
||||
Lowercase: true,
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "SourceName",
|
||||
Lowercase: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"source_destination": &memdb.IndexSchema{
|
||||
Name: "source_destination",
|
||||
AllowMissing: true,
|
||||
Unique: true,
|
||||
Indexer: &memdb.CompoundIndex{
|
||||
Indexes: []memdb.Indexer{
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "SourceNS",
|
||||
Lowercase: true,
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "SourceName",
|
||||
Lowercase: true,
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "DestinationNS",
|
||||
Lowercase: true,
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "DestinationName",
|
||||
Lowercase: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
registerSchema(intentionsTableSchema)
|
||||
}
|
||||
|
||||
// Intentions is used to pull all the intentions from the snapshot.
|
||||
func (s *Snapshot) Intentions() (structs.Intentions, error) {
|
||||
ixns, err := s.tx.Get(intentionsTableName, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret structs.Intentions
|
||||
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
|
||||
ret = append(ret, wrapped.(*structs.Intention))
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Intention is used when restoring from a snapshot.
|
||||
func (s *Restore) Intention(ixn *structs.Intention) error {
|
||||
// Insert the intention
|
||||
if err := s.tx.Insert(intentionsTableName, ixn); err != nil {
|
||||
return fmt.Errorf("failed restoring intention: %s", err)
|
||||
}
|
||||
if err := indexUpdateMaxTxn(s.tx, ixn.ModifyIndex, intentionsTableName); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Intentions returns the list of all intentions.
|
||||
func (s *Store) Intentions(ws memdb.WatchSet) (uint64, structs.Intentions, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the index
|
||||
idx := maxIndexTxn(tx, intentionsTableName)
|
||||
if idx < 1 {
|
||||
idx = 1
|
||||
}
|
||||
|
||||
// Get all intentions
|
||||
iter, err := tx.Get(intentionsTableName, "id")
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
ws.Add(iter.WatchCh())
|
||||
|
||||
var results structs.Intentions
|
||||
for ixn := iter.Next(); ixn != nil; ixn = iter.Next() {
|
||||
results = append(results, ixn.(*structs.Intention))
|
||||
}
|
||||
|
||||
// Sort by precedence just because that's nicer and probably what most clients
|
||||
// want for presentation.
|
||||
sort.Sort(structs.IntentionPrecedenceSorter(results))
|
||||
|
||||
return idx, results, nil
|
||||
}
|
||||
|
||||
// IntentionSet creates or updates an intention.
|
||||
func (s *Store) IntentionSet(idx uint64, ixn *structs.Intention) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
if err := s.intentionSetTxn(tx, idx, ixn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// intentionSetTxn is the inner method used to insert an intention with
|
||||
// the proper indexes into the state store.
|
||||
func (s *Store) intentionSetTxn(tx *memdb.Txn, idx uint64, ixn *structs.Intention) error {
|
||||
// ID is required
|
||||
if ixn.ID == "" {
|
||||
return ErrMissingIntentionID
|
||||
}
|
||||
|
||||
// Ensure Precedence is populated correctly on "write"
|
||||
ixn.UpdatePrecedence()
|
||||
|
||||
// Check for an existing intention
|
||||
existing, err := tx.First(intentionsTableName, "id", ixn.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
if existing != nil {
|
||||
oldIxn := existing.(*structs.Intention)
|
||||
ixn.CreateIndex = oldIxn.CreateIndex
|
||||
ixn.CreatedAt = oldIxn.CreatedAt
|
||||
} else {
|
||||
ixn.CreateIndex = idx
|
||||
}
|
||||
ixn.ModifyIndex = idx
|
||||
|
||||
// Check for duplicates on the 4-tuple.
|
||||
duplicate, err := tx.First(intentionsTableName, "source_destination",
|
||||
ixn.SourceNS, ixn.SourceName, ixn.DestinationNS, ixn.DestinationName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
if duplicate != nil {
|
||||
dupIxn := duplicate.(*structs.Intention)
|
||||
// Same ID is OK - this is an update
|
||||
if dupIxn.ID != ixn.ID {
|
||||
return fmt.Errorf("duplicate intention found: %s", dupIxn.String())
|
||||
}
|
||||
}
|
||||
|
||||
// We always force meta to be non-nil so that we its an empty map.
|
||||
// This makes it easy for API responses to not nil-check this everywhere.
|
||||
if ixn.Meta == nil {
|
||||
ixn.Meta = make(map[string]string)
|
||||
}
|
||||
|
||||
// Insert
|
||||
if err := tx.Insert(intentionsTableName, ixn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Insert("index", &IndexEntry{intentionsTableName, idx}); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IntentionGet returns the given intention by ID.
|
||||
func (s *Store) IntentionGet(ws memdb.WatchSet, id string) (uint64, *structs.Intention, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexTxn(tx, intentionsTableName)
|
||||
if idx < 1 {
|
||||
idx = 1
|
||||
}
|
||||
|
||||
// Look up by its ID.
|
||||
watchCh, intention, err := tx.FirstWatch(intentionsTableName, "id", id)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
ws.Add(watchCh)
|
||||
|
||||
// Convert the interface{} if it is non-nil
|
||||
var result *structs.Intention
|
||||
if intention != nil {
|
||||
result = intention.(*structs.Intention)
|
||||
}
|
||||
|
||||
return idx, result, nil
|
||||
}
|
||||
|
||||
// IntentionDelete deletes the given intention by ID.
|
||||
func (s *Store) IntentionDelete(idx uint64, id string) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
if err := s.intentionDeleteTxn(tx, idx, id); err != nil {
|
||||
return fmt.Errorf("failed intention delete: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// intentionDeleteTxn is the inner method used to delete a intention
|
||||
// with the proper indexes into the state store.
|
||||
func (s *Store) intentionDeleteTxn(tx *memdb.Txn, idx uint64, queryID string) error {
|
||||
// Pull the query.
|
||||
wrapped, err := tx.First(intentionsTableName, "id", queryID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
if wrapped == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete the query and update the index.
|
||||
if err := tx.Delete(intentionsTableName, wrapped); err != nil {
|
||||
return fmt.Errorf("failed intention delete: %s", err)
|
||||
}
|
||||
if err := tx.Insert("index", &IndexEntry{intentionsTableName, idx}); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IntentionMatch returns the list of intentions that match the namespace and
|
||||
// name for either a source or destination. This applies the resolution rules
|
||||
// so wildcards will match any value.
|
||||
//
|
||||
// The returned value is the list of intentions in the same order as the
|
||||
// entries in args. The intentions themselves are sorted based on the
|
||||
// intention precedence rules. i.e. result[0][0] is the highest precedent
|
||||
// rule to match for the first entry.
|
||||
func (s *Store) IntentionMatch(ws memdb.WatchSet, args *structs.IntentionQueryMatch) (uint64, []structs.Intentions, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexTxn(tx, intentionsTableName)
|
||||
if idx < 1 {
|
||||
idx = 1
|
||||
}
|
||||
|
||||
// Make all the calls and accumulate the results
|
||||
results := make([]structs.Intentions, len(args.Entries))
|
||||
for i, entry := range args.Entries {
|
||||
// Each search entry may require multiple queries to memdb, so this
|
||||
// returns the arguments for each necessary Get. Note on performance:
|
||||
// this is not the most optimal set of queries since we repeat some
|
||||
// many times (such as */*). We can work on improving that in the
|
||||
// future, the test cases shouldn't have to change for that.
|
||||
getParams, err := s.intentionMatchGetParams(entry)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
// Perform each call and accumulate the result.
|
||||
var ixns structs.Intentions
|
||||
for _, params := range getParams {
|
||||
iter, err := tx.Get(intentionsTableName, string(args.Type), params...)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
|
||||
ws.Add(iter.WatchCh())
|
||||
|
||||
for ixn := iter.Next(); ixn != nil; ixn = iter.Next() {
|
||||
ixns = append(ixns, ixn.(*structs.Intention))
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the results by precedence
|
||||
sort.Sort(structs.IntentionPrecedenceSorter(ixns))
|
||||
|
||||
// Store the result
|
||||
results[i] = ixns
|
||||
}
|
||||
|
||||
return idx, results, nil
|
||||
}
|
||||
|
||||
// intentionMatchGetParams returns the tx.Get parameters to find all the
|
||||
// intentions for a certain entry.
|
||||
func (s *Store) intentionMatchGetParams(entry structs.IntentionMatchEntry) ([][]interface{}, error) {
|
||||
// We always query for "*/*" so include that. If the namespace is a
|
||||
// wildcard, then we're actually done.
|
||||
result := make([][]interface{}, 0, 3)
|
||||
result = append(result, []interface{}{"*", "*"})
|
||||
if entry.Namespace == structs.IntentionWildcard {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Search for NS/* intentions. If we have a wildcard name, then we're done.
|
||||
result = append(result, []interface{}{entry.Namespace, "*"})
|
||||
if entry.Name == structs.IntentionWildcard {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Search for the exact NS/N value.
|
||||
result = append(result, []interface{}{entry.Namespace, entry.Name})
|
||||
return result, nil
|
||||
}
|
|
@ -0,0 +1,559 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStore_IntentionGet_none(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Querying with no results returns nil.
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, res, err := s.IntentionGet(ws, testUUID())
|
||||
assert.Equal(uint64(1), idx)
|
||||
assert.Nil(res)
|
||||
assert.Nil(err)
|
||||
}
|
||||
|
||||
func TestStore_IntentionSetGet_basic(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call Get to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.IntentionGet(ws, testUUID())
|
||||
assert.Nil(err)
|
||||
|
||||
// Build a valid intention
|
||||
ixn := &structs.Intention{
|
||||
ID: testUUID(),
|
||||
SourceNS: "default",
|
||||
SourceName: "*",
|
||||
DestinationNS: "default",
|
||||
DestinationName: "web",
|
||||
Meta: map[string]string{},
|
||||
}
|
||||
|
||||
// Inserting a with empty ID is disallowed.
|
||||
assert.NoError(s.IntentionSet(1, ixn))
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(uint64(1), s.maxIndex(intentionsTableName))
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Read it back out and verify it.
|
||||
expected := &structs.Intention{
|
||||
ID: ixn.ID,
|
||||
SourceNS: "default",
|
||||
SourceName: "*",
|
||||
DestinationNS: "default",
|
||||
DestinationName: "web",
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 1,
|
||||
ModifyIndex: 1,
|
||||
},
|
||||
}
|
||||
expected.UpdatePrecedence()
|
||||
|
||||
ws = memdb.NewWatchSet()
|
||||
idx, actual, err := s.IntentionGet(ws, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(expected.CreateIndex, idx)
|
||||
assert.Equal(expected, actual)
|
||||
|
||||
// Change a value and test updating
|
||||
ixn.SourceNS = "foo"
|
||||
assert.NoError(s.IntentionSet(2, ixn))
|
||||
|
||||
// Change a value that isn't in the unique 4 tuple and check we don't
|
||||
// incorrectly consider this a duplicate when updating.
|
||||
ixn.Action = structs.IntentionActionDeny
|
||||
assert.NoError(s.IntentionSet(2, ixn))
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(uint64(2), s.maxIndex(intentionsTableName))
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Read it back and verify the data was updated
|
||||
expected.SourceNS = ixn.SourceNS
|
||||
expected.Action = structs.IntentionActionDeny
|
||||
expected.ModifyIndex = 2
|
||||
ws = memdb.NewWatchSet()
|
||||
idx, actual, err = s.IntentionGet(ws, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(expected.ModifyIndex, idx)
|
||||
assert.Equal(expected, actual)
|
||||
|
||||
// Attempt to insert another intention with duplicate 4-tuple
|
||||
ixn = &structs.Intention{
|
||||
ID: testUUID(),
|
||||
SourceNS: "default",
|
||||
SourceName: "*",
|
||||
DestinationNS: "default",
|
||||
DestinationName: "web",
|
||||
Meta: map[string]string{},
|
||||
}
|
||||
|
||||
// Duplicate 4-tuple should cause an error
|
||||
ws = memdb.NewWatchSet()
|
||||
assert.Error(s.IntentionSet(3, ixn))
|
||||
|
||||
// Make sure the index did NOT get updated.
|
||||
assert.Equal(uint64(2), s.maxIndex(intentionsTableName))
|
||||
assert.False(watchFired(ws), "watch not fired")
|
||||
}
|
||||
|
||||
func TestStore_IntentionSet_emptyId(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.IntentionGet(ws, testUUID())
|
||||
assert.NoError(err)
|
||||
|
||||
// Inserting a with empty ID is disallowed.
|
||||
err = s.IntentionSet(1, &structs.Intention{})
|
||||
assert.Error(err)
|
||||
assert.Contains(err.Error(), ErrMissingIntentionID.Error())
|
||||
|
||||
// Index is not updated if nothing is saved.
|
||||
assert.Equal(s.maxIndex(intentionsTableName), uint64(0))
|
||||
assert.False(watchFired(ws), "watch fired")
|
||||
}
|
||||
|
||||
func TestStore_IntentionSet_updateCreatedAt(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Build a valid intention
|
||||
now := time.Now()
|
||||
ixn := structs.Intention{
|
||||
ID: testUUID(),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
// Insert
|
||||
assert.NoError(s.IntentionSet(1, &ixn))
|
||||
|
||||
// Change a value and test updating
|
||||
ixnUpdate := ixn
|
||||
ixnUpdate.CreatedAt = now.Add(10 * time.Second)
|
||||
assert.NoError(s.IntentionSet(2, &ixnUpdate))
|
||||
|
||||
// Read it back and verify
|
||||
_, actual, err := s.IntentionGet(nil, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(now, actual.CreatedAt)
|
||||
}
|
||||
|
||||
func TestStore_IntentionSet_metaNil(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Build a valid intention
|
||||
ixn := structs.Intention{
|
||||
ID: testUUID(),
|
||||
}
|
||||
|
||||
// Insert
|
||||
assert.NoError(s.IntentionSet(1, &ixn))
|
||||
|
||||
// Read it back and verify
|
||||
_, actual, err := s.IntentionGet(nil, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(actual.Meta)
|
||||
}
|
||||
|
||||
func TestStore_IntentionSet_metaSet(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Build a valid intention
|
||||
ixn := structs.Intention{
|
||||
ID: testUUID(),
|
||||
Meta: map[string]string{"foo": "bar"},
|
||||
}
|
||||
|
||||
// Insert
|
||||
assert.NoError(s.IntentionSet(1, &ixn))
|
||||
|
||||
// Read it back and verify
|
||||
_, actual, err := s.IntentionGet(nil, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(ixn.Meta, actual.Meta)
|
||||
}
|
||||
|
||||
func TestStore_IntentionDelete(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call Get to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.IntentionGet(ws, testUUID())
|
||||
assert.NoError(err)
|
||||
|
||||
// Create
|
||||
ixn := &structs.Intention{ID: testUUID()}
|
||||
assert.NoError(s.IntentionSet(1, ixn))
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(s.maxIndex(intentionsTableName), uint64(1))
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Delete
|
||||
assert.NoError(s.IntentionDelete(2, ixn.ID))
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(s.maxIndex(intentionsTableName), uint64(2))
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Sanity check to make sure it's not there.
|
||||
idx, actual, err := s.IntentionGet(nil, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(2))
|
||||
assert.Nil(actual)
|
||||
}
|
||||
|
||||
func TestStore_IntentionsList(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Querying with no results returns nil.
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, res, err := s.Intentions(ws)
|
||||
assert.NoError(err)
|
||||
assert.Nil(res)
|
||||
assert.Equal(uint64(1), idx)
|
||||
|
||||
// Create some intentions
|
||||
ixns := structs.Intentions{
|
||||
&structs.Intention{
|
||||
ID: testUUID(),
|
||||
Meta: map[string]string{},
|
||||
},
|
||||
&structs.Intention{
|
||||
ID: testUUID(),
|
||||
Meta: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
// Force deterministic sort order
|
||||
ixns[0].ID = "a" + ixns[0].ID[1:]
|
||||
ixns[1].ID = "b" + ixns[1].ID[1:]
|
||||
|
||||
// Create
|
||||
for i, ixn := range ixns {
|
||||
assert.NoError(s.IntentionSet(uint64(1+i), ixn))
|
||||
}
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Read it back and verify.
|
||||
expected := structs.Intentions{
|
||||
&structs.Intention{
|
||||
ID: ixns[0].ID,
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 1,
|
||||
ModifyIndex: 1,
|
||||
},
|
||||
},
|
||||
&structs.Intention{
|
||||
ID: ixns[1].ID,
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 2,
|
||||
ModifyIndex: 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := range expected {
|
||||
expected[i].UpdatePrecedence() // to match what is returned...
|
||||
}
|
||||
idx, actual, err := s.Intentions(nil)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(2))
|
||||
assert.Equal(expected, actual)
|
||||
}
|
||||
|
||||
// Test the matrix of match logic.
|
||||
//
|
||||
// Note that this doesn't need to test the intention sort logic exhaustively
|
||||
// since this is tested in their sort implementation in the structs.
|
||||
func TestStore_IntentionMatch_table(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Insert [][]string // List of intentions to insert
|
||||
Query [][]string // List of intentions to match
|
||||
Expected [][][]string // List of matches, where each match is a list of intentions
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
"single exact namespace/name",
|
||||
[][]string{
|
||||
{"foo", "*"},
|
||||
{"foo", "bar"},
|
||||
{"foo", "baz"}, // shouldn't match
|
||||
{"bar", "bar"}, // shouldn't match
|
||||
{"bar", "*"}, // shouldn't match
|
||||
{"*", "*"},
|
||||
},
|
||||
[][]string{
|
||||
{"foo", "bar"},
|
||||
},
|
||||
[][][]string{
|
||||
{
|
||||
{"foo", "bar"},
|
||||
{"foo", "*"},
|
||||
{"*", "*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
"multiple exact namespace/name",
|
||||
[][]string{
|
||||
{"foo", "*"},
|
||||
{"foo", "bar"},
|
||||
{"foo", "baz"}, // shouldn't match
|
||||
{"bar", "bar"},
|
||||
{"bar", "*"},
|
||||
},
|
||||
[][]string{
|
||||
{"foo", "bar"},
|
||||
{"bar", "bar"},
|
||||
},
|
||||
[][][]string{
|
||||
{
|
||||
{"foo", "bar"},
|
||||
{"foo", "*"},
|
||||
},
|
||||
{
|
||||
{"bar", "bar"},
|
||||
{"bar", "*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
"single exact namespace/name with duplicate destinations",
|
||||
[][]string{
|
||||
// 4-tuple specifies src and destination to test duplicate destinations
|
||||
// with different sources. We flip them around to test in both
|
||||
// directions. The first pair are the ones searched on in both cases so
|
||||
// the duplicates need to be there.
|
||||
{"foo", "bar", "foo", "*"},
|
||||
{"foo", "bar", "bar", "*"},
|
||||
{"*", "*", "*", "*"},
|
||||
},
|
||||
[][]string{
|
||||
{"foo", "bar"},
|
||||
},
|
||||
[][][]string{
|
||||
{
|
||||
// Note the first two have the same precedence so we rely on arbitrary
|
||||
// lexicographical tie-break behaviour.
|
||||
{"foo", "bar", "bar", "*"},
|
||||
{"foo", "bar", "foo", "*"},
|
||||
{"*", "*", "*", "*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// testRunner implements the test for a single case, but can be
|
||||
// parameterized to run for both source and destination so we can
|
||||
// test both cases.
|
||||
testRunner := func(t *testing.T, tc testCase, typ structs.IntentionMatchType) {
|
||||
// Insert the set
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
var idx uint64 = 1
|
||||
for _, v := range tc.Insert {
|
||||
ixn := &structs.Intention{ID: testUUID()}
|
||||
switch typ {
|
||||
case structs.IntentionMatchDestination:
|
||||
ixn.DestinationNS = v[0]
|
||||
ixn.DestinationName = v[1]
|
||||
if len(v) == 4 {
|
||||
ixn.SourceNS = v[2]
|
||||
ixn.SourceName = v[3]
|
||||
}
|
||||
case structs.IntentionMatchSource:
|
||||
ixn.SourceNS = v[0]
|
||||
ixn.SourceName = v[1]
|
||||
if len(v) == 4 {
|
||||
ixn.DestinationNS = v[2]
|
||||
ixn.DestinationName = v[3]
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(s.IntentionSet(idx, ixn))
|
||||
|
||||
idx++
|
||||
}
|
||||
|
||||
// Build the arguments
|
||||
args := &structs.IntentionQueryMatch{Type: typ}
|
||||
for _, q := range tc.Query {
|
||||
args.Entries = append(args.Entries, structs.IntentionMatchEntry{
|
||||
Namespace: q[0],
|
||||
Name: q[1],
|
||||
})
|
||||
}
|
||||
|
||||
// Match
|
||||
_, matches, err := s.IntentionMatch(nil, args)
|
||||
assert.NoError(err)
|
||||
|
||||
// Should have equal lengths
|
||||
require.Len(t, matches, len(tc.Expected))
|
||||
|
||||
// Verify matches
|
||||
for i, expected := range tc.Expected {
|
||||
var actual [][]string
|
||||
for _, ixn := range matches[i] {
|
||||
switch typ {
|
||||
case structs.IntentionMatchDestination:
|
||||
if len(expected) > 1 && len(expected[0]) == 4 {
|
||||
actual = append(actual, []string{
|
||||
ixn.DestinationNS,
|
||||
ixn.DestinationName,
|
||||
ixn.SourceNS,
|
||||
ixn.SourceName,
|
||||
})
|
||||
} else {
|
||||
actual = append(actual, []string{ixn.DestinationNS, ixn.DestinationName})
|
||||
}
|
||||
case structs.IntentionMatchSource:
|
||||
if len(expected) > 1 && len(expected[0]) == 4 {
|
||||
actual = append(actual, []string{
|
||||
ixn.SourceNS,
|
||||
ixn.SourceName,
|
||||
ixn.DestinationNS,
|
||||
ixn.DestinationName,
|
||||
})
|
||||
} else {
|
||||
actual = append(actual, []string{ixn.SourceNS, ixn.SourceName})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Name+" (destination)", func(t *testing.T) {
|
||||
testRunner(t, tc, structs.IntentionMatchDestination)
|
||||
})
|
||||
|
||||
t.Run(tc.Name+" (source)", func(t *testing.T) {
|
||||
testRunner(t, tc, structs.IntentionMatchSource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Intention_Snapshot_Restore(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Create some intentions.
|
||||
ixns := structs.Intentions{
|
||||
&structs.Intention{
|
||||
DestinationName: "foo",
|
||||
},
|
||||
&structs.Intention{
|
||||
DestinationName: "bar",
|
||||
},
|
||||
&structs.Intention{
|
||||
DestinationName: "baz",
|
||||
},
|
||||
}
|
||||
|
||||
// Force the sort order of the UUIDs before we create them so the
|
||||
// order is deterministic.
|
||||
id := testUUID()
|
||||
ixns[0].ID = "a" + id[1:]
|
||||
ixns[1].ID = "b" + id[1:]
|
||||
ixns[2].ID = "c" + id[1:]
|
||||
|
||||
// Now create
|
||||
for i, ixn := range ixns {
|
||||
assert.NoError(s.IntentionSet(uint64(4+i), ixn))
|
||||
}
|
||||
|
||||
// Snapshot the queries.
|
||||
snap := s.Snapshot()
|
||||
defer snap.Close()
|
||||
|
||||
// Alter the real state store.
|
||||
assert.NoError(s.IntentionDelete(7, ixns[0].ID))
|
||||
|
||||
// Verify the snapshot.
|
||||
assert.Equal(snap.LastIndex(), uint64(6))
|
||||
|
||||
// Expect them sorted in insertion order
|
||||
expected := structs.Intentions{
|
||||
&structs.Intention{
|
||||
ID: ixns[0].ID,
|
||||
DestinationName: "foo",
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 4,
|
||||
ModifyIndex: 4,
|
||||
},
|
||||
},
|
||||
&structs.Intention{
|
||||
ID: ixns[1].ID,
|
||||
DestinationName: "bar",
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 5,
|
||||
ModifyIndex: 5,
|
||||
},
|
||||
},
|
||||
&structs.Intention{
|
||||
ID: ixns[2].ID,
|
||||
DestinationName: "baz",
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 6,
|
||||
ModifyIndex: 6,
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := range expected {
|
||||
expected[i].UpdatePrecedence() // to match what is returned...
|
||||
}
|
||||
dump, err := snap.Intentions()
|
||||
assert.NoError(err)
|
||||
assert.Equal(expected, dump)
|
||||
|
||||
// Restore the values into a new state store.
|
||||
func() {
|
||||
s := testStateStore(t)
|
||||
restore := s.Restore()
|
||||
for _, ixn := range dump {
|
||||
assert.NoError(restore.Intention(ixn))
|
||||
}
|
||||
restore.Commit()
|
||||
|
||||
// Read the restored values back out and verify that they match. Note that
|
||||
// Intentions are returned precedence sorted unlike the snapshot so we need
|
||||
// to rearrange the expected slice some.
|
||||
expected[0], expected[1], expected[2] = expected[1], expected[2], expected[0]
|
||||
idx, actual, err := s.Intentions(nil)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(6))
|
||||
assert.Equal(expected, actual)
|
||||
}()
|
||||
}
|
|
@ -28,6 +28,14 @@ var (
|
|||
// ErrMissingQueryID is returned when a Query set is called on
|
||||
// a Query with an empty ID.
|
||||
ErrMissingQueryID = errors.New("Missing Query ID")
|
||||
|
||||
// ErrMissingCARootID is returned when an CARoot set is called
|
||||
// with an CARoot with an empty ID.
|
||||
ErrMissingCARootID = errors.New("Missing CA Root ID")
|
||||
|
||||
// ErrMissingIntentionID is returned when an Intention set is called
|
||||
// with an Intention with an empty ID.
|
||||
ErrMissingIntentionID = errors.New("Missing Intention ID")
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
54
agent/dns.go
54
agent/dns.go
|
@ -51,6 +51,7 @@ type dnsConfig struct {
|
|||
ServiceTTL map[string]time.Duration
|
||||
UDPAnswerLimit int
|
||||
ARecordLimit int
|
||||
NodeMetaTXT bool
|
||||
}
|
||||
|
||||
// DNSServer is used to wrap an Agent and expose various
|
||||
|
@ -109,6 +110,7 @@ func GetDNSConfig(conf *config.RuntimeConfig) *dnsConfig {
|
|||
SegmentName: conf.SegmentName,
|
||||
ServiceTTL: conf.DNSServiceTTL,
|
||||
UDPAnswerLimit: conf.DNSUDPAnswerLimit,
|
||||
NodeMetaTXT: conf.DNSNodeMetaTXT,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -337,7 +339,7 @@ func (d *DNSServer) addSOA(msg *dns.Msg) {
|
|||
// nameservers returns the names and ip addresses of up to three random servers
|
||||
// in the current cluster which serve as authoritative name servers for zone.
|
||||
func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
|
||||
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "")
|
||||
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false)
|
||||
if err != nil {
|
||||
d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err)
|
||||
return nil, nil
|
||||
|
@ -374,7 +376,7 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
|
|||
}
|
||||
ns = append(ns, nsrr)
|
||||
|
||||
glue := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns)
|
||||
glue := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns, false)
|
||||
extra = append(extra, glue...)
|
||||
|
||||
// don't provide more than 3 servers
|
||||
|
@ -415,7 +417,7 @@ PARSE:
|
|||
n = n + 1
|
||||
}
|
||||
|
||||
switch labels[n-1] {
|
||||
switch kind := labels[n-1]; kind {
|
||||
case "service":
|
||||
if n == 1 {
|
||||
goto INVALID
|
||||
|
@ -433,7 +435,7 @@ PARSE:
|
|||
}
|
||||
|
||||
// _name._tag.service.consul
|
||||
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, req, resp)
|
||||
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp)
|
||||
|
||||
// Consul 0.3 and prior format for SRV queries
|
||||
} else {
|
||||
|
@ -445,9 +447,17 @@ PARSE:
|
|||
}
|
||||
|
||||
// tag[.tag].name.service.consul
|
||||
d.serviceLookup(network, datacenter, labels[n-2], tag, req, resp)
|
||||
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp)
|
||||
}
|
||||
|
||||
case "connect":
|
||||
if n == 1 {
|
||||
goto INVALID
|
||||
}
|
||||
|
||||
// name.connect.consul
|
||||
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp)
|
||||
|
||||
case "node":
|
||||
if n == 1 {
|
||||
goto INVALID
|
||||
|
@ -582,7 +592,7 @@ RPC:
|
|||
n := out.NodeServices.Node
|
||||
edns := req.IsEdns0() != nil
|
||||
addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses)
|
||||
records := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns)
|
||||
records := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns, true)
|
||||
if records != nil {
|
||||
resp.Answer = append(resp.Answer, records...)
|
||||
}
|
||||
|
@ -610,7 +620,7 @@ func encodeKVasRFC1464(key, value string) (txt string) {
|
|||
}
|
||||
|
||||
// formatNodeRecord takes a Node and returns an A, AAAA, TXT or CNAME record
|
||||
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records []dns.RR) {
|
||||
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns, answer bool) (records []dns.RR) {
|
||||
// Parse the IP
|
||||
ip := net.ParseIP(addr)
|
||||
var ipv4 net.IP
|
||||
|
@ -671,7 +681,20 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
|
|||
}
|
||||
}
|
||||
|
||||
if node != nil && (qType == dns.TypeANY || qType == dns.TypeTXT) {
|
||||
node_meta_txt := false
|
||||
|
||||
if node == nil {
|
||||
node_meta_txt = false
|
||||
} else if answer {
|
||||
node_meta_txt = true
|
||||
} else {
|
||||
// Use configuration when the TXT RR would
|
||||
// end up in the Additional section of the
|
||||
// DNS response
|
||||
node_meta_txt = d.config.NodeMetaTXT
|
||||
}
|
||||
|
||||
if node_meta_txt {
|
||||
for key, value := range node.Meta {
|
||||
txt := value
|
||||
if !strings.HasPrefix(strings.ToLower(key), "rfc1035-") {
|
||||
|
@ -782,8 +805,8 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
|
|||
originalNumRecords := len(resp.Answer)
|
||||
|
||||
// It is not possible to return more than 4k records even with compression
|
||||
// Since we are performing binary search it is not a big deal, but it
|
||||
// improves a bit performance, even with binary search
|
||||
// Since we are performing binary search it is not a big deal, but it
|
||||
// improves a bit performance, even with binary search
|
||||
truncateAt := 4096
|
||||
if req.Question[0].Qtype == dns.TypeSRV {
|
||||
// More than 1024 SRV records do not fit in 64k
|
||||
|
@ -898,8 +921,9 @@ func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed
|
|||
}
|
||||
|
||||
// lookupServiceNodes returns nodes with a given service.
|
||||
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string) (structs.IndexedCheckServiceNodes, error) {
|
||||
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool) (structs.IndexedCheckServiceNodes, error) {
|
||||
args := structs.ServiceSpecificRequest{
|
||||
Connect: connect,
|
||||
Datacenter: datacenter,
|
||||
ServiceName: service,
|
||||
ServiceTag: tag,
|
||||
|
@ -935,8 +959,8 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string) (structs
|
|||
}
|
||||
|
||||
// serviceLookup is used to handle a service query
|
||||
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, resp *dns.Msg) {
|
||||
out, err := d.lookupServiceNodes(datacenter, service, tag)
|
||||
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg) {
|
||||
out, err := d.lookupServiceNodes(datacenter, service, tag, connect)
|
||||
if err != nil {
|
||||
d.logger.Printf("[ERR] dns: rpc error: %v", err)
|
||||
resp.SetRcode(req, dns.RcodeServerFailure)
|
||||
|
@ -1143,7 +1167,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
|
|||
handled[addr] = struct{}{}
|
||||
|
||||
// Add the node record
|
||||
records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns)
|
||||
records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns, true)
|
||||
if records != nil {
|
||||
resp.Answer = append(resp.Answer, records...)
|
||||
count++
|
||||
|
@ -1192,7 +1216,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
|
|||
}
|
||||
|
||||
// Add the extra record
|
||||
records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns)
|
||||
records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, false)
|
||||
if len(records) > 0 {
|
||||
// Use the node address if it doesn't differ from the service address
|
||||
if addr == node.Node.Address {
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/hashicorp/serf/coordinate"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -472,6 +473,51 @@ func TestDNS_NodeLookup_TXT(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDNS_NodeLookup_TXT_DontSuppress(t *testing.T) {
|
||||
a := NewTestAgent(t.Name(), `dns_config = { enable_additional_node_meta_txt = false }`)
|
||||
defer a.Shutdown()
|
||||
|
||||
args := &structs.RegisterRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: "google",
|
||||
Address: "127.0.0.1",
|
||||
NodeMeta: map[string]string{
|
||||
"rfc1035-00": "value0",
|
||||
"key0": "value1",
|
||||
},
|
||||
}
|
||||
|
||||
var out struct{}
|
||||
if err := a.RPC("Catalog.Register", args, &out); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("google.node.consul.", dns.TypeTXT)
|
||||
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m, a.DNSAddr())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Should have the 1 TXT record reply
|
||||
if len(in.Answer) != 2 {
|
||||
t.Fatalf("Bad: %#v", in)
|
||||
}
|
||||
|
||||
txtRec, ok := in.Answer[0].(*dns.TXT)
|
||||
if !ok {
|
||||
t.Fatalf("Bad: %#v", in.Answer[0])
|
||||
}
|
||||
if len(txtRec.Txt) != 1 {
|
||||
t.Fatalf("Bad: %#v", in.Answer[0])
|
||||
}
|
||||
if txtRec.Txt[0] != "value0" && txtRec.Txt[0] != "key0=value1" {
|
||||
t.Fatalf("Bad: %#v", in.Answer[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_NodeLookup_ANY(t *testing.T) {
|
||||
a := NewTestAgent(t.Name(), ``)
|
||||
defer a.Shutdown()
|
||||
|
@ -510,7 +556,46 @@ func TestDNS_NodeLookup_ANY(t *testing.T) {
|
|||
},
|
||||
}
|
||||
verify.Values(t, "answer", in.Answer, wantAnswer)
|
||||
}
|
||||
|
||||
func TestDNS_NodeLookup_ANY_DontSuppressTXT(t *testing.T) {
|
||||
a := NewTestAgent(t.Name(), `dns_config = { enable_additional_node_meta_txt = false }`)
|
||||
defer a.Shutdown()
|
||||
|
||||
args := &structs.RegisterRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: "bar",
|
||||
Address: "127.0.0.1",
|
||||
NodeMeta: map[string]string{
|
||||
"key": "value",
|
||||
},
|
||||
}
|
||||
|
||||
var out struct{}
|
||||
if err := a.RPC("Catalog.Register", args, &out); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("bar.node.consul.", dns.TypeANY)
|
||||
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m, a.DNSAddr())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
wantAnswer := []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "bar.node.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4},
|
||||
A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1
|
||||
},
|
||||
&dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: "bar.node.consul.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Rdlength: 0xa},
|
||||
Txt: []string{"key=value"},
|
||||
},
|
||||
}
|
||||
verify.Values(t, "answer", in.Answer, wantAnswer)
|
||||
}
|
||||
|
||||
func TestDNS_EDNS0(t *testing.T) {
|
||||
|
@ -1041,6 +1126,51 @@ func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) {
|
|||
verify.Values(t, "extra", in.Extra, wantExtra)
|
||||
}
|
||||
|
||||
func TestDNS_ConnectServiceLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
{
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Address = "127.0.0.55"
|
||||
args.Service.ProxyDestination = "db"
|
||||
args.Service.Address = ""
|
||||
args.Service.Port = 12345
|
||||
var out struct{}
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
}
|
||||
|
||||
// Look up the service
|
||||
questions := []string{
|
||||
"db.connect.consul.",
|
||||
}
|
||||
for _, question := range questions {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(question, dns.TypeSRV)
|
||||
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m, a.DNSAddr())
|
||||
assert.Nil(err)
|
||||
assert.Len(in.Answer, 1)
|
||||
|
||||
srvRec, ok := in.Answer[0].(*dns.SRV)
|
||||
assert.True(ok)
|
||||
assert.Equal(uint16(12345), srvRec.Port)
|
||||
assert.Equal("foo.node.dc1.consul.", srvRec.Target)
|
||||
assert.Equal(uint32(0), srvRec.Hdr.Ttl)
|
||||
|
||||
cnameRec, ok := in.Extra[0].(*dns.A)
|
||||
assert.True(ok)
|
||||
assert.Equal("foo.node.dc1.consul.", cnameRec.Hdr.Name)
|
||||
assert.Equal(uint32(0), srvRec.Hdr.Ttl)
|
||||
assert.Equal("127.0.0.55", cnameRec.A.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_ExternalServiceLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
|
@ -4613,6 +4743,93 @@ func TestDNS_ServiceLookup_FilterACL(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDNS_ServiceLookup_MetaTXT(t *testing.T) {
|
||||
a := NewTestAgent(t.Name(), `dns_config = { enable_additional_node_meta_txt = true }`)
|
||||
defer a.Shutdown()
|
||||
|
||||
args := &structs.RegisterRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: "bar",
|
||||
Address: "127.0.0.1",
|
||||
NodeMeta: map[string]string{
|
||||
"key": "value",
|
||||
},
|
||||
Service: &structs.NodeService{
|
||||
Service: "db",
|
||||
Tags: []string{"master"},
|
||||
Port: 12345,
|
||||
},
|
||||
}
|
||||
|
||||
var out struct{}
|
||||
if err := a.RPC("Catalog.Register", args, &out); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("db.service.consul.", dns.TypeSRV)
|
||||
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m, a.DNSAddr())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
wantAdditional := []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4},
|
||||
A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1
|
||||
},
|
||||
&dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Rdlength: 0xa},
|
||||
Txt: []string{"key=value"},
|
||||
},
|
||||
}
|
||||
verify.Values(t, "additional", in.Extra, wantAdditional)
|
||||
}
|
||||
|
||||
func TestDNS_ServiceLookup_SuppressTXT(t *testing.T) {
|
||||
a := NewTestAgent(t.Name(), `dns_config = { enable_additional_node_meta_txt = false }`)
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register a node with a service.
|
||||
args := &structs.RegisterRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: "bar",
|
||||
Address: "127.0.0.1",
|
||||
NodeMeta: map[string]string{
|
||||
"key": "value",
|
||||
},
|
||||
Service: &structs.NodeService{
|
||||
Service: "db",
|
||||
Tags: []string{"master"},
|
||||
Port: 12345,
|
||||
},
|
||||
}
|
||||
|
||||
var out struct{}
|
||||
if err := a.RPC("Catalog.Register", args, &out); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("db.service.consul.", dns.TypeSRV)
|
||||
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m, a.DNSAddr())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
wantAdditional := []dns.RR{
|
||||
&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4},
|
||||
A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1
|
||||
},
|
||||
}
|
||||
verify.Values(t, "additional", in.Extra, wantAdditional)
|
||||
}
|
||||
|
||||
func TestDNS_AddressLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
|
|
|
@ -143,9 +143,17 @@ RETRY_ONCE:
|
|||
return out.HealthChecks, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) HealthConnectServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
return s.healthServiceNodes(resp, req, true)
|
||||
}
|
||||
|
||||
func (s *HTTPServer) HealthServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
return s.healthServiceNodes(resp, req, false)
|
||||
}
|
||||
|
||||
func (s *HTTPServer) healthServiceNodes(resp http.ResponseWriter, req *http.Request, connect bool) (interface{}, error) {
|
||||
// Set default DC
|
||||
args := structs.ServiceSpecificRequest{}
|
||||
args := structs.ServiceSpecificRequest{Connect: connect}
|
||||
s.parseSource(req, &args.Source)
|
||||
args.NodeMetaFilters = s.parseMetaFilter(req)
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
|
@ -159,8 +167,14 @@ func (s *HTTPServer) HealthServiceNodes(resp http.ResponseWriter, req *http.Requ
|
|||
args.TagFilter = true
|
||||
}
|
||||
|
||||
// Determine the prefix
|
||||
prefix := "/v1/health/service/"
|
||||
if connect {
|
||||
prefix = "/v1/health/connect/"
|
||||
}
|
||||
|
||||
// Pull out the service name
|
||||
args.ServiceName = strings.TrimPrefix(req.URL.Path, "/v1/health/service/")
|
||||
args.ServiceName = strings.TrimPrefix(req.URL.Path, prefix)
|
||||
if args.ServiceName == "" {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, "Missing service name")
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/testutil/retry"
|
||||
"github.com/hashicorp/serf/coordinate"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHealthChecksInState(t *testing.T) {
|
||||
|
@ -770,6 +771,105 @@ func TestHealthServiceNodes_WanTranslation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHealthConnectServiceNodes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/health/connect/%s?dc=dc1", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
// Should be a non-nil empty list for checks
|
||||
nodes := obj.(structs.CheckServiceNodes)
|
||||
assert.Len(nodes, 1)
|
||||
assert.Len(nodes[0].Checks, 0)
|
||||
}
|
||||
|
||||
func TestHealthConnectServiceNodes_PassingFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Check = &structs.HealthCheck{
|
||||
Node: args.Node,
|
||||
Name: "check",
|
||||
ServiceID: args.Service.Service,
|
||||
Status: api.HealthCritical,
|
||||
}
|
||||
var out struct{}
|
||||
assert.Nil(t, a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
t.Run("bc_no_query_value", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/health/connect/%s?passing", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
// Should be 0 health check for consul
|
||||
nodes := obj.(structs.CheckServiceNodes)
|
||||
assert.Len(nodes, 0)
|
||||
})
|
||||
|
||||
t.Run("passing_true", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/health/connect/%s?passing=true", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
// Should be 0 health check for consul
|
||||
nodes := obj.(structs.CheckServiceNodes)
|
||||
assert.Len(nodes, 0)
|
||||
})
|
||||
|
||||
t.Run("passing_false", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/health/connect/%s?passing=false", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
// Should be 1
|
||||
nodes := obj.(structs.CheckServiceNodes)
|
||||
assert.Len(nodes, 1)
|
||||
})
|
||||
|
||||
t.Run("passing_bad", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/health/connect/%s?passing=nope-nope", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
a.srv.HealthConnectServiceNodes(resp, req)
|
||||
assert.Equal(400, resp.Code)
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
assert.Nil(err)
|
||||
assert.True(bytes.Contains(body, []byte("Invalid value for ?passing")))
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterNonPassing(t *testing.T) {
|
||||
t.Parallel()
|
||||
nodes := structs.CheckServiceNodes{
|
||||
|
|
|
@ -3,6 +3,7 @@ package agent
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
|
@ -16,6 +17,7 @@ import (
|
|||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
@ -157,9 +159,9 @@ func (s *HTTPServer) handler(enableDebug bool) http.Handler {
|
|||
}
|
||||
|
||||
if s.IsUIEnabled() {
|
||||
new_ui, err := strconv.ParseBool(os.Getenv("CONSUL_UI_BETA"))
|
||||
legacy_ui, err := strconv.ParseBool(os.Getenv("CONSUL_UI_LEGACY"))
|
||||
if err != nil {
|
||||
new_ui = false
|
||||
legacy_ui = false
|
||||
}
|
||||
var uifs http.FileSystem
|
||||
|
||||
|
@ -169,15 +171,15 @@ func (s *HTTPServer) handler(enableDebug bool) http.Handler {
|
|||
} else {
|
||||
fs := assetFS()
|
||||
|
||||
if new_ui {
|
||||
fs.Prefix += "/v2/"
|
||||
} else {
|
||||
if legacy_ui {
|
||||
fs.Prefix += "/v1/"
|
||||
} else {
|
||||
fs.Prefix += "/v2/"
|
||||
}
|
||||
uifs = fs
|
||||
}
|
||||
|
||||
if new_ui {
|
||||
if !legacy_ui {
|
||||
uifs = &redirectFS{fs: uifs}
|
||||
}
|
||||
|
||||
|
@ -384,6 +386,13 @@ func (s *HTTPServer) Index(resp http.ResponseWriter, req *http.Request) {
|
|||
|
||||
// decodeBody is used to decode a JSON request body
|
||||
func decodeBody(req *http.Request, out interface{}, cb func(interface{}) error) error {
|
||||
// This generally only happens in tests since real HTTP requests set
|
||||
// a non-nil body with no content. We guard against it anyways to prevent
|
||||
// a panic. The EOF response is the same behavior as an empty reader.
|
||||
if req.Body == nil {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
var raw interface{}
|
||||
dec := json.NewDecoder(req.Body)
|
||||
if err := dec.Decode(&raw); err != nil {
|
||||
|
@ -409,6 +418,14 @@ func setTranslateAddr(resp http.ResponseWriter, active bool) {
|
|||
|
||||
// setIndex is used to set the index response header
|
||||
func setIndex(resp http.ResponseWriter, index uint64) {
|
||||
// If we ever return X-Consul-Index of 0 blocking clients will go into a busy
|
||||
// loop and hammer us since ?index=0 will never block. It's always safe to
|
||||
// return index=1 since the very first Raft write is always an internal one
|
||||
// writing the raft config for the cluster so no user-facing blocking query
|
||||
// will ever legitimately have an X-Consul-Index of 1.
|
||||
if index == 0 {
|
||||
index = 1
|
||||
}
|
||||
resp.Header().Set("X-Consul-Index", strconv.FormatUint(index, 10))
|
||||
}
|
||||
|
||||
|
@ -444,6 +461,15 @@ func setMeta(resp http.ResponseWriter, m *structs.QueryMeta) {
|
|||
setConsistency(resp, m.ConsistencyLevel)
|
||||
}
|
||||
|
||||
// setCacheMeta sets http response headers to indicate cache status.
|
||||
func setCacheMeta(resp http.ResponseWriter, m *cache.ResultMeta) {
|
||||
str := "MISS"
|
||||
if m != nil && m.Hit {
|
||||
str = "HIT"
|
||||
}
|
||||
resp.Header().Set("X-Cache", str)
|
||||
}
|
||||
|
||||
// setHeaders is used to set canonical response header fields
|
||||
func setHeaders(resp http.ResponseWriter, headers map[string]string) {
|
||||
for field, value := range headers {
|
||||
|
|
|
@ -29,16 +29,27 @@ func init() {
|
|||
registerEndpoint("/v1/agent/check/warn/", []string{"PUT"}, (*HTTPServer).AgentCheckWarn)
|
||||
registerEndpoint("/v1/agent/check/fail/", []string{"PUT"}, (*HTTPServer).AgentCheckFail)
|
||||
registerEndpoint("/v1/agent/check/update/", []string{"PUT"}, (*HTTPServer).AgentCheckUpdate)
|
||||
registerEndpoint("/v1/agent/connect/authorize", []string{"POST"}, (*HTTPServer).AgentConnectAuthorize)
|
||||
registerEndpoint("/v1/agent/connect/ca/roots", []string{"GET"}, (*HTTPServer).AgentConnectCARoots)
|
||||
registerEndpoint("/v1/agent/connect/ca/leaf/", []string{"GET"}, (*HTTPServer).AgentConnectCALeafCert)
|
||||
registerEndpoint("/v1/agent/connect/proxy/", []string{"GET"}, (*HTTPServer).AgentConnectProxyConfig)
|
||||
registerEndpoint("/v1/agent/service/register", []string{"PUT"}, (*HTTPServer).AgentRegisterService)
|
||||
registerEndpoint("/v1/agent/service/deregister/", []string{"PUT"}, (*HTTPServer).AgentDeregisterService)
|
||||
registerEndpoint("/v1/agent/service/maintenance/", []string{"PUT"}, (*HTTPServer).AgentServiceMaintenance)
|
||||
registerEndpoint("/v1/catalog/register", []string{"PUT"}, (*HTTPServer).CatalogRegister)
|
||||
registerEndpoint("/v1/catalog/connect/", []string{"GET"}, (*HTTPServer).CatalogConnectServiceNodes)
|
||||
registerEndpoint("/v1/catalog/deregister", []string{"PUT"}, (*HTTPServer).CatalogDeregister)
|
||||
registerEndpoint("/v1/catalog/datacenters", []string{"GET"}, (*HTTPServer).CatalogDatacenters)
|
||||
registerEndpoint("/v1/catalog/nodes", []string{"GET"}, (*HTTPServer).CatalogNodes)
|
||||
registerEndpoint("/v1/catalog/services", []string{"GET"}, (*HTTPServer).CatalogServices)
|
||||
registerEndpoint("/v1/catalog/service/", []string{"GET"}, (*HTTPServer).CatalogServiceNodes)
|
||||
registerEndpoint("/v1/catalog/node/", []string{"GET"}, (*HTTPServer).CatalogNodeServices)
|
||||
registerEndpoint("/v1/connect/ca/configuration", []string{"GET", "PUT"}, (*HTTPServer).ConnectCAConfiguration)
|
||||
registerEndpoint("/v1/connect/ca/roots", []string{"GET"}, (*HTTPServer).ConnectCARoots)
|
||||
registerEndpoint("/v1/connect/intentions", []string{"GET", "POST"}, (*HTTPServer).IntentionEndpoint)
|
||||
registerEndpoint("/v1/connect/intentions/match", []string{"GET"}, (*HTTPServer).IntentionMatch)
|
||||
registerEndpoint("/v1/connect/intentions/check", []string{"GET"}, (*HTTPServer).IntentionCheck)
|
||||
registerEndpoint("/v1/connect/intentions/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).IntentionSpecific)
|
||||
registerEndpoint("/v1/coordinate/datacenters", []string{"GET"}, (*HTTPServer).CoordinateDatacenters)
|
||||
registerEndpoint("/v1/coordinate/nodes", []string{"GET"}, (*HTTPServer).CoordinateNodes)
|
||||
registerEndpoint("/v1/coordinate/node/", []string{"GET"}, (*HTTPServer).CoordinateNode)
|
||||
|
@ -49,6 +60,7 @@ func init() {
|
|||
registerEndpoint("/v1/health/checks/", []string{"GET"}, (*HTTPServer).HealthServiceChecks)
|
||||
registerEndpoint("/v1/health/state/", []string{"GET"}, (*HTTPServer).HealthChecksInState)
|
||||
registerEndpoint("/v1/health/service/", []string{"GET"}, (*HTTPServer).HealthServiceNodes)
|
||||
registerEndpoint("/v1/health/connect/", []string{"GET"}, (*HTTPServer).HealthConnectServiceNodes)
|
||||
registerEndpoint("/v1/internal/ui/nodes", []string{"GET"}, (*HTTPServer).UINodes)
|
||||
registerEndpoint("/v1/internal/ui/node/", []string{"GET"}, (*HTTPServer).UINodeInfo)
|
||||
registerEndpoint("/v1/internal/ui/services", []string{"GET"}, (*HTTPServer).UIServices)
|
||||
|
|
|
@ -0,0 +1,302 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// /v1/connection/intentions
|
||||
func (s *HTTPServer) IntentionEndpoint(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
switch req.Method {
|
||||
case "GET":
|
||||
return s.IntentionList(resp, req)
|
||||
|
||||
case "POST":
|
||||
return s.IntentionCreate(resp, req)
|
||||
|
||||
default:
|
||||
return nil, MethodNotAllowedError{req.Method, []string{"GET", "POST"}}
|
||||
}
|
||||
}
|
||||
|
||||
// GET /v1/connect/intentions
|
||||
func (s *HTTPServer) IntentionList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in IntentionEndpoint
|
||||
|
||||
var args structs.DCSpecificRequest
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var reply structs.IndexedIntentions
|
||||
if err := s.agent.RPC("Intention.List", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reply.Intentions, nil
|
||||
}
|
||||
|
||||
// POST /v1/connect/intentions
|
||||
func (s *HTTPServer) IntentionCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in IntentionEndpoint
|
||||
|
||||
args := structs.IntentionRequest{
|
||||
Op: structs.IntentionOpCreate,
|
||||
}
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
s.parseToken(req, &args.Token)
|
||||
if err := decodeBody(req, &args.Intention, nil); err != nil {
|
||||
return nil, fmt.Errorf("Failed to decode request body: %s", err)
|
||||
}
|
||||
|
||||
var reply string
|
||||
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return intentionCreateResponse{reply}, nil
|
||||
}
|
||||
|
||||
// GET /v1/connect/intentions/match
|
||||
func (s *HTTPServer) IntentionMatch(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Prepare args
|
||||
args := &structs.IntentionQueryRequest{Match: &structs.IntentionQueryMatch{}}
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
|
||||
// Extract the "by" query parameter
|
||||
if by, ok := q["by"]; !ok || len(by) != 1 {
|
||||
return nil, fmt.Errorf("required query parameter 'by' not set")
|
||||
} else {
|
||||
switch v := structs.IntentionMatchType(by[0]); v {
|
||||
case structs.IntentionMatchSource, structs.IntentionMatchDestination:
|
||||
args.Match.Type = v
|
||||
default:
|
||||
return nil, fmt.Errorf("'by' parameter must be one of 'source' or 'destination'")
|
||||
}
|
||||
}
|
||||
|
||||
// Extract all the match names
|
||||
names, ok := q["name"]
|
||||
if !ok || len(names) == 0 {
|
||||
return nil, fmt.Errorf("required query parameter 'name' not set")
|
||||
}
|
||||
|
||||
// Build the entries in order. The order matters since that is the
|
||||
// order of the returned responses.
|
||||
args.Match.Entries = make([]structs.IntentionMatchEntry, len(names))
|
||||
for i, n := range names {
|
||||
entry, err := parseIntentionMatchEntry(n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("name %q is invalid: %s", n, err)
|
||||
}
|
||||
|
||||
args.Match.Entries[i] = entry
|
||||
}
|
||||
|
||||
var reply structs.IndexedIntentionMatches
|
||||
if err := s.agent.RPC("Intention.Match", args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We must have an identical count of matches
|
||||
if len(reply.Matches) != len(names) {
|
||||
return nil, fmt.Errorf("internal error: match response count didn't match input count")
|
||||
}
|
||||
|
||||
// Use empty list instead of nil.
|
||||
response := make(map[string]structs.Intentions)
|
||||
for i, ixns := range reply.Matches {
|
||||
response[names[i]] = ixns
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GET /v1/connect/intentions/check
|
||||
func (s *HTTPServer) IntentionCheck(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Prepare args
|
||||
args := &structs.IntentionQueryRequest{Check: &structs.IntentionQueryCheck{}}
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
|
||||
// Set the source type if set
|
||||
args.Check.SourceType = structs.IntentionSourceConsul
|
||||
if sourceType, ok := q["source-type"]; ok && len(sourceType) > 0 {
|
||||
args.Check.SourceType = structs.IntentionSourceType(sourceType[0])
|
||||
}
|
||||
|
||||
// Extract the source/destination
|
||||
source, ok := q["source"]
|
||||
if !ok || len(source) != 1 {
|
||||
return nil, fmt.Errorf("required query parameter 'source' not set")
|
||||
}
|
||||
destination, ok := q["destination"]
|
||||
if !ok || len(destination) != 1 {
|
||||
return nil, fmt.Errorf("required query parameter 'destination' not set")
|
||||
}
|
||||
|
||||
// We parse them the same way as matches to extract namespace/name
|
||||
args.Check.SourceName = source[0]
|
||||
if args.Check.SourceType == structs.IntentionSourceConsul {
|
||||
entry, err := parseIntentionMatchEntry(source[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("source %q is invalid: %s", source[0], err)
|
||||
}
|
||||
args.Check.SourceNS = entry.Namespace
|
||||
args.Check.SourceName = entry.Name
|
||||
}
|
||||
|
||||
// The destination is always in the Consul format
|
||||
entry, err := parseIntentionMatchEntry(destination[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("destination %q is invalid: %s", destination[0], err)
|
||||
}
|
||||
args.Check.DestinationNS = entry.Namespace
|
||||
args.Check.DestinationName = entry.Name
|
||||
|
||||
var reply structs.IntentionQueryCheckResponse
|
||||
if err := s.agent.RPC("Intention.Check", args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &reply, nil
|
||||
}
|
||||
|
||||
// IntentionSpecific handles the endpoint for /v1/connection/intentions/:id
|
||||
func (s *HTTPServer) IntentionSpecific(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
id := strings.TrimPrefix(req.URL.Path, "/v1/connect/intentions/")
|
||||
|
||||
switch req.Method {
|
||||
case "GET":
|
||||
return s.IntentionSpecificGet(id, resp, req)
|
||||
|
||||
case "PUT":
|
||||
return s.IntentionSpecificUpdate(id, resp, req)
|
||||
|
||||
case "DELETE":
|
||||
return s.IntentionSpecificDelete(id, resp, req)
|
||||
|
||||
default:
|
||||
return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}}
|
||||
}
|
||||
}
|
||||
|
||||
// GET /v1/connect/intentions/:id
|
||||
func (s *HTTPServer) IntentionSpecificGet(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in IntentionEndpoint
|
||||
|
||||
args := structs.IntentionQueryRequest{
|
||||
IntentionID: id,
|
||||
}
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var reply structs.IndexedIntentions
|
||||
if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil {
|
||||
// We have to check the string since the RPC sheds the error type
|
||||
if err.Error() == consul.ErrIntentionNotFound.Error() {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(resp, err.Error())
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// This shouldn't happen since the called API documents it shouldn't,
|
||||
// but we check since the alternative if it happens is a panic.
|
||||
if len(reply.Intentions) == 0 {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return reply.Intentions[0], nil
|
||||
}
|
||||
|
||||
// PUT /v1/connect/intentions/:id
|
||||
func (s *HTTPServer) IntentionSpecificUpdate(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in IntentionEndpoint
|
||||
|
||||
args := structs.IntentionRequest{
|
||||
Op: structs.IntentionOpUpdate,
|
||||
}
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
s.parseToken(req, &args.Token)
|
||||
if err := decodeBody(req, &args.Intention, nil); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Request decode failed: %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Use the ID from the URL
|
||||
args.Intention.ID = id
|
||||
|
||||
var reply string
|
||||
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update uses the same create response
|
||||
return intentionCreateResponse{reply}, nil
|
||||
|
||||
}
|
||||
|
||||
// DELETE /v1/connect/intentions/:id
|
||||
func (s *HTTPServer) IntentionSpecificDelete(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in IntentionEndpoint
|
||||
|
||||
args := structs.IntentionRequest{
|
||||
Op: structs.IntentionOpDelete,
|
||||
Intention: &structs.Intention{ID: id},
|
||||
}
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
s.parseToken(req, &args.Token)
|
||||
|
||||
var reply string
|
||||
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// intentionCreateResponse is the response structure for creating an intention.
|
||||
type intentionCreateResponse struct{ ID string }
|
||||
|
||||
// parseIntentionMatchEntry parses the query parameter for an intention
|
||||
// match query entry.
|
||||
func parseIntentionMatchEntry(input string) (structs.IntentionMatchEntry, error) {
|
||||
var result structs.IntentionMatchEntry
|
||||
result.Namespace = structs.IntentionDefaultNamespace
|
||||
|
||||
// TODO(mitchellh): when namespaces are introduced, set the default
|
||||
// namespace to be the namespace of the requestor.
|
||||
|
||||
// Get the index to the '/'. If it doesn't exist, we have just a name
|
||||
// so just set that and return.
|
||||
idx := strings.IndexByte(input, '/')
|
||||
if idx == -1 {
|
||||
result.Name = input
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result.Namespace = input[:idx]
|
||||
result.Name = input[idx+1:]
|
||||
if strings.IndexByte(result.Name, '/') != -1 {
|
||||
return result, fmt.Errorf("input can contain at most one '/'")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
|
@ -0,0 +1,502 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIntentionsList_empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Make sure an empty list is non-nil.
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/intentions", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionList(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(structs.Intentions)
|
||||
assert.NotNil(value)
|
||||
assert.Len(value, 0)
|
||||
}
|
||||
|
||||
func TestIntentionsList_values(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Create some intentions, note we create the lowest precedence first to test
|
||||
// sorting.
|
||||
for _, v := range []string{"*", "foo", "bar"} {
|
||||
req := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: structs.TestIntention(t),
|
||||
}
|
||||
req.Intention.SourceName = v
|
||||
|
||||
var reply string
|
||||
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
|
||||
}
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/intentions", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionList(resp, req)
|
||||
assert.NoError(err)
|
||||
|
||||
value := obj.(structs.Intentions)
|
||||
assert.Len(value, 3)
|
||||
|
||||
expected := []string{"bar", "foo", "*"}
|
||||
actual := []string{
|
||||
value[0].SourceName,
|
||||
value[1].SourceName,
|
||||
value[2].SourceName,
|
||||
}
|
||||
assert.Equal(expected, actual)
|
||||
}
|
||||
|
||||
func TestIntentionsMatch_basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Create some intentions
|
||||
{
|
||||
insert := [][]string{
|
||||
{"foo", "*", "foo", "*"},
|
||||
{"foo", "*", "foo", "bar"},
|
||||
{"foo", "*", "foo", "baz"}, // shouldn't match
|
||||
{"foo", "*", "bar", "bar"}, // shouldn't match
|
||||
{"foo", "*", "bar", "*"}, // shouldn't match
|
||||
{"foo", "*", "*", "*"},
|
||||
{"bar", "*", "foo", "bar"}, // duplicate destination different source
|
||||
}
|
||||
|
||||
for _, v := range insert {
|
||||
ixn := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: structs.TestIntention(t),
|
||||
}
|
||||
ixn.Intention.SourceNS = v[0]
|
||||
ixn.Intention.SourceName = v[1]
|
||||
ixn.Intention.DestinationNS = v[2]
|
||||
ixn.Intention.DestinationName = v[3]
|
||||
|
||||
// Create
|
||||
var reply string
|
||||
assert.Nil(a.RPC("Intention.Apply", &ixn, &reply))
|
||||
}
|
||||
}
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/match?by=destination&name=foo/bar", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionMatch(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(map[string]structs.Intentions)
|
||||
assert.Len(value, 1)
|
||||
|
||||
var actual [][]string
|
||||
expected := [][]string{
|
||||
{"bar", "*", "foo", "bar"},
|
||||
{"foo", "*", "foo", "bar"},
|
||||
{"foo", "*", "foo", "*"},
|
||||
{"foo", "*", "*", "*"},
|
||||
}
|
||||
for _, ixn := range value["foo/bar"] {
|
||||
actual = append(actual, []string{
|
||||
ixn.SourceNS,
|
||||
ixn.SourceName,
|
||||
ixn.DestinationNS,
|
||||
ixn.DestinationName,
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(expected, actual)
|
||||
}
|
||||
|
||||
func TestIntentionsMatch_noBy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/match?name=foo/bar", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionMatch(resp, req)
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "by")
|
||||
assert.Nil(obj)
|
||||
}
|
||||
|
||||
func TestIntentionsMatch_byInvalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/match?by=datacenter", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionMatch(resp, req)
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "'by' parameter")
|
||||
assert.Nil(obj)
|
||||
}
|
||||
|
||||
func TestIntentionsMatch_noName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/match?by=source", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionMatch(resp, req)
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "'name' not set")
|
||||
assert.Nil(obj)
|
||||
}
|
||||
|
||||
func TestIntentionsCheck_basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Create some intentions
|
||||
{
|
||||
insert := [][]string{
|
||||
{"foo", "*", "foo", "*"},
|
||||
{"foo", "*", "foo", "bar"},
|
||||
{"bar", "*", "foo", "bar"},
|
||||
}
|
||||
|
||||
for _, v := range insert {
|
||||
ixn := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: structs.TestIntention(t),
|
||||
}
|
||||
ixn.Intention.SourceNS = v[0]
|
||||
ixn.Intention.SourceName = v[1]
|
||||
ixn.Intention.DestinationNS = v[2]
|
||||
ixn.Intention.DestinationName = v[3]
|
||||
ixn.Intention.Action = structs.IntentionActionDeny
|
||||
|
||||
// Create
|
||||
var reply string
|
||||
require.Nil(a.RPC("Intention.Apply", &ixn, &reply))
|
||||
}
|
||||
}
|
||||
|
||||
// Request matching intention
|
||||
{
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/test?source=foo/bar&destination=foo/baz", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionCheck(resp, req)
|
||||
require.Nil(err)
|
||||
value := obj.(*structs.IntentionQueryCheckResponse)
|
||||
require.False(value.Allowed)
|
||||
}
|
||||
|
||||
// Request non-matching intention
|
||||
{
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/test?source=foo/bar&destination=bar/qux", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionCheck(resp, req)
|
||||
require.Nil(err)
|
||||
value := obj.(*structs.IntentionQueryCheckResponse)
|
||||
require.True(value.Allowed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntentionsCheck_noSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/test?destination=B", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionCheck(resp, req)
|
||||
require.NotNil(err)
|
||||
require.Contains(err.Error(), "'source' not set")
|
||||
require.Nil(obj)
|
||||
}
|
||||
|
||||
func TestIntentionsCheck_noDestination(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/test?source=B", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionCheck(resp, req)
|
||||
require.NotNil(err)
|
||||
require.Contains(err.Error(), "'destination' not set")
|
||||
require.Nil(obj)
|
||||
}
|
||||
|
||||
func TestIntentionsCreate_good(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Make sure an empty list is non-nil.
|
||||
args := structs.TestIntention(t)
|
||||
args.SourceName = "foo"
|
||||
req, _ := http.NewRequest("POST", "/v1/connect/intentions", jsonReader(args))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionCreate(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(intentionCreateResponse)
|
||||
assert.NotEqual("", value.ID)
|
||||
|
||||
// Read the value
|
||||
{
|
||||
req := &structs.IntentionQueryRequest{
|
||||
Datacenter: "dc1",
|
||||
IntentionID: value.ID,
|
||||
}
|
||||
var resp structs.IndexedIntentions
|
||||
assert.Nil(a.RPC("Intention.Get", req, &resp))
|
||||
assert.Len(resp.Intentions, 1)
|
||||
actual := resp.Intentions[0]
|
||||
assert.Equal("foo", actual.SourceName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntentionsCreate_noBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Create with no body
|
||||
req, _ := http.NewRequest("POST", "/v1/connect/intentions", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.IntentionCreate(resp, req)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIntentionsSpecificGet_good(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// The intention
|
||||
ixn := structs.TestIntention(t)
|
||||
|
||||
// Create an intention directly
|
||||
var reply string
|
||||
{
|
||||
req := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: ixn,
|
||||
}
|
||||
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
|
||||
}
|
||||
|
||||
// Get the value
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf("/v1/connect/intentions/%s", reply), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionSpecific(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(*structs.Intention)
|
||||
assert.Equal(reply, value.ID)
|
||||
|
||||
ixn.ID = value.ID
|
||||
ixn.RaftIndex = value.RaftIndex
|
||||
ixn.CreatedAt, ixn.UpdatedAt = value.CreatedAt, value.UpdatedAt
|
||||
assert.Equal(ixn, value)
|
||||
}
|
||||
|
||||
func TestIntentionsSpecificUpdate_good(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// The intention
|
||||
ixn := structs.TestIntention(t)
|
||||
|
||||
// Create an intention directly
|
||||
var reply string
|
||||
{
|
||||
req := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: ixn,
|
||||
}
|
||||
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
|
||||
}
|
||||
|
||||
// Update the intention
|
||||
ixn.ID = "bogus"
|
||||
ixn.SourceName = "bar"
|
||||
req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/connect/intentions/%s", reply), jsonReader(ixn))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionSpecific(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(intentionCreateResponse)
|
||||
assert.Equal(reply, value.ID)
|
||||
|
||||
// Read the value
|
||||
{
|
||||
req := &structs.IntentionQueryRequest{
|
||||
Datacenter: "dc1",
|
||||
IntentionID: reply,
|
||||
}
|
||||
var resp structs.IndexedIntentions
|
||||
assert.Nil(a.RPC("Intention.Get", req, &resp))
|
||||
assert.Len(resp.Intentions, 1)
|
||||
actual := resp.Intentions[0]
|
||||
assert.Equal("bar", actual.SourceName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntentionsSpecificDelete_good(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// The intention
|
||||
ixn := structs.TestIntention(t)
|
||||
ixn.SourceName = "foo"
|
||||
|
||||
// Create an intention directly
|
||||
var reply string
|
||||
{
|
||||
req := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: ixn,
|
||||
}
|
||||
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
|
||||
}
|
||||
|
||||
// Sanity check that the intention exists
|
||||
{
|
||||
req := &structs.IntentionQueryRequest{
|
||||
Datacenter: "dc1",
|
||||
IntentionID: reply,
|
||||
}
|
||||
var resp structs.IndexedIntentions
|
||||
assert.Nil(a.RPC("Intention.Get", req, &resp))
|
||||
assert.Len(resp.Intentions, 1)
|
||||
actual := resp.Intentions[0]
|
||||
assert.Equal("foo", actual.SourceName)
|
||||
}
|
||||
|
||||
// Delete the intention
|
||||
req, _ := http.NewRequest("DELETE", fmt.Sprintf("/v1/connect/intentions/%s", reply), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionSpecific(resp, req)
|
||||
assert.Nil(err)
|
||||
assert.Equal(true, obj)
|
||||
|
||||
// Verify the intention is gone
|
||||
{
|
||||
req := &structs.IntentionQueryRequest{
|
||||
Datacenter: "dc1",
|
||||
IntentionID: reply,
|
||||
}
|
||||
var resp structs.IndexedIntentions
|
||||
err := a.RPC("Intention.Get", req, &resp)
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseIntentionMatchEntry(t *testing.T) {
|
||||
cases := []struct {
|
||||
Input string
|
||||
Expected structs.IntentionMatchEntry
|
||||
Err bool
|
||||
}{
|
||||
{
|
||||
"foo",
|
||||
structs.IntentionMatchEntry{
|
||||
Namespace: structs.IntentionDefaultNamespace,
|
||||
Name: "foo",
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
"foo/bar",
|
||||
structs.IntentionMatchEntry{
|
||||
Namespace: "foo",
|
||||
Name: "bar",
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
"foo/bar/baz",
|
||||
structs.IntentionMatchEntry{},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Input, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
actual, err := parseIntentionMatchEntry(tc.Input)
|
||||
assert.Equal(err != nil, tc.Err, err)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(tc.Expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue