Merge branch 'master' of github.com:hashicorp/consul into WinService

This commit is contained in:
Siva 2018-06-26 16:49:50 -04:00
commit 2182e289a3
1232 changed files with 287819 additions and 7094 deletions

3
.dockerignore Normal file
View File

@ -0,0 +1,3 @@
pkg/
.git
bin/

View File

@ -1,13 +1,21 @@
## UNRELEASED
## 1.2.0 (June 26, 2018)
FEATURES:
* **Connect Feature Beta**: This version includes a major new feature for Consul named Connect. Connect enables secure service-to-service communication with automatic TLS encryption and identity-based authorization. For more details and links to demos and getting started guides, see the [announcement blog post](https://www.hashicorp.com/blog/consul-1-2-service-mesh).
* Connect must be enabled explicitly in configuration so upgrading a cluster will not affect any existing functionality until it's enabled.
* This is a Beta feature, we don't recommend enabling this in production yet. Please see the documentation for more information.
* dns: Enable PTR record lookups for services with IPs that have no registered node [[PR-4083](https://github.com/hashicorp/consul/pull/4083)]
* ui: Default to serving the new UI. Setting the `CONSUL_UI_LEGACY` environment variable to `1` or `true` will revert to serving the old UI
IMPROVEMENTS:
* agent: A Consul user-agent string is now sent to providers when making retry-join requests [GH-4013](https://github.com/hashicorp/consul/pull/4013)
* agent: A Consul user-agent string is now sent to providers when making retry-join requests [[GH-4013](https://github.com/hashicorp/consul/issues/4013)](https://github.com/hashicorp/consul/pull/4013)
* client: Add metrics for failed RPCs [PR-4220](https://github.com/hashicorp/consul/pull/4220)
* agent: Add configuration entry to control including TXT records for node meta in DNS responses [PR-4215](https://github.com/hashicorp/consul/pull/4215)
* client: Make RPC rate limit configuration reloadable [GH-4012](https://github.com/hashicorp/consul/issues/4012)
BUG FIXES:

View File

@ -16,35 +16,97 @@ GOTEST_PKGS ?= "./..."
else
GOTEST_PKGS=$(shell go list ./... | sed 's/github.com\/hashicorp\/consul/./' | egrep -v "^($(GOTEST_PKGS_EXCLUDE))$$")
endif
GOOS=$(shell go env GOOS)
GOARCH=$(shell go env GOARCH)
GOOS?=$(shell go env GOOS)
GOARCH?=$(shell go env GOARCH)
GOPATH=$(shell go env GOPATH)
ASSETFS_PATH?=agent/bindata_assetfs.go
# Get the git commit
GIT_COMMIT=$(shell git rev-parse --short HEAD)
GIT_DIRTY=$(shell test -n "`git status --porcelain`" && echo "+CHANGES" || true)
GIT_DESCRIBE=$(shell git describe --tags --always)
GIT_COMMIT?=$(shell git rev-parse --short HEAD)
GIT_DIRTY?=$(shell test -n "`git status --porcelain`" && echo "+CHANGES" || true)
GIT_DESCRIBE?=$(shell git describe --tags --always)
GIT_IMPORT=github.com/hashicorp/consul/version
GOLDFLAGS=-X $(GIT_IMPORT).GitCommit=$(GIT_COMMIT)$(GIT_DIRTY) -X $(GIT_IMPORT).GitDescribe=$(GIT_DESCRIBE)
ifeq ($(FORCE_REBUILD),1)
NOCACHE=--no-cache
else
NOCACHE=
endif
DOCKER_BUILD_QUIET?=1
ifeq (${DOCKER_BUILD_QUIET},1)
QUIET=-q
else
QUIET=
endif
CONSUL_DEV_IMAGE?=consul-dev
GO_BUILD_TAG?=consul-build-go
UI_BUILD_TAG?=consul-build-ui
UI_LEGACY_BUILD_TAG?=consul-build-ui-legacy
BUILD_CONTAINER_NAME?=consul-builder
DIST_TAG?=1
DIST_BUILD?=1
DIST_SIGN?=1
ifdef DIST_VERSION
DIST_VERSION_ARG=-v "$(DIST_VERSION)"
else
DIST_VERSION_ARG=
endif
ifdef DIST_RELEASE_DATE
DIST_DATE_ARG=-d "$(DIST_RELEASE_DATE)"
else
DIST_DATE_ARG=
endif
ifdef DIST_PRERELEASE
DIST_REL_ARG=-r "$(DIST_PRERELEASE)"
else
DIST_REL_ARG=
endif
PUB_GIT?=1
PUB_WEBSITE?=1
ifeq ($(PUB_GIT),1)
PUB_GIT_ARG=-g
else
PUB_GIT_ARG=
endif
ifeq ($(PUB_WEBSITE),1)
PUB_WEBSITE_ARG=-g
else
PUB_WEBSITE_ARG=
endif
export GO_BUILD_TAG
export UI_BUILD_TAG
export UI_LEGACY_BUILD_TAG
export BUILD_CONTAINER_NAME
export GIT_COMMIT
export GIT_DIRTY
export GIT_DESCRIBE
export GOTAGS
export GOLDFLAGS
# all builds binaries for all targets
all: bin
bin: tools
@mkdir -p bin/
@GOTAGS='$(GOTAGS)' sh -c "'$(CURDIR)/scripts/build.sh'"
bin: tools dev-build
# dev creates binaries for testing locally - these are put into ./bin and $GOPATH
dev: changelogfmt vendorfmt dev-build
dev-build:
@echo "--> Building consul"
mkdir -p pkg/$(GOOS)_$(GOARCH)/ bin/
go install -ldflags '$(GOLDFLAGS)' -tags '$(GOTAGS)'
cp $(GOPATH)/bin/consul bin/
cp $(GOPATH)/bin/consul pkg/$(GOOS)_$(GOARCH)
@$(SHELL) $(CURDIR)/build-support/scripts/build-local.sh -o $(GOOS) -a $(GOARCH)
dev-docker:
@docker build -t '$(CONSUL_DEV_IMAGE)' --build-arg 'GIT_COMMIT=$(GIT_COMMIT)' --build-arg 'GIT_DIRTY=$(GIT_DIRTY)' --build-arg 'GIT_DESCRIBE=$(GIT_DESCRIBE)' -f $(CURDIR)/build-support/docker/Consul-Dev.dockerfile $(CURDIR)
vendorfmt:
@echo "--> Formatting vendor/vendor.json"
@ -57,12 +119,17 @@ changelogfmt:
# linux builds a linux package independent of the source platform
linux:
mkdir -p pkg/linux_amd64/
GOOS=linux GOARCH=amd64 go build -ldflags '$(GOLDFLAGS)' -tags '$(GOTAGS)' -o pkg/linux_amd64/consul
@$(SHELL) $(CURDIR)/build-support/scripts/build-local.sh -o linux -a amd64
# dist builds binaries for all platforms and packages them for distribution
dist:
@GOTAGS='$(GOTAGS)' sh -c "'$(CURDIR)/scripts/dist.sh'"
@$(SHELL) $(CURDIR)/build-support/scripts/release.sh -t '$(DIST_TAG)' -b '$(DIST_BUILD)' -S '$(DIST_SIGN)' $(DIST_VERSION_ARG) $(DIST_DATE_ARG) $(DIST_REL_ARG)
publish:
@$(SHELL) $(CURDIR)/build-support/scripts/publish.sh $(PUB_GIT_ARG) $(PUB_WEBSITE_ARG)
dev-tree:
@$(SHELL) $(CURDIR)/build-support/scripts/dev.sh
cov:
gocov test $(GOFILES) | gocov-html > /tmp/coverage.html
@ -78,10 +145,16 @@ test: other-consul dev-build vet
@# _something_ to stop them terminating us due to inactivity...
{ go test $(GOTEST_FLAGS) -tags '$(GOTAGS)' -timeout 5m $(GOTEST_PKGS) 2>&1 ; echo $$? > exit-code ; } | tee test.log | egrep '^(ok|FAIL)\s*github.com/hashicorp/consul'
@echo "Exit code: $$(cat exit-code)" >> test.log
@grep -A5 'DATA RACE' test.log || true
@grep -A10 'panic: test timed out' test.log || true
@grep -A1 -- '--- SKIP:' test.log || true
@grep -A1 -- '--- FAIL:' test.log || true
@# This prints all the race report between ====== lines
@awk '/^WARNING: DATA RACE/ {do_print=1; print "=================="} do_print==1 {print} /^={10,}/ {do_print=0}' test.log || true
@grep -A10 'panic: ' test.log || true
@# Prints all the failure output until the next non-indented line - testify
@# helpers often output multiple lines for readability but useless if we can't
@# see them. Un-intuitive order of matches is necessary. No || true because
@# awk always returns true even if there is no match and it breaks non-bash
@# shells locally.
@awk '/^[^[:space:]]/ {do_print=0} /--- SKIP/ {do_print=1} do_print==1 {print}' test.log
@awk '/^[^[:space:]]/ {do_print=0} /--- FAIL/ {do_print=1} do_print==1 {print}' test.log
@grep '^FAIL' test.log || true
@if [ "$$(cat exit-code)" == "0" ] ; then echo "PASS" ; exit 0 ; else exit 1 ; fi
@ -111,20 +184,57 @@ vet:
exit 1; \
fi
# Build the static web ui and build static assets inside a Docker container, the
# same way a release build works. This implicitly does a "make static-assets" at
# the end.
ui:
@sh -c "'$(CURDIR)/scripts/ui.sh'"
# If you've run "make ui" manually then this will get called for you. This is
# also run as part of the release build script when it verifies that there are no
# changes to the UI assets that aren't checked in.
static-assets:
@go-bindata-assetfs -pkg agent -prefix pkg -o agent/bindata_assetfs.go ./pkg/web_ui/...
@go-bindata-assetfs -pkg agent -prefix pkg -o $(ASSETFS_PATH) ./pkg/web_ui/...
$(MAKE) format
# Build the static web ui and build static assets inside a Docker container
ui: ui-legacy-docker ui-docker static-assets-docker
tools:
go get -u -v $(GOTOOLS)
.PHONY: all ci bin dev dist cov test cover format vet ui static-assets tools vendorfmt
version:
@echo -n "Version: "
@$(SHELL) $(CURDIR)/build-support/scripts/version.sh
@echo -n "Version + release: "
@$(SHELL) $(CURDIR)/build-support/scripts/version.sh -r
@echo -n "Version + git: "
@$(SHELL) $(CURDIR)/build-support/scripts/version.sh -g
@echo -n "Version + release + git: "
@$(SHELL) $(CURDIR)/build-support/scripts/version.sh -r -g
docker-images: go-build-image ui-build-image ui-legacy-build-image
go-build-image:
@echo "Building Golang build container"
@docker build $(NOCACHE) $(QUIET) --build-arg 'GOTOOLS=$(GOTOOLS)' -t $(GO_BUILD_TAG) - < build-support/docker/Build-Go.dockerfile
ui-build-image:
@echo "Building UI build container"
@docker build $(NOCACHE) $(QUIET) -t $(UI_BUILD_TAG) - < build-support/docker/Build-UI.dockerfile
ui-legacy-build-image:
@echo "Building Legacy UI build container"
@docker build $(NOCACHE) $(QUIET) -t $(UI_LEGACY_BUILD_TAG) - < build-support/docker/Build-UI-Legacy.dockerfile
static-assets-docker: go-build-image
@$(SHELL) $(CURDIR)/build-support/scripts/build-docker.sh static-assets
consul-docker: go-build-image
@$(SHELL) $(CURDIR)/build-support/scripts/build-docker.sh consul
ui-docker: ui-build-image
@$(SHELL) $(CURDIR)/build-support/scripts/build-docker.sh ui
ui-legacy-docker: ui-legacy-build-image
@$(SHELL) $(CURDIR)/build-support/scripts/build-docker.sh ui-legacy
.PHONY: all ci bin dev dist cov test cover format vet ui static-assets tools vendorfmt
.PHONY: docker-images go-build-image ui-build-image ui-legacy-build-image static-assets-docker consul-docker ui-docker ui-legacy-docker version

View File

@ -25,6 +25,9 @@ Consul provides several key features:
* **Multi-Datacenter** - Consul is built to be datacenter aware, and can
support any number of regions without complex configuration.
* **Service Segmentation** - Consul Connect enables secure service-to-service
communication with automatic TLS encryption and identity-based authorization.
Consul runs on Linux, Mac OS X, FreeBSD, Solaris, and Windows. A commercial
version called [Consul Enterprise](https://www.hashicorp.com/products/consul)
is also available.

View File

@ -60,6 +60,17 @@ type ACL interface {
// EventWrite determines if a specific event may be fired.
EventWrite(string) bool
// IntentionDefaultAllow determines the default authorized behavior
// when no intentions match a Connect request.
IntentionDefaultAllow() bool
// IntentionRead determines if a specific intention can be read.
IntentionRead(string) bool
// IntentionWrite determines if a specific intention can be
// created, modified, or deleted.
IntentionWrite(string) bool
// KeyList checks for permission to list keys under a prefix
KeyList(string) bool
@ -154,6 +165,18 @@ func (s *StaticACL) EventWrite(string) bool {
return s.defaultAllow
}
func (s *StaticACL) IntentionDefaultAllow() bool {
return s.defaultAllow
}
func (s *StaticACL) IntentionRead(string) bool {
return s.defaultAllow
}
func (s *StaticACL) IntentionWrite(string) bool {
return s.defaultAllow
}
func (s *StaticACL) KeyRead(string) bool {
return s.defaultAllow
}
@ -275,6 +298,9 @@ type PolicyACL struct {
// agentRules contains the agent policies
agentRules *radix.Tree
// intentionRules contains the service intention policies
intentionRules *radix.Tree
// keyRules contains the key policies
keyRules *radix.Tree
@ -308,6 +334,7 @@ func New(parent ACL, policy *Policy, sentinel sentinel.Evaluator) (*PolicyACL, e
p := &PolicyACL{
parent: parent,
agentRules: radix.New(),
intentionRules: radix.New(),
keyRules: radix.New(),
nodeRules: radix.New(),
serviceRules: radix.New(),
@ -347,6 +374,25 @@ func New(parent ACL, policy *Policy, sentinel sentinel.Evaluator) (*PolicyACL, e
sentinelPolicy: sp.Sentinel,
}
p.serviceRules.Insert(sp.Name, policyRule)
// Determine the intention. The intention could be blank (not set).
// If the intention is not set, the value depends on the value of
// the service policy.
intention := sp.Intentions
if intention == "" {
switch sp.Policy {
case PolicyRead, PolicyWrite:
intention = PolicyRead
default:
intention = PolicyDeny
}
}
policyRule = PolicyRule{
aclPolicy: intention,
sentinelPolicy: sp.Sentinel,
}
p.intentionRules.Insert(sp.Name, policyRule)
}
// Load the session policy
@ -455,6 +501,51 @@ func (p *PolicyACL) EventWrite(name string) bool {
return p.parent.EventWrite(name)
}
// IntentionDefaultAllow returns whether the default behavior when there are
// no matching intentions is to allow or deny.
func (p *PolicyACL) IntentionDefaultAllow() bool {
// We always go up, this can't be determined by a policy.
return p.parent.IntentionDefaultAllow()
}
// IntentionRead checks if writing (creating, updating, or deleting) of an
// intention is allowed.
func (p *PolicyACL) IntentionRead(prefix string) bool {
// Check for an exact rule or catch-all
_, rule, ok := p.intentionRules.LongestPrefix(prefix)
if ok {
pr := rule.(PolicyRule)
switch pr.aclPolicy {
case PolicyRead, PolicyWrite:
return true
default:
return false
}
}
// No matching rule, use the parent.
return p.parent.IntentionRead(prefix)
}
// IntentionWrite checks if writing (creating, updating, or deleting) of an
// intention is allowed.
func (p *PolicyACL) IntentionWrite(prefix string) bool {
// Check for an exact rule or catch-all
_, rule, ok := p.intentionRules.LongestPrefix(prefix)
if ok {
pr := rule.(PolicyRule)
switch pr.aclPolicy {
case PolicyWrite:
return true
default:
return false
}
}
// No matching rule, use the parent.
return p.parent.IntentionWrite(prefix)
}
// KeyRead returns if a key is allowed to be read
func (p *PolicyACL) KeyRead(key string) bool {
// Look for a matching rule

View File

@ -53,6 +53,12 @@ func TestStaticACL(t *testing.T) {
if !all.EventWrite("foobar") {
t.Fatalf("should allow")
}
if !all.IntentionDefaultAllow() {
t.Fatalf("should allow")
}
if !all.IntentionWrite("foobar") {
t.Fatalf("should allow")
}
if !all.KeyRead("foobar") {
t.Fatalf("should allow")
}
@ -123,6 +129,12 @@ func TestStaticACL(t *testing.T) {
if none.EventWrite("") {
t.Fatalf("should not allow")
}
if none.IntentionDefaultAllow() {
t.Fatalf("should not allow")
}
if none.IntentionWrite("foo") {
t.Fatalf("should not allow")
}
if none.KeyRead("foobar") {
t.Fatalf("should not allow")
}
@ -187,6 +199,12 @@ func TestStaticACL(t *testing.T) {
if !manage.EventWrite("foobar") {
t.Fatalf("should allow")
}
if !manage.IntentionDefaultAllow() {
t.Fatalf("should allow")
}
if !manage.IntentionWrite("foobar") {
t.Fatalf("should allow")
}
if !manage.KeyRead("foobar") {
t.Fatalf("should allow")
}
@ -305,8 +323,14 @@ func TestPolicyACL(t *testing.T) {
Policy: PolicyDeny,
},
&ServicePolicy{
Name: "barfoo",
Policy: PolicyWrite,
Name: "barfoo",
Policy: PolicyWrite,
Intentions: PolicyWrite,
},
&ServicePolicy{
Name: "intbaz",
Policy: PolicyWrite,
Intentions: PolicyDeny,
},
},
}
@ -344,6 +368,31 @@ func TestPolicyACL(t *testing.T) {
}
}
// Test the intentions
type intentioncase struct {
inp string
read bool
write bool
}
icases := []intentioncase{
{"other", true, false},
{"foo", true, false},
{"bar", false, false},
{"foobar", true, false},
{"barfo", false, false},
{"barfoo", true, true},
{"barfoo2", true, true},
{"intbaz", false, false},
}
for _, c := range icases {
if c.read != acl.IntentionRead(c.inp) {
t.Fatalf("Read fail: %#v", c)
}
if c.write != acl.IntentionWrite(c.inp) {
t.Fatalf("Write fail: %#v", c)
}
}
// Test the services
type servicecase struct {
inp string
@ -414,6 +463,11 @@ func TestPolicyACL(t *testing.T) {
t.Fatalf("Prepared query fail: %#v", c)
}
}
// Check default intentions bubble up
if !acl.IntentionDefaultAllow() {
t.Fatal("should allow")
}
}
func TestPolicyACL_Parent(t *testing.T) {
@ -567,6 +621,11 @@ func TestPolicyACL_Parent(t *testing.T) {
if acl.Snapshot() {
t.Fatalf("should not allow")
}
// Check default intentions
if acl.IntentionDefaultAllow() {
t.Fatal("should not allow")
}
}
func TestPolicyACL_Agent(t *testing.T) {

View File

@ -73,6 +73,11 @@ type ServicePolicy struct {
Name string `hcl:",key"`
Policy string
Sentinel Sentinel
// Intentions is the policy for intentions where this service is the
// destination. This may be empty, in which case the Policy determines
// the intentions policy.
Intentions string
}
func (s *ServicePolicy) GoString() string {
@ -197,6 +202,9 @@ func Parse(rules string, sentinel sentinel.Evaluator) (*Policy, error) {
if !isPolicyValid(sp.Policy) {
return nil, fmt.Errorf("Invalid service policy: %#v", sp)
}
if sp.Intentions != "" && !isPolicyValid(sp.Intentions) {
return nil, fmt.Errorf("Invalid service intentions policy: %#v", sp)
}
if err := isSentinelValid(sentinel, sp.Policy, sp.Sentinel); err != nil {
return nil, fmt.Errorf("Invalid service Sentinel policy: %#v, got error:%v", sp, err)
}

View File

@ -4,8 +4,85 @@ import (
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParse_table(t *testing.T) {
// Note that the table tests are newer than other tests. Many of the
// other aspects of policy parsing are tested in older tests below. New
// parsing tests should be added to this table as its easier to maintain.
cases := []struct {
Name string
Input string
Expected *Policy
Err string
}{
{
"service no intentions",
`
service "foo" {
policy = "write"
}
`,
&Policy{
Services: []*ServicePolicy{
{
Name: "foo",
Policy: "write",
},
},
},
"",
},
{
"service intentions",
`
service "foo" {
policy = "write"
intentions = "read"
}
`,
&Policy{
Services: []*ServicePolicy{
{
Name: "foo",
Policy: "write",
Intentions: "read",
},
},
},
"",
},
{
"service intention: invalid value",
`
service "foo" {
policy = "write"
intentions = "foo"
}
`,
nil,
"service intentions",
},
}
for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
assert := assert.New(t)
actual, err := Parse(tc.Input, nil)
assert.Equal(tc.Err != "", err != nil, err)
if err != nil {
assert.Contains(err.Error(), tc.Err)
return
}
assert.Equal(tc.Expected, actual)
})
}
}
func TestACLPolicy_Parse_HCL(t *testing.T) {
inp := `
agent "foo" {

View File

@ -8,6 +8,7 @@ import (
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/local"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/golang-lru"
@ -239,6 +240,18 @@ func (a *Agent) resolveToken(id string) (acl.ACL, error) {
return a.acls.lookupACL(a, id)
}
// resolveProxyToken attempts to resolve an ACL ID to a local proxy token.
// If a local proxy isn't found with that token, nil is returned.
func (a *Agent) resolveProxyToken(id string) *local.ManagedProxy {
for _, p := range a.State.Proxies() {
if p.ProxyToken == id {
return p
}
}
return nil
}
// vetServiceRegister makes sure the service registration action is allowed by
// the given token.
func (a *Agent) vetServiceRegister(token string, service *structs.NodeService) error {

View File

@ -21,16 +21,20 @@ import (
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/ae"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/checks"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/consul"
"github.com/hashicorp/consul/agent/local"
"github.com/hashicorp/consul/agent/proxy"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/systemd"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/lib/file"
"github.com/hashicorp/consul/logger"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/consul/watch"
@ -46,6 +50,9 @@ const (
// Path to save agent service definitions
servicesDir = "services"
// Path to save agent proxy definitions
proxyDir = "proxies"
// Path to save local agent checks
checksDir = "checks"
checkStateDir = "checks/state"
@ -73,6 +80,7 @@ type delegate interface {
SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error
Shutdown() error
Stats() map[string]map[string]string
ReloadConfig(config *consul.Config) error
enterpriseDelegate
}
@ -118,6 +126,9 @@ type Agent struct {
// and the remote state.
sync *ae.StateSyncer
// cache is the in-memory cache for data the Agent requests.
cache *cache.Cache
// checkReapAfter maps the check ID to a timeout after which we should
// reap its associated service
checkReapAfter map[types.CheckID]time.Duration
@ -194,6 +205,9 @@ type Agent struct {
// be updated at runtime, so should always be used instead of going to
// the configuration directly.
tokens *token.Store
// proxyManager is the proxy process manager for managed Connect proxies.
proxyManager *proxy.Manager
}
func New(c *config.RuntimeConfig) (*Agent, error) {
@ -246,6 +260,8 @@ func LocalConfig(cfg *config.RuntimeConfig) local.Config {
NodeID: cfg.NodeID,
NodeName: cfg.NodeName,
TaggedAddresses: map[string]string{},
ProxyBindMinPort: cfg.ConnectProxyBindMinPort,
ProxyBindMaxPort: cfg.ConnectProxyBindMaxPort,
}
for k, v := range cfg.TaggedAddresses {
lc.TaggedAddresses[k] = v
@ -288,6 +304,9 @@ func (a *Agent) Start() error {
// regular and on-demand state synchronizations (anti-entropy).
a.sync = ae.NewStateSyncer(a.State, c.AEInterval, a.shutdownCh, a.logger)
// create the cache
a.cache = cache.New(nil)
// create the config for the rpc server/client
consulCfg, err := a.consulConfig()
if err != nil {
@ -324,10 +343,17 @@ func (a *Agent) Start() error {
a.State.Delegate = a.delegate
a.State.TriggerSyncChanges = a.sync.SyncChanges.Trigger
// Register the cache. We do this much later so the delegate is
// populated from above.
a.registerCache()
// Load checks/services/metadata.
if err := a.loadServices(c); err != nil {
return err
}
if err := a.loadProxies(c); err != nil {
return err
}
if err := a.loadChecks(c); err != nil {
return err
}
@ -335,6 +361,28 @@ func (a *Agent) Start() error {
return err
}
// create the proxy process manager and start it. This is purposely
// done here after the local state above is loaded in so we can have
// a more accurate initial state view.
if !c.ConnectTestDisableManagedProxies {
a.proxyManager = proxy.NewManager()
a.proxyManager.AllowRoot = a.config.ConnectProxyAllowManagedRoot
a.proxyManager.State = a.State
a.proxyManager.Logger = a.logger
if a.config.DataDir != "" {
// DataDir is required for all non-dev mode agents, but we want
// to allow setting the data dir for demos and so on for the agent,
// so do the check above instead.
a.proxyManager.DataDir = filepath.Join(a.config.DataDir, "proxy")
// Restore from our snapshot (if it exists)
if err := a.proxyManager.Restore(a.proxyManager.SnapshotPath()); err != nil {
a.logger.Printf("[WARN] agent: error restoring proxy state: %s", err)
}
}
go a.proxyManager.Run()
}
// Start watching for critical services to deregister, based on their
// checks.
go a.reapServices()
@ -604,6 +652,16 @@ func (a *Agent) reloadWatches(cfg *config.RuntimeConfig) error {
return fmt.Errorf("Handler type '%s' not recognized", params["handler_type"])
}
// Don't let people use connect watches via this mechanism for now as it
// needs thought about how to do securely and shouldn't be necessary. Note
// that if the type assertion fails an type is not a string then
// ParseExample below will error so we don't need to handle that case.
if typ, ok := params["type"].(string); ok {
if strings.HasPrefix(typ, "connect_") {
return fmt.Errorf("Watch type %s is not allowed in agent config", typ)
}
}
// Parse the watches, excluding 'handler' and 'args'
wp, err := watch.ParseExempt(params, []string{"handler", "args"})
if err != nil {
@ -877,6 +935,47 @@ func (a *Agent) consulConfig() (*consul.Config, error) {
base.TLSCipherSuites = a.config.TLSCipherSuites
base.TLSPreferServerCipherSuites = a.config.TLSPreferServerCipherSuites
// Copy the Connect CA bootstrap config
if a.config.ConnectEnabled {
base.ConnectEnabled = true
// Allow config to specify cluster_id provided it's a valid UUID. This is
// meant only for tests where a deterministic ID makes fixtures much simpler
// to work with but since it's only read on initial cluster bootstrap it's not
// that much of a liability in production. The worst a user could do is
// configure logically separate clusters with same ID by mistake but we can
// avoid documenting this is even an option.
if clusterID, ok := a.config.ConnectCAConfig["cluster_id"]; ok {
if cIDStr, ok := clusterID.(string); ok {
if _, err := uuid.ParseUUID(cIDStr); err == nil {
// Valid UUID configured, use that
base.CAConfig.ClusterID = cIDStr
}
}
if base.CAConfig.ClusterID == "" {
// If the tried to specify an ID but typoed it don't ignore as they will
// then bootstrap with a new ID and have to throw away the whole cluster
// and start again.
a.logger.Println("[ERR] connect CA config cluster_id specified but " +
"is not a valid UUID, aborting startup")
return nil, fmt.Errorf("cluster_id was supplied but was not a valid UUID")
}
}
if a.config.ConnectCAProvider != "" {
base.CAConfig.Provider = a.config.ConnectCAProvider
// Merge with the default config if it's the consul provider.
if a.config.ConnectCAProvider == "consul" {
for k, v := range a.config.ConnectCAConfig {
base.CAConfig.Config[k] = v
}
} else {
base.CAConfig.Config = a.config.ConnectCAConfig
}
}
}
// Setup the user event callback
base.UserEventHandler = func(e serf.UserEvent) {
select {
@ -1009,7 +1108,7 @@ func (a *Agent) setupNodeID(config *config.RuntimeConfig) error {
}
// For dev mode we have no filesystem access so just make one.
if a.config.DevMode {
if a.config.DataDir == "" {
id, err := a.makeNodeID()
if err != nil {
return err
@ -1223,6 +1322,24 @@ func (a *Agent) ShutdownAgent() error {
chk.Stop()
}
// Stop the proxy manager
if a.proxyManager != nil {
// If persistence is disabled (implies DevMode but a subset of DevMode) then
// don't leave the proxies running since the agent will not be able to
// recover them later.
if a.config.DataDir == "" {
a.logger.Printf("[WARN] agent: dev mode disabled persistence, killing " +
"all proxies since we can't recover them")
if err := a.proxyManager.Kill(); err != nil {
a.logger.Printf("[WARN] agent: error shutting down proxy manager: %s", err)
}
} else {
if err := a.proxyManager.Close(); err != nil {
a.logger.Printf("[WARN] agent: error shutting down proxy manager: %s", err)
}
}
}
var err error
if a.delegate != nil {
err = a.delegate.Shutdown()
@ -1492,7 +1609,7 @@ func (a *Agent) persistService(service *structs.NodeService) error {
return err
}
return writeFileAtomic(svcPath, encoded)
return file.WriteAtomic(svcPath, encoded)
}
// purgeService removes a persisted service definition file from the data dir
@ -1504,6 +1621,39 @@ func (a *Agent) purgeService(serviceID string) error {
return nil
}
// persistedProxy is used to wrap a proxy definition and bundle it with an Proxy
// token so we can continue to authenticate the running proxy after a restart.
type persistedProxy struct {
ProxyToken string
Proxy *structs.ConnectManagedProxy
}
// persistProxy saves a proxy definition to a JSON file in the data dir
func (a *Agent) persistProxy(proxy *local.ManagedProxy) error {
proxyPath := filepath.Join(a.config.DataDir, proxyDir,
stringHash(proxy.Proxy.ProxyService.ID))
wrapped := persistedProxy{
ProxyToken: proxy.ProxyToken,
Proxy: proxy.Proxy,
}
encoded, err := json.Marshal(wrapped)
if err != nil {
return err
}
return file.WriteAtomic(proxyPath, encoded)
}
// purgeProxy removes a persisted proxy definition file from the data dir
func (a *Agent) purgeProxy(proxyID string) error {
proxyPath := filepath.Join(a.config.DataDir, proxyDir, stringHash(proxyID))
if _, err := os.Stat(proxyPath); err == nil {
return os.Remove(proxyPath)
}
return nil
}
// persistCheck saves a check definition to the local agent's state directory
func (a *Agent) persistCheck(check *structs.HealthCheck, chkType *structs.CheckType) error {
checkPath := filepath.Join(a.config.DataDir, checksDir, checkIDHash(check.CheckID))
@ -1520,7 +1670,7 @@ func (a *Agent) persistCheck(check *structs.HealthCheck, chkType *structs.CheckT
return err
}
return writeFileAtomic(checkPath, encoded)
return file.WriteAtomic(checkPath, encoded)
}
// purgeCheck removes a persisted check definition file from the data dir
@ -1532,43 +1682,6 @@ func (a *Agent) purgeCheck(checkID types.CheckID) error {
return nil
}
// writeFileAtomic writes the given contents to a temporary file in the same
// directory, does an fsync and then renames the file to its real path
func writeFileAtomic(path string, contents []byte) error {
uuid, err := uuid.GenerateUUID()
if err != nil {
return err
}
tempPath := fmt.Sprintf("%s-%s.tmp", path, uuid)
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return err
}
fh, err := os.OpenFile(tempPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return err
}
if _, err := fh.Write(contents); err != nil {
fh.Close()
os.Remove(tempPath)
return err
}
if err := fh.Sync(); err != nil {
fh.Close()
os.Remove(tempPath)
return err
}
if err := fh.Close(); err != nil {
os.Remove(tempPath)
return err
}
if err := os.Rename(tempPath, path); err != nil {
os.Remove(tempPath)
return err
}
return nil
}
// AddService is used to add a service entry.
// This entry is persistent and the agent will make a best effort to
// ensure it is registered
@ -1622,7 +1735,7 @@ func (a *Agent) AddService(service *structs.NodeService, chkTypes []*structs.Che
a.State.AddService(service, token)
// Persist the service to a file
if persist && !a.config.DevMode {
if persist && a.config.DataDir != "" {
if err := a.persistService(service); err != nil {
return err
}
@ -1921,7 +2034,7 @@ func (a *Agent) AddCheck(check *structs.HealthCheck, chkType *structs.CheckType,
}
// Persist the check
if persist && !a.config.DevMode {
if persist && a.config.DataDir != "" {
return a.persistCheck(check, chkType)
}
@ -1973,6 +2086,277 @@ func (a *Agent) RemoveCheck(checkID types.CheckID, persist bool) error {
return nil
}
// AddProxy adds a new local Connect Proxy instance to be managed by the agent.
//
// It REQUIRES that the service that is being proxied is already present in the
// local state. Note that this is only used for agent-managed proxies so we can
// ensure that we always make this true. For externally managed and registered
// proxies we explicitly allow the proxy to be registered first to make
// bootstrap ordering of a new service simpler but the same is not true here
// since this is only ever called when setting up a _managed_ proxy which was
// registered as part of a service registration either from config or HTTP API
// call.
//
// The restoredProxyToken argument should only be used when restoring proxy
// definitions from disk; new proxies must leave it blank to get a new token
// assigned. We need to restore from disk to enable to continue authenticating
// running proxies that already had that credential injected.
func (a *Agent) AddProxy(proxy *structs.ConnectManagedProxy, persist bool,
restoredProxyToken string) error {
// Lookup the target service token in state if there is one.
token := a.State.ServiceToken(proxy.TargetServiceID)
// Copy the basic proxy structure so it isn't modified w/ defaults
proxyCopy := *proxy
proxy = &proxyCopy
if err := a.applyProxyDefaults(proxy); err != nil {
return err
}
// Add the proxy to local state first since we may need to assign a port which
// needs to be coordinate under state lock. AddProxy will generate the
// NodeService for the proxy populated with the allocated (or configured) port
// and an ID, but it doesn't add it to the agent directly since that could
// deadlock and we may need to coordinate adding it and persisting etc.
proxyState, err := a.State.AddProxy(proxy, token, restoredProxyToken)
if err != nil {
return err
}
proxyService := proxyState.Proxy.ProxyService
// Register proxy TCP check. The built in proxy doesn't listen publically
// until it's loaded certs so this ensures we won't route traffic until it's
// ready.
proxyCfg, err := a.applyProxyConfigDefaults(proxyState.Proxy)
if err != nil {
return err
}
chkTypes := []*structs.CheckType{
&structs.CheckType{
Name: "Connect Proxy Listening",
TCP: fmt.Sprintf("%s:%d", proxyCfg["bind_address"],
proxyCfg["bind_port"]),
Interval: 10 * time.Second,
},
}
err = a.AddService(proxyService, chkTypes, persist, token)
if err != nil {
// Remove the state too
a.State.RemoveProxy(proxyService.ID)
return err
}
// Persist the proxy
if persist && a.config.DataDir != "" {
return a.persistProxy(proxyState)
}
return nil
}
// applyProxyConfigDefaults takes a *structs.ConnectManagedProxy and returns
// it's Config map merged with any defaults from the Agent's config. It would be
// nicer if this were defined as a method on structs.ConnectManagedProxy but we
// can't do that because ot the import cycle it causes with agent/config.
func (a *Agent) applyProxyConfigDefaults(p *structs.ConnectManagedProxy) (map[string]interface{}, error) {
if p == nil || p.ProxyService == nil {
// Should never happen but protect from panic
return nil, fmt.Errorf("invalid proxy state")
}
// Lookup the target service
target := a.State.Service(p.TargetServiceID)
if target == nil {
// Can happen during deregistration race between proxy and scheduler.
return nil, fmt.Errorf("unknown target service ID: %s", p.TargetServiceID)
}
// Merge globals defaults
config := make(map[string]interface{})
for k, v := range a.config.ConnectProxyDefaultConfig {
if _, ok := config[k]; !ok {
config[k] = v
}
}
// Copy config from the proxy
for k, v := range p.Config {
config[k] = v
}
// Set defaults for anything that is still not specified but required.
// Note that these are not included in the content hash. Since we expect
// them to be static in general but some like the default target service
// port might not be. In that edge case services can set that explicitly
// when they re-register which will be caught though.
if _, ok := config["bind_port"]; !ok {
config["bind_port"] = p.ProxyService.Port
}
if _, ok := config["bind_address"]; !ok {
// Default to binding to the same address the agent is configured to
// bind to.
config["bind_address"] = a.config.BindAddr.String()
}
if _, ok := config["local_service_address"]; !ok {
// Default to localhost and the port the service registered with
config["local_service_address"] = fmt.Sprintf("127.0.0.1:%d", target.Port)
}
// Basic type conversions for expected types.
if raw, ok := config["bind_port"]; ok {
switch v := raw.(type) {
case float64:
// Common since HCL/JSON parse as float64
config["bind_port"] = int(v)
// NOTE(mitchellh): No default case since errors and validation
// are handled by the ServiceDefinition.Validate function.
}
}
return config, nil
}
// applyProxyDefaults modifies the given proxy by applying any configured
// defaults, such as the default execution mode, command, etc.
func (a *Agent) applyProxyDefaults(proxy *structs.ConnectManagedProxy) error {
// Set the default exec mode
if proxy.ExecMode == structs.ProxyExecModeUnspecified {
mode, err := structs.NewProxyExecMode(a.config.ConnectProxyDefaultExecMode)
if err != nil {
return err
}
proxy.ExecMode = mode
}
if proxy.ExecMode == structs.ProxyExecModeUnspecified {
proxy.ExecMode = structs.ProxyExecModeDaemon
}
// Set the default command to the globally configured default
if len(proxy.Command) == 0 {
switch proxy.ExecMode {
case structs.ProxyExecModeDaemon:
proxy.Command = a.config.ConnectProxyDefaultDaemonCommand
case structs.ProxyExecModeScript:
proxy.Command = a.config.ConnectProxyDefaultScriptCommand
}
}
// If there is no globally configured default we need to get the
// default command so we can do "consul connect proxy"
if len(proxy.Command) == 0 {
command, err := defaultProxyCommand(a.config)
if err != nil {
return err
}
proxy.Command = command
}
return nil
}
// RemoveProxy stops and removes a local proxy instance.
func (a *Agent) RemoveProxy(proxyID string, persist bool) error {
// Validate proxyID
if proxyID == "" {
return fmt.Errorf("proxyID missing")
}
// Remove the proxy from the local state
p, err := a.State.RemoveProxy(proxyID)
if err != nil {
return err
}
// Remove the proxy service as well. The proxy ID is also the ID
// of the servie, but we might as well use the service pointer.
if err := a.RemoveService(p.Proxy.ProxyService.ID, persist); err != nil {
return err
}
if persist && a.config.DataDir != "" {
return a.purgeProxy(proxyID)
}
return nil
}
// verifyProxyToken takes a token and attempts to verify it against the
// targetService name. If targetProxy is specified, then the local proxy token
// must exactly match the given proxy ID. cert, config, etc.).
//
// The given token may be a local-only proxy token or it may be an ACL token. We
// will attempt to verify the local proxy token first.
//
// The effective ACL token is returned along with a boolean which is true if the
// match was against a proxy token rather than an ACL token, and any error. In
// the case the token matches a proxy token, then the ACL token used to register
// that proxy's target service is returned for use in any RPC calls the proxy
// needs to make on behalf of that service. If the token was an ACL token
// already then it is always returned. Provided error is nil, a valid ACL token
// is always returned.
func (a *Agent) verifyProxyToken(token, targetService,
targetProxy string) (string, bool, error) {
// If we specify a target proxy, we look up that proxy directly. Otherwise,
// we resolve with any proxy we can find.
var proxy *local.ManagedProxy
if targetProxy != "" {
proxy = a.State.Proxy(targetProxy)
if proxy == nil {
return "", false, fmt.Errorf("unknown proxy service ID: %q", targetProxy)
}
// If the token DOESN'T match, then we reset the proxy which will
// cause the logic below to fall back to normal ACLs. Otherwise,
// we keep the proxy set because we also have to verify that the
// target service matches on the proxy.
if token != proxy.ProxyToken {
proxy = nil
}
} else {
proxy = a.resolveProxyToken(token)
}
// The existence of a token isn't enough, we also need to verify
// that the service name of the matching proxy matches our target
// service.
if proxy != nil {
// Get the target service since we only have the name. The nil
// check below should never be true since a proxy token always
// represents the existence of a local service.
target := a.State.Service(proxy.Proxy.TargetServiceID)
if target == nil {
return "", false, fmt.Errorf("proxy target service not found: %q",
proxy.Proxy.TargetServiceID)
}
if target.Service != targetService {
return "", false, acl.ErrPermissionDenied
}
// Resolve the actual ACL token used to register the proxy/service and
// return that for use in RPC calls.
return a.State.ServiceToken(proxy.Proxy.TargetServiceID), true, nil
}
// Doesn't match, we have to do a full token resolution. The required
// permission for any proxy-related endpoint is service:write, since
// to register a proxy you require that permission and sensitive data
// is usually present in the configuration.
rule, err := a.resolveToken(token)
if err != nil {
return "", false, err
}
if rule != nil && !rule.ServiceWrite(targetService, nil) {
return "", false, acl.ErrPermissionDenied
}
return token, false, nil
}
func (a *Agent) cancelCheckMonitors(checkID types.CheckID) {
// Stop any monitors
delete(a.checkReapAfter, checkID)
@ -2017,7 +2401,7 @@ func (a *Agent) updateTTLCheck(checkID types.CheckID, status, output string) err
check.SetStatus(status, output)
// We don't write any files in dev mode so bail here.
if a.config.DevMode {
if a.config.DataDir == "" {
return nil
}
@ -2366,6 +2750,96 @@ func (a *Agent) unloadChecks() error {
return nil
}
// loadProxies will load connect proxy definitions from configuration and
// persisted definitions on disk, and load them into the local agent.
func (a *Agent) loadProxies(conf *config.RuntimeConfig) error {
for _, svc := range conf.Services {
if svc.Connect != nil {
proxy, err := svc.ConnectManagedProxy()
if err != nil {
return fmt.Errorf("failed adding proxy: %s", err)
}
if proxy == nil {
continue
}
if err := a.AddProxy(proxy, false, ""); err != nil {
return fmt.Errorf("failed adding proxy: %s", err)
}
}
}
// Load any persisted proxies
proxyDir := filepath.Join(a.config.DataDir, proxyDir)
files, err := ioutil.ReadDir(proxyDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("Failed reading proxies dir %q: %s", proxyDir, err)
}
for _, fi := range files {
// Skip all dirs
if fi.IsDir() {
continue
}
// Skip all partially written temporary files
if strings.HasSuffix(fi.Name(), "tmp") {
a.logger.Printf("[WARN] agent: Ignoring temporary proxy file %v", fi.Name())
continue
}
// Open the file for reading
file := filepath.Join(proxyDir, fi.Name())
fh, err := os.Open(file)
if err != nil {
return fmt.Errorf("failed opening proxy file %q: %s", file, err)
}
// Read the contents into a buffer
buf, err := ioutil.ReadAll(fh)
fh.Close()
if err != nil {
return fmt.Errorf("failed reading proxy file %q: %s", file, err)
}
// Try decoding the proxy definition
var p persistedProxy
if err := json.Unmarshal(buf, &p); err != nil {
a.logger.Printf("[ERR] agent: Failed decoding proxy file %q: %s", file, err)
continue
}
proxyID := p.Proxy.ProxyService.ID
if a.State.Proxy(proxyID) != nil {
// Purge previously persisted proxy. This allows config to be preferred
// over services persisted from the API.
a.logger.Printf("[DEBUG] agent: proxy %q exists, not restoring from %q",
proxyID, file)
if err := a.purgeProxy(proxyID); err != nil {
return fmt.Errorf("failed purging proxy %q: %s", proxyID, err)
}
} else {
a.logger.Printf("[DEBUG] agent: restored proxy definition %q from %q",
proxyID, file)
if err := a.AddProxy(p.Proxy, false, p.ProxyToken); err != nil {
return fmt.Errorf("failed adding proxy %q: %s", proxyID, err)
}
}
}
return nil
}
// unloadProxies will deregister all proxies known to the local agent.
func (a *Agent) unloadProxies() error {
for id := range a.State.Proxies() {
if err := a.RemoveProxy(id, false); err != nil {
return fmt.Errorf("Failed deregistering proxy '%s': %s", id, err)
}
}
return nil
}
// snapshotCheckState is used to snapshot the current state of the health
// checks. This is done before we reload our checks, so that we can properly
// restore into the same state.
@ -2491,6 +2965,11 @@ func (a *Agent) DisableNodeMaintenance() {
a.logger.Printf("[INFO] agent: Node left maintenance mode")
}
func (a *Agent) loadLimits(conf *config.RuntimeConfig) {
a.config.RPCRateLimit = conf.RPCRateLimit
a.config.RPCMaxBurst = conf.RPCMaxBurst
}
func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
// Bulk update the services and checks
a.PauseSync()
@ -2502,6 +2981,9 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
// First unload all checks, services, and metadata. This lets us begin the reload
// with a clean slate.
if err := a.unloadProxies(); err != nil {
return fmt.Errorf("Failed unloading proxies: %s", err)
}
if err := a.unloadServices(); err != nil {
return fmt.Errorf("Failed unloading services: %s", err)
}
@ -2514,6 +2996,9 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
if err := a.loadServices(newCfg); err != nil {
return fmt.Errorf("Failed reloading services: %s", err)
}
if err := a.loadProxies(newCfg); err != nil {
return fmt.Errorf("Failed reloading proxies: %s", err)
}
if err := a.loadChecks(newCfg); err != nil {
return fmt.Errorf("Failed reloading checks: %s", err)
}
@ -2525,10 +3010,75 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
return fmt.Errorf("Failed reloading watches: %v", err)
}
a.loadLimits(newCfg)
// create the config for the rpc server/client
consulCfg, err := a.consulConfig()
if err != nil {
return err
}
if err := a.delegate.ReloadConfig(consulCfg); err != nil {
return err
}
// Update filtered metrics
metrics.UpdateFilter(newCfg.TelemetryAllowedPrefixes, newCfg.TelemetryBlockedPrefixes)
metrics.UpdateFilter(newCfg.Telemetry.AllowedPrefixes,
newCfg.Telemetry.BlockedPrefixes)
a.State.SetDiscardCheckOutput(newCfg.DiscardCheckOutput)
return nil
}
// registerCache configures the cache and registers all the supported
// types onto the cache. This is NOT safe to call multiple times so
// care should be taken to call this exactly once after the cache
// field has been initialized.
func (a *Agent) registerCache() {
a.cache.RegisterType(cachetype.ConnectCARootName, &cachetype.ConnectCARoot{
RPC: a.delegate,
}, &cache.RegisterOptions{
// Maintain a blocking query, retry dropped connections quickly
Refresh: true,
RefreshTimer: 0 * time.Second,
RefreshTimeout: 10 * time.Minute,
})
a.cache.RegisterType(cachetype.ConnectCALeafName, &cachetype.ConnectCALeaf{
RPC: a.delegate,
Cache: a.cache,
}, &cache.RegisterOptions{
// Maintain a blocking query, retry dropped connections quickly
Refresh: true,
RefreshTimer: 0 * time.Second,
RefreshTimeout: 10 * time.Minute,
})
a.cache.RegisterType(cachetype.IntentionMatchName, &cachetype.IntentionMatch{
RPC: a.delegate,
}, &cache.RegisterOptions{
// Maintain a blocking query, retry dropped connections quickly
Refresh: true,
RefreshTimer: 0 * time.Second,
RefreshTimeout: 10 * time.Minute,
})
}
// defaultProxyCommand returns the default Connect managed proxy command.
func defaultProxyCommand(agentCfg *config.RuntimeConfig) ([]string, error) {
// Get the path to the current exectuable. This is cached once by the
// library so this is effectively just a variable read.
execPath, err := os.Executable()
if err != nil {
return nil, err
}
// "consul connect proxy" default value for managed daemon proxy
cmd := []string{execPath, "connect", "proxy"}
if agentCfg != nil && agentCfg.LogLevel != "INFO" {
cmd = append(cmd, "-log-level", agentCfg.LogLevel)
}
return cmd, nil
}

View File

@ -4,12 +4,21 @@ import (
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/mitchellh/mapstructure"
"github.com/hashicorp/go-memdb"
"github.com/mitchellh/hashstructure"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/checks"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/ipaddr"
@ -97,7 +106,7 @@ func (s *HTTPServer) AgentMetrics(resp http.ResponseWriter, req *http.Request) (
return nil, acl.ErrPermissionDenied
}
if enablePrometheusOutput(req) {
if s.agent.config.TelemetryPrometheusRetentionTime < 1 {
if s.agent.config.Telemetry.PrometheusRetentionTime < 1 {
resp.WriteHeader(http.StatusUnsupportedMediaType)
fmt.Fprint(resp, "Prometheus is not enable since its retention time is not positive")
return nil, nil
@ -153,25 +162,49 @@ func (s *HTTPServer) AgentServices(resp http.ResponseWriter, req *http.Request)
return nil, err
}
proxies := s.agent.State.Proxies()
// Convert into api.AgentService since that includes Connect config but so far
// NodeService doesn't need to internally. They are otherwise identical since
// that is the struct used in client for reading the one we output here
// anyway.
agentSvcs := make(map[string]*api.AgentService)
// Use empty list instead of nil
for id, s := range services {
if s.Tags == nil || s.Meta == nil {
clone := *s
if s.Tags == nil {
clone.Tags = make([]string, 0)
} else {
clone.Tags = s.Tags
}
if s.Meta == nil {
clone.Meta = make(map[string]string)
} else {
clone.Meta = s.Meta
}
services[id] = &clone
as := &api.AgentService{
Kind: api.ServiceKind(s.Kind),
ID: s.ID,
Service: s.Service,
Tags: s.Tags,
Meta: s.Meta,
Port: s.Port,
Address: s.Address,
EnableTagOverride: s.EnableTagOverride,
CreateIndex: s.CreateIndex,
ModifyIndex: s.ModifyIndex,
ProxyDestination: s.ProxyDestination,
}
if as.Tags == nil {
as.Tags = []string{}
}
if as.Meta == nil {
as.Meta = map[string]string{}
}
// Attach Connect configs if the exist
if proxy, ok := proxies[id+"-proxy"]; ok {
as.Connect = &api.AgentServiceConnect{
Proxy: &api.AgentServiceConnectProxy{
ExecMode: api.ProxyExecMode(proxy.Proxy.ExecMode.String()),
Command: proxy.Proxy.Command,
Config: proxy.Proxy.Config,
},
}
}
agentSvcs[id] = as
}
return services, nil
return agentSvcs, nil
}
func (s *HTTPServer) AgentChecks(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
@ -554,6 +587,14 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
return nil, nil
}
// Run validation. This is the same validation that would happen on
// the catalog endpoint so it helps ensure the sync will work properly.
if err := ns.Validate(); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, err.Error())
return nil, nil
}
// Verify the check type.
chkTypes, err := args.CheckTypes()
if err != nil {
@ -576,10 +617,30 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
return nil, err
}
// Get any proxy registrations
proxy, err := args.ConnectManagedProxy()
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, err.Error())
return nil, nil
}
// If we have a proxy, verify that we're allowed to add a proxy via the API
if proxy != nil && !s.agent.config.ConnectProxyAllowManagedAPIRegistration {
return nil, &BadRequestError{
Reason: "Managed proxy registration via the API is disallowed."}
}
// Add the service.
if err := s.agent.AddService(ns, chkTypes, true, token); err != nil {
return nil, err
}
// Add proxy (which will add proxy service so do it before we trigger sync)
if proxy != nil {
if err := s.agent.AddProxy(proxy, true, ""); err != nil {
return nil, err
}
}
s.syncChanges()
return nil, nil
}
@ -594,9 +655,27 @@ func (s *HTTPServer) AgentDeregisterService(resp http.ResponseWriter, req *http.
return nil, err
}
// Verify this isn't a proxy
if s.agent.State.Proxy(serviceID) != nil {
return nil, &BadRequestError{
Reason: "Managed proxy service cannot be deregistered directly. " +
"Deregister the service that has a managed proxy to automatically " +
"deregister the managed proxy itself."}
}
if err := s.agent.RemoveService(serviceID, true); err != nil {
return nil, err
}
// Remove the associated managed proxy if it exists
for proxyID, p := range s.agent.State.Proxies() {
if p.Proxy.TargetServiceID == serviceID {
if err := s.agent.RemoveProxy(proxyID, true); err != nil {
return nil, err
}
}
}
s.syncChanges()
return nil, nil
}
@ -828,3 +907,394 @@ func (s *HTTPServer) AgentToken(resp http.ResponseWriter, req *http.Request) (in
s.agent.logger.Printf("[INFO] agent: Updated agent's ACL token %q", target)
return nil, nil
}
// AgentConnectCARoots returns the trusted CA roots.
func (s *HTTPServer) AgentConnectCARoots(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var args structs.DCSpecificRequest
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
raw, m, err := s.agent.cache.Get(cachetype.ConnectCARootName, &args)
if err != nil {
return nil, err
}
defer setCacheMeta(resp, &m)
// Add cache hit
reply, ok := raw.(*structs.IndexedCARoots)
if !ok {
// This should never happen, but we want to protect against panics
return nil, fmt.Errorf("internal error: response type not correct")
}
defer setMeta(resp, &reply.QueryMeta)
return *reply, nil
}
// AgentConnectCALeafCert returns the certificate bundle for a service
// instance. This supports blocking queries to update the returned bundle.
func (s *HTTPServer) AgentConnectCALeafCert(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Get the service name. Note that this is the name of the sevice,
// not the ID of the service instance.
serviceName := strings.TrimPrefix(req.URL.Path, "/v1/agent/connect/ca/leaf/")
args := cachetype.ConnectCALeafRequest{
Service: serviceName, // Need name not ID
}
var qOpts structs.QueryOptions
// Store DC in the ConnectCALeafRequest but query opts separately
if done := s.parse(resp, req, &args.Datacenter, &qOpts); done {
return nil, nil
}
args.MinQueryIndex = qOpts.MinQueryIndex
// Verify the proxy token. This will check both the local proxy token
// as well as the ACL if the token isn't local.
effectiveToken, _, err := s.agent.verifyProxyToken(qOpts.Token, serviceName, "")
if err != nil {
return nil, err
}
args.Token = effectiveToken
raw, m, err := s.agent.cache.Get(cachetype.ConnectCALeafName, &args)
if err != nil {
return nil, err
}
defer setCacheMeta(resp, &m)
reply, ok := raw.(*structs.IssuedCert)
if !ok {
// This should never happen, but we want to protect against panics
return nil, fmt.Errorf("internal error: response type not correct")
}
setIndex(resp, reply.ModifyIndex)
return reply, nil
}
// GET /v1/agent/connect/proxy/:proxy_service_id
//
// Returns the local proxy config for the identified proxy. Requires token=
// param with the correct local ProxyToken (not ACL token).
func (s *HTTPServer) AgentConnectProxyConfig(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Get the proxy ID. Note that this is the ID of a proxy's service instance.
id := strings.TrimPrefix(req.URL.Path, "/v1/agent/connect/proxy/")
// Maybe block
var queryOpts structs.QueryOptions
if parseWait(resp, req, &queryOpts) {
// parseWait returns an error itself
return nil, nil
}
// Parse the token
var token string
s.parseToken(req, &token)
// Parse hash specially since it's only this endpoint that uses it currently.
// Eventually this should happen in parseWait and end up in QueryOptions but I
// didn't want to make very general changes right away.
hash := req.URL.Query().Get("hash")
return s.agentLocalBlockingQuery(resp, hash, &queryOpts,
func(ws memdb.WatchSet) (string, interface{}, error) {
// Retrieve the proxy specified
proxy := s.agent.State.Proxy(id)
if proxy == nil {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprintf(resp, "unknown proxy service ID: %s", id)
return "", nil, nil
}
// Lookup the target service as a convenience
target := s.agent.State.Service(proxy.Proxy.TargetServiceID)
if target == nil {
// Not found since this endpoint is only useful for agent-managed proxies so
// service missing means the service was deregistered racily with this call.
resp.WriteHeader(http.StatusNotFound)
fmt.Fprintf(resp, "unknown target service ID: %s", proxy.Proxy.TargetServiceID)
return "", nil, nil
}
// Validate the ACL token
_, isProxyToken, err := s.agent.verifyProxyToken(token, target.Service, id)
if err != nil {
return "", nil, err
}
// Watch the proxy for changes
ws.Add(proxy.WatchCh)
hash, err := hashstructure.Hash(proxy.Proxy, nil)
if err != nil {
return "", nil, err
}
contentHash := fmt.Sprintf("%x", hash)
// Set defaults
config, err := s.agent.applyProxyConfigDefaults(proxy.Proxy)
if err != nil {
return "", nil, err
}
// Only merge in telemetry config from agent if the requested is
// authorized with a proxy token. This prevents us leaking potentially
// sensitive config like Circonus API token via a public endpoint. Proxy
// tokens are only ever generated in-memory and passed via ENV to a child
// proxy process so potential for abuse here seems small. This endpoint in
// general is only useful for managed proxies now so it should _always_ be
// true that auth is via a proxy token but inconvenient for testing if we
// lock it down so strictly.
if isProxyToken {
// Add telemetry config. Copy the global config so we can customize the
// prefix.
telemetryCfg := s.agent.config.Telemetry
telemetryCfg.MetricsPrefix = telemetryCfg.MetricsPrefix + ".proxy." + target.ID
// First see if the user has specified telemetry
if userRaw, ok := config["telemetry"]; ok {
// User specified domething, see if it is compatible with agent
// telemetry config:
var uCfg lib.TelemetryConfig
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: &uCfg,
// Make sure that if the user passes something that isn't just a
// simple override of a valid TelemetryConfig that we fail so that we
// don't clobber their custom config.
ErrorUnused: true,
})
if err == nil {
if err = dec.Decode(userRaw); err == nil {
// It did decode! Merge any unspecified fields from agent config.
uCfg.MergeDefaults(&telemetryCfg)
config["telemetry"] = uCfg
}
}
// Failed to decode, just keep user's config["telemetry"] verbatim
// with no agent merge.
} else {
// Add agent telemetry config.
config["telemetry"] = telemetryCfg
}
}
reply := &api.ConnectProxyConfig{
ProxyServiceID: proxy.Proxy.ProxyService.ID,
TargetServiceID: target.ID,
TargetServiceName: target.Service,
ContentHash: contentHash,
ExecMode: api.ProxyExecMode(proxy.Proxy.ExecMode.String()),
Command: proxy.Proxy.Command,
Config: config,
}
return contentHash, reply, nil
})
}
type agentLocalBlockingFunc func(ws memdb.WatchSet) (string, interface{}, error)
// agentLocalBlockingQuery performs a blocking query in a generic way against
// local agent state that has no RPC or raft to back it. It uses `hash` paramter
// instead of an `index`. The resp is needed to write the `X-Consul-ContentHash`
// header back on return no Status nor body content is ever written to it.
func (s *HTTPServer) agentLocalBlockingQuery(resp http.ResponseWriter, hash string,
queryOpts *structs.QueryOptions, fn agentLocalBlockingFunc) (interface{}, error) {
// If we are not blocking we can skip tracking and allocating - nil WatchSet
// is still valid to call Add on and will just be a no op.
var ws memdb.WatchSet
var timeout *time.Timer
if hash != "" {
// TODO(banks) at least define these defaults somewhere in a const. Would be
// nice not to duplicate the ones in consul/rpc.go too...
wait := queryOpts.MaxQueryTime
if wait == 0 {
wait = 5 * time.Minute
}
if wait > 10*time.Minute {
wait = 10 * time.Minute
}
// Apply a small amount of jitter to the request.
wait += lib.RandomStagger(wait / 16)
timeout = time.NewTimer(wait)
}
for {
// Must reset this every loop in case the Watch set is already closed but
// hash remains same. In that case we'll need to re-block on ws.Watch()
// again.
ws = memdb.NewWatchSet()
curHash, curResp, err := fn(ws)
if err != nil {
return curResp, err
}
// Return immediately if there is no timeout, the hash is different or the
// Watch returns true (indicating timeout fired). Note that Watch on a nil
// WatchSet immediately returns false which would incorrectly cause this to
// loop and repeat again, however we rely on the invariant that ws == nil
// IFF timeout == nil in which case the Watch call is never invoked.
if timeout == nil || hash != curHash || ws.Watch(timeout.C) {
resp.Header().Set("X-Consul-ContentHash", curHash)
return curResp, err
}
// Watch returned false indicating a change was detected, loop and repeat
// the callback to load the new value.
}
}
// AgentConnectAuthorize
//
// POST /v1/agent/connect/authorize
//
// Note: when this logic changes, consider if the Intention.Check RPC method
// also needs to be updated.
func (s *HTTPServer) AgentConnectAuthorize(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Fetch the token
var token string
s.parseToken(req, &token)
// Decode the request from the request body
var authReq structs.ConnectAuthorizeRequest
if err := decodeBody(req, &authReq, nil); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
}
// We need to have a target to check intentions
if authReq.Target == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Target service must be specified")
return nil, nil
}
// Parse the certificate URI from the client ID
uriRaw, err := url.Parse(authReq.ClientCertURI)
if err != nil {
return &connectAuthorizeResp{
Authorized: false,
Reason: fmt.Sprintf("Client ID must be a URI: %s", err),
}, nil
}
uri, err := connect.ParseCertURI(uriRaw)
if err != nil {
return &connectAuthorizeResp{
Authorized: false,
Reason: fmt.Sprintf("Invalid client ID: %s", err),
}, nil
}
uriService, ok := uri.(*connect.SpiffeIDService)
if !ok {
return &connectAuthorizeResp{
Authorized: false,
Reason: "Client ID must be a valid SPIFFE service URI",
}, nil
}
// We need to verify service:write permissions for the given token.
// We do this manually here since the RPC request below only verifies
// service:read.
rule, err := s.agent.resolveToken(token)
if err != nil {
return nil, err
}
if rule != nil && !rule.ServiceWrite(authReq.Target, nil) {
return nil, acl.ErrPermissionDenied
}
// Validate the trust domain matches ours. Later we will support explicit
// external federation but not built yet.
rootArgs := &structs.DCSpecificRequest{Datacenter: s.agent.config.Datacenter}
raw, _, err := s.agent.cache.Get(cachetype.ConnectCARootName, rootArgs)
if err != nil {
return nil, err
}
roots, ok := raw.(*structs.IndexedCARoots)
if !ok {
return nil, fmt.Errorf("internal error: roots response type not correct")
}
if roots.TrustDomain == "" {
return nil, fmt.Errorf("connect CA not bootstrapped yet")
}
if roots.TrustDomain != strings.ToLower(uriService.Host) {
return &connectAuthorizeResp{
Authorized: false,
Reason: fmt.Sprintf("Identity from an external trust domain: %s",
uriService.Host),
}, nil
}
// TODO(banks): Implement revocation list checking here.
// Get the intentions for this target service.
args := &structs.IntentionQueryRequest{
Datacenter: s.agent.config.Datacenter,
Match: &structs.IntentionQueryMatch{
Type: structs.IntentionMatchDestination,
Entries: []structs.IntentionMatchEntry{
{
Namespace: structs.IntentionDefaultNamespace,
Name: authReq.Target,
},
},
},
}
args.Token = token
raw, m, err := s.agent.cache.Get(cachetype.IntentionMatchName, args)
if err != nil {
return nil, err
}
setCacheMeta(resp, &m)
reply, ok := raw.(*structs.IndexedIntentionMatches)
if !ok {
return nil, fmt.Errorf("internal error: response type not correct")
}
if len(reply.Matches) != 1 {
return nil, fmt.Errorf("Internal error loading matches")
}
// Test the authorization for each match
for _, ixn := range reply.Matches[0] {
if auth, ok := uriService.Authorize(ixn); ok {
return &connectAuthorizeResp{
Authorized: auth,
Reason: fmt.Sprintf("Matched intention: %s", ixn.String()),
}, nil
}
}
// No match, we need to determine the default behavior. We do this by
// specifying the anonymous token token, which will get that behavior.
// The default behavior if ACLs are disabled is to allow connections
// to mimic the behavior of Consul itself: everything is allowed if
// ACLs are disabled.
rule, err = s.agent.resolveToken("")
if err != nil {
return nil, err
}
authz := true
reason := "ACLs disabled, access is allowed by default"
if rule != nil {
authz = rule.IntentionDefaultAllow()
reason = "Default behavior configured by ACLs"
}
return &connectAuthorizeResp{
Authorized: authz,
Reason: reason,
}, nil
}
// connectAuthorizeResp is the response format/structure for the
// /v1/agent/connect/authorize endpoint.
type connectAuthorizeResp struct {
Authorized bool // True if authorized, false if not
Reason string // Reason for the Authorized value (whether true or false)
}

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@ import (
"time"
"github.com/hashicorp/consul/agent/checks"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/testutil"
@ -23,6 +24,8 @@ import (
"github.com/hashicorp/consul/types"
uuid "github.com/hashicorp/go-uuid"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func externalIP() (string, error) {
@ -51,10 +54,62 @@ func TestAgent_MultiStartStop(t *testing.T) {
}
}
func TestAgent_ConnectClusterIDConfig(t *testing.T) {
tests := []struct {
name string
hcl string
wantClusterID string
wantPanic bool
}{
{
name: "default TestAgent has fixed cluster id",
hcl: "",
wantClusterID: connect.TestClusterID,
},
{
name: "no cluster ID specified sets to test ID",
hcl: "connect { enabled = true }",
wantClusterID: connect.TestClusterID,
},
{
name: "non-UUID cluster_id is fatal",
hcl: `connect {
enabled = true
ca_config {
cluster_id = "fake-id"
}
}`,
wantClusterID: "",
wantPanic: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Indirection to support panic recovery cleanly
testFn := func() {
a := &TestAgent{Name: "test", HCL: tt.hcl}
a.ExpectConfigError = tt.wantPanic
a.Start()
defer a.Shutdown()
cfg := a.consulConfig()
assert.Equal(t, tt.wantClusterID, cfg.CAConfig.ClusterID)
}
if tt.wantPanic {
require.Panics(t, testFn)
} else {
testFn()
}
})
}
}
func TestAgent_StartStop(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
// defer a.Shutdown()
defer a.Shutdown()
if err := a.Leave(); err != nil {
t.Fatalf("err: %v", err)
@ -1294,6 +1349,187 @@ func TestAgent_PurgeServiceOnDuplicate(t *testing.T) {
}
}
func TestAgent_PersistProxy(t *testing.T) {
t.Parallel()
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
cfg := `
server = false
bootstrap = false
data_dir = "` + dataDir + `"
`
a := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
a.Start()
defer os.RemoveAll(dataDir)
defer a.Shutdown()
require := require.New(t)
assert := assert.New(t)
// Add a service to proxy (precondition for AddProxy)
svc1 := &structs.NodeService{
ID: "redis",
Service: "redis",
Tags: []string{"foo"},
Port: 8000,
}
require.NoError(a.AddService(svc1, nil, true, ""))
// Add a proxy for it
proxy := &structs.ConnectManagedProxy{
TargetServiceID: svc1.ID,
Command: []string{"/bin/sleep", "3600"},
}
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash("redis-proxy"))
// Proxy is not persisted unless requested
require.NoError(a.AddProxy(proxy, false, ""))
_, err := os.Stat(file)
require.Error(err, "proxy should not be persisted")
// Proxy is persisted if requested
require.NoError(a.AddProxy(proxy, true, ""))
_, err = os.Stat(file)
require.NoError(err, "proxy should be persisted")
content, err := ioutil.ReadFile(file)
require.NoError(err)
var gotProxy persistedProxy
require.NoError(json.Unmarshal(content, &gotProxy))
assert.Equal(proxy.Command, gotProxy.Proxy.Command)
assert.Len(gotProxy.ProxyToken, 36) // sanity check for UUID
// Updates service definition on disk
proxy.Config = map[string]interface{}{
"foo": "bar",
}
require.NoError(a.AddProxy(proxy, true, ""))
content, err = ioutil.ReadFile(file)
require.NoError(err)
require.NoError(json.Unmarshal(content, &gotProxy))
assert.Equal(gotProxy.Proxy.Command, proxy.Command)
assert.Equal(gotProxy.Proxy.Config, proxy.Config)
assert.Len(gotProxy.ProxyToken, 36) // sanity check for UUID
a.Shutdown()
// Should load it back during later start
a2 := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
a2.Start()
defer a2.Shutdown()
restored := a2.State.Proxy("redis-proxy")
require.NotNil(restored)
assert.Equal(gotProxy.ProxyToken, restored.ProxyToken)
// Ensure the port that was auto picked at random is the same again
assert.Equal(gotProxy.Proxy.ProxyService.Port, restored.Proxy.ProxyService.Port)
assert.Equal(gotProxy.Proxy.Command, restored.Proxy.Command)
}
func TestAgent_PurgeProxy(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
require := require.New(t)
// Add a service to proxy (precondition for AddProxy)
svc1 := &structs.NodeService{
ID: "redis",
Service: "redis",
Tags: []string{"foo"},
Port: 8000,
}
require.NoError(a.AddService(svc1, nil, true, ""))
// Add a proxy for it
proxy := &structs.ConnectManagedProxy{
TargetServiceID: svc1.ID,
Command: []string{"/bin/sleep", "3600"},
}
proxyID := "redis-proxy"
require.NoError(a.AddProxy(proxy, true, ""))
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash("redis-proxy"))
// Not removed
require.NoError(a.RemoveProxy(proxyID, false))
_, err := os.Stat(file)
require.NoError(err, "should not be removed")
// Re-add the proxy
require.NoError(a.AddProxy(proxy, true, ""))
// Removed
require.NoError(a.RemoveProxy(proxyID, true))
_, err = os.Stat(file)
require.Error(err, "should be removed")
}
func TestAgent_PurgeProxyOnDuplicate(t *testing.T) {
t.Parallel()
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
cfg := `
data_dir = "` + dataDir + `"
server = false
bootstrap = false
`
a := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
a.Start()
defer a.Shutdown()
defer os.RemoveAll(dataDir)
require := require.New(t)
// Add a service to proxy (precondition for AddProxy)
svc1 := &structs.NodeService{
ID: "redis",
Service: "redis",
Tags: []string{"foo"},
Port: 8000,
}
require.NoError(a.AddService(svc1, nil, true, ""))
// Add a proxy for it
proxy := &structs.ConnectManagedProxy{
TargetServiceID: svc1.ID,
Command: []string{"/bin/sleep", "3600"},
}
proxyID := "redis-proxy"
require.NoError(a.AddProxy(proxy, true, ""))
a.Shutdown()
// Try bringing the agent back up with the service already
// existing in the config
a2 := &TestAgent{Name: t.Name() + "-a2", HCL: cfg + `
service = {
id = "redis"
name = "redis"
tags = ["bar"]
port = 9000
connect {
proxy {
command = ["/bin/sleep", "3600"]
}
}
}
`, DataDir: dataDir}
a2.Start()
defer a2.Shutdown()
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash(proxyID))
_, err := os.Stat(file)
require.Error(err, "should have removed remote state")
result := a2.State.Proxy(proxyID)
require.NotNil(result)
require.Equal(proxy.Command, result.Proxy.Command)
}
func TestAgent_PersistCheck(t *testing.T) {
t.Parallel()
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
@ -1629,6 +1865,96 @@ func TestAgent_unloadServices(t *testing.T) {
}
}
func TestAgent_loadProxies(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `
service = {
id = "rabbitmq"
name = "rabbitmq"
port = 5672
token = "abc123"
connect {
proxy {
config {
bind_port = 1234
}
}
}
}
`)
defer a.Shutdown()
services := a.State.Services()
if _, ok := services["rabbitmq"]; !ok {
t.Fatalf("missing service")
}
if token := a.State.ServiceToken("rabbitmq"); token != "abc123" {
t.Fatalf("bad: %s", token)
}
if _, ok := services["rabbitmq-proxy"]; !ok {
t.Fatalf("missing proxy service")
}
if token := a.State.ServiceToken("rabbitmq-proxy"); token != "abc123" {
t.Fatalf("bad: %s", token)
}
proxies := a.State.Proxies()
if _, ok := proxies["rabbitmq-proxy"]; !ok {
t.Fatalf("missing proxy")
}
}
func TestAgent_loadProxies_nilProxy(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `
service = {
id = "rabbitmq"
name = "rabbitmq"
port = 5672
token = "abc123"
connect {
}
}
`)
defer a.Shutdown()
services := a.State.Services()
require.Contains(t, services, "rabbitmq")
require.Equal(t, "abc123", a.State.ServiceToken("rabbitmq"))
require.NotContains(t, services, "rabbitme-proxy")
require.Empty(t, a.State.Proxies())
}
func TestAgent_unloadProxies(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `
service = {
id = "rabbitmq"
name = "rabbitmq"
port = 5672
token = "abc123"
connect {
proxy {
config {
bind_port = 1234
}
}
}
}
`)
defer a.Shutdown()
// Sanity check it's there
require.NotNil(t, a.State.Proxy("rabbitmq-proxy"))
// Unload all proxies
if err := a.unloadProxies(); err != nil {
t.Fatalf("err: %s", err)
}
if len(a.State.Proxies()) != 0 {
t.Fatalf("should have unloaded proxies")
}
}
func TestAgent_Service_MaintenanceMode(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
@ -2179,6 +2505,18 @@ func TestAgent_reloadWatches(t *testing.T) {
t.Fatalf("bad: %s", err)
}
// Should fail to reload with connect watches
newConf.Watches = []map[string]interface{}{
{
"type": "connect_roots",
"key": "asdf",
"args": []interface{}{"ls"},
},
}
if err := a.reloadWatches(&newConf); err == nil || !strings.Contains(err.Error(), "not allowed in agent config") {
t.Fatalf("bad: %s", err)
}
// Should still succeed with only HTTPS addresses
newConf.HTTPSAddrs = newConf.HTTPAddrs
newConf.HTTPAddrs = make([]net.Addr, 0)
@ -2226,3 +2564,217 @@ func TestAgent_reloadWatchesHTTPS(t *testing.T) {
t.Fatalf("bad: %s", err)
}
}
func TestAgent_AddProxy(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `
node_name = "node1"
connect {
proxy_defaults {
exec_mode = "script"
daemon_command = ["foo", "bar"]
script_command = ["bar", "foo"]
}
}
ports {
proxy_min_port = 20000
proxy_max_port = 20000
}
`)
defer a.Shutdown()
// Register a target service we can use
reg := &structs.NodeService{
Service: "web",
Port: 8080,
}
require.NoError(t, a.AddService(reg, nil, false, ""))
tests := []struct {
desc string
proxy, wantProxy *structs.ConnectManagedProxy
wantTCPCheck string
wantErr bool
}{
{
desc: "basic proxy adding, unregistered service",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"consul", "connect", "proxy"},
Config: map[string]interface{}{
"foo": "bar",
},
TargetServiceID: "db", // non-existent service.
},
// Target service must be registered.
wantErr: true,
},
{
desc: "basic proxy adding, registered service",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"consul", "connect", "proxy"},
Config: map[string]interface{}{
"foo": "bar",
},
TargetServiceID: "web",
},
wantErr: false,
},
{
desc: "default global exec mode",
proxy: &structs.ConnectManagedProxy{
Command: []string{"consul", "connect", "proxy"},
TargetServiceID: "web",
},
wantProxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeScript,
Command: []string{"consul", "connect", "proxy"},
TargetServiceID: "web",
},
wantErr: false,
},
{
desc: "default daemon command",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
TargetServiceID: "web",
},
wantProxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"foo", "bar"},
TargetServiceID: "web",
},
wantErr: false,
},
{
desc: "default script command",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeScript,
TargetServiceID: "web",
},
wantProxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeScript,
Command: []string{"bar", "foo"},
TargetServiceID: "web",
},
wantErr: false,
},
{
desc: "managed proxy with custom bind port",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"consul", "connect", "proxy"},
Config: map[string]interface{}{
"foo": "bar",
"bind_address": "127.10.10.10",
"bind_port": 1234,
},
TargetServiceID: "web",
},
wantTCPCheck: "127.10.10.10:1234",
wantErr: false,
},
{
// This test is necessary since JSON and HCL both will parse
// numbers as a float64.
desc: "managed proxy with custom bind port (float64)",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"consul", "connect", "proxy"},
Config: map[string]interface{}{
"foo": "bar",
"bind_address": "127.10.10.10",
"bind_port": float64(1234),
},
TargetServiceID: "web",
},
wantTCPCheck: "127.10.10.10:1234",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
require := require.New(t)
err := a.AddProxy(tt.proxy, false, "")
if tt.wantErr {
require.Error(err)
return
}
require.NoError(err)
// Test the ID was created as we expect.
got := a.State.Proxy("web-proxy")
wantProxy := tt.wantProxy
if wantProxy == nil {
wantProxy = tt.proxy
}
wantProxy.ProxyService = got.Proxy.ProxyService
require.Equal(wantProxy, got.Proxy)
// Ensure a TCP check was created for the service.
gotCheck := a.State.Check("service:web-proxy")
require.NotNil(gotCheck)
require.Equal("Connect Proxy Listening", gotCheck.Name)
// Confusingly, a.State.Check("service:web-proxy") will return the state
// but it's Definition field will be empty. This appears to be expected
// when adding Checks as part of `AddService`. Notice how `AddService`
// tests in this file don't assert on that state but instead look at the
// agent's check state directly to ensure the right thing was registered.
// We'll do the same for now.
gotTCP, ok := a.checkTCPs["service:web-proxy"]
require.True(ok)
wantTCPCheck := tt.wantTCPCheck
if wantTCPCheck == "" {
wantTCPCheck = "127.0.0.1:20000"
}
require.Equal(wantTCPCheck, gotTCP.TCP)
require.Equal(10*time.Second, gotTCP.Interval)
})
}
}
func TestAgent_RemoveProxy(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `
node_name = "node1"
`)
defer a.Shutdown()
require := require.New(t)
// Register a target service we can use
reg := &structs.NodeService{
Service: "web",
Port: 8080,
}
require.NoError(a.AddService(reg, nil, false, ""))
// Add a proxy for web
pReg := &structs.ConnectManagedProxy{
TargetServiceID: "web",
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"foo"},
}
require.NoError(a.AddProxy(pReg, false, ""))
// Test the ID was created as we expect.
gotProxy := a.State.Proxy("web-proxy")
require.NotNil(gotProxy)
err := a.RemoveProxy("web-proxy", false)
require.NoError(err)
gotProxy = a.State.Proxy("web-proxy")
require.Nil(gotProxy)
require.Nil(a.State.Service("web-proxy"), "web-proxy service")
// Removing invalid proxy should be an error
err = a.RemoveProxy("foobar", false)
require.Error(err)
}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,240 @@
package cachetype
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
)
// Recommended name for registration.
const ConnectCALeafName = "connect-ca-leaf"
// ConnectCALeaf supports fetching and generating Connect leaf
// certificates.
type ConnectCALeaf struct {
caIndex uint64 // Current index for CA roots
issuedCertsLock sync.RWMutex
issuedCerts map[string]*structs.IssuedCert
RPC RPC // RPC client for remote requests
Cache *cache.Cache // Cache that has CA root certs via ConnectCARoot
}
func (c *ConnectCALeaf) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
var result cache.FetchResult
// Get the correct type
reqReal, ok := req.(*ConnectCALeafRequest)
if !ok {
return result, fmt.Errorf(
"Internal cache failure: request wrong type: %T", req)
}
// This channel watches our overall timeout. The other goroutines
// launched in this function should end all around the same time so
// they clean themselves up.
timeoutCh := time.After(opts.Timeout)
// Kick off the goroutine that waits for new CA roots. The channel buffer
// is so that the goroutine doesn't block forever if we return for other
// reasons.
newRootCACh := make(chan error, 1)
go c.waitNewRootCA(reqReal.Datacenter, newRootCACh, opts.Timeout)
// Get our prior cert (if we had one) and use that to determine our
// expiration time. If no cert exists, we expire immediately since we
// need to generate.
c.issuedCertsLock.RLock()
lastCert := c.issuedCerts[reqReal.Service]
c.issuedCertsLock.RUnlock()
var leafExpiryCh <-chan time.Time
if lastCert != nil {
// Determine how long we wait until triggering. If we've already
// expired, we trigger immediately.
if expiryDur := lastCert.ValidBefore.Sub(time.Now()); expiryDur > 0 {
leafExpiryCh = time.After(expiryDur - 1*time.Hour)
// TODO(mitchellh): 1 hour buffer is hardcoded above
}
}
if leafExpiryCh == nil {
// If the channel is still nil then it means we need to generate
// a cert no matter what: we either don't have an existing one or
// it is expired.
leafExpiryCh = time.After(0)
}
// Block on the events that wake us up.
select {
case <-timeoutCh:
// On a timeout, we just return the empty result and no error.
// It isn't an error to timeout, its just the limit of time the
// caching system wants us to block for. By returning an empty result
// the caching system will ignore.
return result, nil
case err := <-newRootCACh:
// A new root CA triggers us to refresh the leaf certificate.
// If there was an error while getting the root CA then we return.
// Otherwise, we leave the select statement and move to generation.
if err != nil {
return result, err
}
case <-leafExpiryCh:
// The existing leaf certificate is expiring soon, so we generate a
// new cert with a healthy overlapping validity period (determined
// by the above channel).
}
// Need to lookup RootCAs response to discover trust domain. First just lookup
// with no blocking info - this should be a cache hit most of the time.
rawRoots, _, err := c.Cache.Get(ConnectCARootName, &structs.DCSpecificRequest{
Datacenter: reqReal.Datacenter,
})
if err != nil {
return result, err
}
roots, ok := rawRoots.(*structs.IndexedCARoots)
if !ok {
return result, errors.New("invalid RootCA response type")
}
if roots.TrustDomain == "" {
return result, errors.New("cluster has no CA bootstrapped")
}
// Build the service ID
serviceID := &connect.SpiffeIDService{
Host: roots.TrustDomain,
Datacenter: reqReal.Datacenter,
Namespace: "default",
Service: reqReal.Service,
}
// Create a new private key
pk, pkPEM, err := connect.GeneratePrivateKey()
if err != nil {
return result, err
}
// Create a CSR.
csr, err := connect.CreateCSR(serviceID, pk)
if err != nil {
return result, err
}
// Request signing
var reply structs.IssuedCert
args := structs.CASignRequest{
WriteRequest: structs.WriteRequest{Token: reqReal.Token},
Datacenter: reqReal.Datacenter,
CSR: csr,
}
if err := c.RPC.RPC("ConnectCA.Sign", &args, &reply); err != nil {
return result, err
}
reply.PrivateKeyPEM = pkPEM
// Lock the issued certs map so we can insert it. We only insert if
// we didn't happen to get a newer one. This should never happen since
// the Cache should ensure only one Fetch per service, but we sanity
// check just in case.
c.issuedCertsLock.Lock()
defer c.issuedCertsLock.Unlock()
lastCert = c.issuedCerts[reqReal.Service]
if lastCert == nil || lastCert.ModifyIndex < reply.ModifyIndex {
if c.issuedCerts == nil {
c.issuedCerts = make(map[string]*structs.IssuedCert)
}
c.issuedCerts[reqReal.Service] = &reply
lastCert = &reply
}
result.Value = lastCert
result.Index = lastCert.ModifyIndex
return result, nil
}
// waitNewRootCA blocks until a new root CA is available or the timeout is
// reached (on timeout ErrTimeout is returned on the channel).
func (c *ConnectCALeaf) waitNewRootCA(datacenter string, ch chan<- error,
timeout time.Duration) {
// We always want to block on at least an initial value. If this isn't
minIndex := atomic.LoadUint64(&c.caIndex)
if minIndex == 0 {
minIndex = 1
}
// Fetch some new roots. This will block until our MinQueryIndex is
// matched or the timeout is reached.
rawRoots, _, err := c.Cache.Get(ConnectCARootName, &structs.DCSpecificRequest{
Datacenter: datacenter,
QueryOptions: structs.QueryOptions{
MinQueryIndex: minIndex,
MaxQueryTime: timeout,
},
})
if err != nil {
ch <- err
return
}
roots, ok := rawRoots.(*structs.IndexedCARoots)
if !ok {
// This should never happen but we don't want to even risk a panic
ch <- fmt.Errorf(
"internal error: CA root cache returned bad type: %T", rawRoots)
return
}
// We do a loop here because there can be multiple waitNewRootCA calls
// happening simultaneously. Each Fetch kicks off one call. These are
// multiplexed through Cache.Get which should ensure we only ever
// actually make a single RPC call. However, there is a race to set
// the caIndex field so do a basic CAS loop here.
for {
// We only set our index if its newer than what is previously set.
old := atomic.LoadUint64(&c.caIndex)
if old == roots.Index || old > roots.Index {
break
}
// Set the new index atomically. If the caIndex value changed
// in the meantime, retry.
if atomic.CompareAndSwapUint64(&c.caIndex, old, roots.Index) {
break
}
}
// Trigger the channel since we updated.
ch <- nil
}
// ConnectCALeafRequest is the cache.Request implementation for the
// ConnectCALeaf cache type. This is implemented here and not in structs
// since this is only used for cache-related requests and not forwarded
// directly to any Consul servers.
type ConnectCALeafRequest struct {
Token string
Datacenter string
Service string // Service name, not ID
MinQueryIndex uint64
}
func (r *ConnectCALeafRequest) CacheInfo() cache.RequestInfo {
return cache.RequestInfo{
Token: r.Token,
Key: r.Service,
Datacenter: r.Datacenter,
MinIndex: r.MinQueryIndex,
}
}

View File

@ -0,0 +1,209 @@
package cachetype
import (
"fmt"
"sync/atomic"
"testing"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// Test that after an initial signing, new CA roots (new ID) will
// trigger a blocking query to execute.
func TestConnectCALeaf_changingRoots(t *testing.T) {
t.Parallel()
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ, rootsCh := testCALeafType(t, rpc)
defer close(rootsCh)
rootsCh <- structs.IndexedCARoots{
ActiveRootID: "1",
TrustDomain: "fake-trust-domain.consul",
QueryMeta: structs.QueryMeta{Index: 1},
}
// Instrument ConnectCA.Sign to return signed cert
var resp *structs.IssuedCert
var idx uint64
rpc.On("RPC", "ConnectCA.Sign", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
reply := args.Get(2).(*structs.IssuedCert)
reply.ValidBefore = time.Now().Add(12 * time.Hour)
reply.CreateIndex = atomic.AddUint64(&idx, 1)
reply.ModifyIndex = reply.CreateIndex
resp = reply
})
// We'll reuse the fetch options and request
opts := cache.FetchOptions{MinIndex: 0, Timeout: 10 * time.Second}
req := &ConnectCALeafRequest{Datacenter: "dc1", Service: "web"}
// First fetch should return immediately
fetchCh := TestFetchCh(t, typ, opts, req)
select {
case <-time.After(100 * time.Millisecond):
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
require.Equal(cache.FetchResult{
Value: resp,
Index: 1,
}, result)
}
// Second fetch should block with set index
fetchCh = TestFetchCh(t, typ, opts, req)
select {
case result := <-fetchCh:
t.Fatalf("should not return: %#v", result)
case <-time.After(100 * time.Millisecond):
}
// Let's send in new roots, which should trigger the sign req
rootsCh <- structs.IndexedCARoots{
ActiveRootID: "2",
TrustDomain: "fake-trust-domain.consul",
QueryMeta: structs.QueryMeta{Index: 2},
}
select {
case <-time.After(100 * time.Millisecond):
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
require.Equal(cache.FetchResult{
Value: resp,
Index: 2,
}, result)
}
// Third fetch should block
fetchCh = TestFetchCh(t, typ, opts, req)
select {
case result := <-fetchCh:
t.Fatalf("should not return: %#v", result)
case <-time.After(100 * time.Millisecond):
}
}
// Test that after an initial signing, an expiringLeaf will trigger a
// blocking query to resign.
func TestConnectCALeaf_expiringLeaf(t *testing.T) {
t.Parallel()
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ, rootsCh := testCALeafType(t, rpc)
defer close(rootsCh)
rootsCh <- structs.IndexedCARoots{
ActiveRootID: "1",
TrustDomain: "fake-trust-domain.consul",
QueryMeta: structs.QueryMeta{Index: 1},
}
// Instrument ConnectCA.Sign to
var resp *structs.IssuedCert
var idx uint64
rpc.On("RPC", "ConnectCA.Sign", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
reply := args.Get(2).(*structs.IssuedCert)
reply.CreateIndex = atomic.AddUint64(&idx, 1)
reply.ModifyIndex = reply.CreateIndex
// This sets the validity to 0 on the first call, and
// 12 hours+ on subsequent calls. This means that our first
// cert expires immediately.
reply.ValidBefore = time.Now().Add((12 * time.Hour) *
time.Duration(reply.CreateIndex-1))
resp = reply
})
// We'll reuse the fetch options and request
opts := cache.FetchOptions{MinIndex: 0, Timeout: 10 * time.Second}
req := &ConnectCALeafRequest{Datacenter: "dc1", Service: "web"}
// First fetch should return immediately
fetchCh := TestFetchCh(t, typ, opts, req)
select {
case <-time.After(100 * time.Millisecond):
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
require.Equal(cache.FetchResult{
Value: resp,
Index: 1,
}, result)
}
// Second fetch should return immediately despite there being
// no updated CA roots, because we issued an expired cert.
fetchCh = TestFetchCh(t, typ, opts, req)
select {
case <-time.After(100 * time.Millisecond):
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
require.Equal(cache.FetchResult{
Value: resp,
Index: 2,
}, result)
}
// Third fetch should block since the cert is not expiring and
// we also didn't update CA certs.
fetchCh = TestFetchCh(t, typ, opts, req)
select {
case result := <-fetchCh:
t.Fatalf("should not return: %#v", result)
case <-time.After(100 * time.Millisecond):
}
}
// testCALeafType returns a *ConnectCALeaf that is pre-configured to
// use the given RPC implementation for "ConnectCA.Sign" operations.
func testCALeafType(t *testing.T, rpc RPC) (*ConnectCALeaf, chan structs.IndexedCARoots) {
// This creates an RPC implementation that will block until the
// value is sent on the channel. This lets us control when the
// next values show up.
rootsCh := make(chan structs.IndexedCARoots, 10)
rootsRPC := &testGatedRootsRPC{ValueCh: rootsCh}
// Create a cache
c := cache.TestCache(t)
c.RegisterType(ConnectCARootName, &ConnectCARoot{RPC: rootsRPC}, &cache.RegisterOptions{
// Disable refresh so that the gated channel controls the
// request directly. Otherwise, we get background refreshes and
// it screws up the ordering of the channel reads of the
// testGatedRootsRPC implementation.
Refresh: false,
})
// Create the leaf type
return &ConnectCALeaf{RPC: rpc, Cache: c}, rootsCh
}
// testGatedRootsRPC will send each subsequent value on the channel as the
// RPC response, blocking if it is waiting for a value on the channel. This
// can be used to control when background fetches are returned and what they
// return.
//
// This should be used with Refresh = false for the registration options so
// automatic refreshes don't mess up the channel read ordering.
type testGatedRootsRPC struct {
ValueCh chan structs.IndexedCARoots
}
func (r *testGatedRootsRPC) RPC(method string, args interface{}, reply interface{}) error {
if method != "ConnectCA.Roots" {
return fmt.Errorf("invalid RPC method: %s", method)
}
replyReal := reply.(*structs.IndexedCARoots)
*replyReal = <-r.ValueCh
return nil
}

View File

@ -0,0 +1,43 @@
package cachetype
import (
"fmt"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
)
// Recommended name for registration.
const ConnectCARootName = "connect-ca-root"
// ConnectCARoot supports fetching the Connect CA roots. This is a
// straightforward cache type since it only has to block on the given
// index and return the data.
type ConnectCARoot struct {
RPC RPC
}
func (c *ConnectCARoot) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
var result cache.FetchResult
// The request should be a DCSpecificRequest.
reqReal, ok := req.(*structs.DCSpecificRequest)
if !ok {
return result, fmt.Errorf(
"Internal cache failure: request wrong type: %T", req)
}
// Set the minimum query index to our current index so we block
reqReal.QueryOptions.MinQueryIndex = opts.MinIndex
reqReal.QueryOptions.MaxQueryTime = opts.Timeout
// Fetch
var reply structs.IndexedCARoots
if err := c.RPC.RPC("ConnectCA.Roots", reqReal, &reply); err != nil {
return result, err
}
result.Value = &reply
result.Index = reply.QueryMeta.Index
return result, nil
}

View File

@ -0,0 +1,57 @@
package cachetype
import (
"testing"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestConnectCARoot(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &ConnectCARoot{RPC: rpc}
// Expect the proper RPC call. This also sets the expected value
// since that is return-by-pointer in the arguments.
var resp *structs.IndexedCARoots
rpc.On("RPC", "ConnectCA.Roots", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.DCSpecificRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime)
reply := args.Get(2).(*structs.IndexedCARoots)
reply.QueryMeta.Index = 48
resp = reply
})
// Fetch
result, err := typ.Fetch(cache.FetchOptions{
MinIndex: 24,
Timeout: 1 * time.Second,
}, &structs.DCSpecificRequest{Datacenter: "dc1"})
require.Nil(err)
require.Equal(cache.FetchResult{
Value: resp,
Index: 48,
}, result)
}
func TestConnectCARoot_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &ConnectCARoot{RPC: rpc}
// Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.NotNil(err)
require.Contains(err.Error(), "wrong type")
}

View File

@ -0,0 +1,41 @@
package cachetype
import (
"fmt"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
)
// Recommended name for registration.
const IntentionMatchName = "intention-match"
// IntentionMatch supports fetching the intentions via match queries.
type IntentionMatch struct {
RPC RPC
}
func (c *IntentionMatch) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
var result cache.FetchResult
// The request should be an IntentionQueryRequest.
reqReal, ok := req.(*structs.IntentionQueryRequest)
if !ok {
return result, fmt.Errorf(
"Internal cache failure: request wrong type: %T", req)
}
// Set the minimum query index to our current index so we block
reqReal.MinQueryIndex = opts.MinIndex
reqReal.MaxQueryTime = opts.Timeout
// Fetch
var reply structs.IndexedIntentionMatches
if err := c.RPC.RPC("Intention.Match", reqReal, &reply); err != nil {
return result, err
}
result.Value = &reply
result.Index = reply.Index
return result, nil
}

View File

@ -0,0 +1,57 @@
package cachetype
import (
"testing"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestIntentionMatch(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &IntentionMatch{RPC: rpc}
// Expect the proper RPC call. This also sets the expected value
// since that is return-by-pointer in the arguments.
var resp *structs.IndexedIntentionMatches
rpc.On("RPC", "Intention.Match", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.IntentionQueryRequest)
require.Equal(uint64(24), req.MinQueryIndex)
require.Equal(1*time.Second, req.MaxQueryTime)
reply := args.Get(2).(*structs.IndexedIntentionMatches)
reply.Index = 48
resp = reply
})
// Fetch
result, err := typ.Fetch(cache.FetchOptions{
MinIndex: 24,
Timeout: 1 * time.Second,
}, &structs.IntentionQueryRequest{Datacenter: "dc1"})
require.NoError(err)
require.Equal(cache.FetchResult{
Value: resp,
Index: 48,
}, result)
}
func TestIntentionMatch_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &IntentionMatch{RPC: rpc}
// Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err)
require.Contains(err.Error(), "wrong type")
}

View File

@ -0,0 +1,23 @@
// Code generated by mockery v1.0.0
package cachetype
import mock "github.com/stretchr/testify/mock"
// MockRPC is an autogenerated mock type for the RPC type
type MockRPC struct {
mock.Mock
}
// RPC provides a mock function with given fields: method, args, reply
func (_m *MockRPC) RPC(method string, args interface{}, reply interface{}) error {
ret := _m.Called(method, args, reply)
var r0 error
if rf, ok := ret.Get(0).(func(string, interface{}, interface{}) error); ok {
r0 = rf(method, args, reply)
} else {
r0 = ret.Error(0)
}
return r0
}

10
agent/cache-types/rpc.go Normal file
View File

@ -0,0 +1,10 @@
package cachetype
//go:generate mockery -all -inpkg
// RPC is an interface that an RPC client must implement. This is a helper
// interface that is implemented by the agent delegate so that Type
// implementations can request RPC access.
type RPC interface {
RPC(method string, args interface{}, reply interface{}) error
}

View File

@ -0,0 +1,60 @@
package cachetype
import (
"reflect"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/mitchellh/go-testing-interface"
)
// TestRPC returns a mock implementation of the RPC interface.
func TestRPC(t testing.T) *MockRPC {
// This function is relatively useless but this allows us to perhaps
// perform some initialization later.
return &MockRPC{}
}
// TestFetchCh returns a channel that returns the result of the Fetch call.
// This is useful for testing timing and concurrency with Fetch calls.
// Errors will show up as an error type on the resulting channel so a
// type switch should be used.
func TestFetchCh(
t testing.T,
typ cache.Type,
opts cache.FetchOptions,
req cache.Request) <-chan interface{} {
resultCh := make(chan interface{})
go func() {
result, err := typ.Fetch(opts, req)
if err != nil {
resultCh <- err
return
}
resultCh <- result
}()
return resultCh
}
// TestFetchChResult tests that the result from TestFetchCh matches
// within a reasonable period of time (it expects it to be "immediate" but
// waits some milliseconds).
func TestFetchChResult(t testing.T, ch <-chan interface{}, expected interface{}) {
t.Helper()
select {
case result := <-ch:
if err, ok := result.(error); ok {
t.Fatalf("Result was error: %s", err)
return
}
if !reflect.DeepEqual(result, expected) {
t.Fatalf("Result doesn't match!\n\n%#v\n\n%#v", result, expected)
}
case <-time.After(50 * time.Millisecond):
}
}

536
agent/cache/cache.go vendored Normal file
View File

@ -0,0 +1,536 @@
// Package cache provides caching features for data from a Consul server.
//
// While this is similar in some ways to the "agent/ae" package, a key
// difference is that with anti-entropy, the agent is the authoritative
// source so it resolves differences the server may have. With caching (this
// package), the server is the authoritative source and we do our best to
// balance performance and correctness, depending on the type of data being
// requested.
//
// The types of data that can be cached is configurable via the Type interface.
// This allows specialized behavior for certain types of data. Each type of
// Consul data (CA roots, leaf certs, intentions, KV, catalog, etc.) will
// have to be manually implemented. This usually is not much work, see
// the "agent/cache-types" package.
package cache
import (
"container/heap"
"fmt"
"sync"
"time"
"github.com/armon/go-metrics"
)
//go:generate mockery -all -inpkg
// Constants related to refresh backoff. We probably don't ever need to
// make these configurable knobs since they primarily exist to lower load.
const (
CacheRefreshBackoffMin = 3 // 3 attempts before backing off
CacheRefreshMaxWait = 1 * time.Minute // maximum backoff wait time
)
// Cache is a agent-local cache of Consul data. Create a Cache using the
// New function. A zero-value Cache is not ready for usage and will result
// in a panic.
//
// The types of data to be cached must be registered via RegisterType. Then,
// calls to Get specify the type and a Request implementation. The
// implementation of Request is usually done directly on the standard RPC
// struct in agent/structs. This API makes cache usage a mostly drop-in
// replacement for non-cached RPC calls.
//
// The cache is partitioned by ACL and datacenter. This allows the cache
// to be safe for multi-DC queries and for queries where the data is modified
// due to ACLs all without the cache having to have any clever logic, at
// the slight expense of a less perfect cache.
//
// The Cache exposes various metrics via go-metrics. Please view the source
// searching for "metrics." to see the various metrics exposed. These can be
// used to explore the performance of the cache.
type Cache struct {
// types stores the list of data types that the cache knows how to service.
// These can be dynamically registered with RegisterType.
typesLock sync.RWMutex
types map[string]typeEntry
// entries contains the actual cache data. Access to entries and
// entriesExpiryHeap must be protected by entriesLock.
//
// entriesExpiryHeap is a heap of *cacheEntry values ordered by
// expiry, with the soonest to expire being first in the list (index 0).
//
// NOTE(mitchellh): The entry map key is currently a string in the format
// of "<DC>/<ACL token>/<Request key>" in order to properly partition
// requests to different datacenters and ACL tokens. This format has some
// big drawbacks: we can't evict by datacenter, ACL token, etc. For an
// initial implementation this works and the tests are agnostic to the
// internal storage format so changing this should be possible safely.
entriesLock sync.RWMutex
entries map[string]cacheEntry
entriesExpiryHeap *expiryHeap
}
// typeEntry is a single type that is registered with a Cache.
type typeEntry struct {
Type Type
Opts *RegisterOptions
}
// ResultMeta is returned from Get calls along with the value and can be used
// to expose information about the cache status for debugging or testing.
type ResultMeta struct {
// Return whether or not the request was a cache hit
Hit bool
}
// Options are options for the Cache.
type Options struct {
// Nothing currently, reserved.
}
// New creates a new cache with the given RPC client and reasonable defaults.
// Further settings can be tweaked on the returned value.
func New(*Options) *Cache {
// Initialize the heap. The buffer of 1 is really important because
// its possible for the expiry loop to trigger the heap to update
// itself and it'd block forever otherwise.
h := &expiryHeap{NotifyCh: make(chan struct{}, 1)}
heap.Init(h)
c := &Cache{
types: make(map[string]typeEntry),
entries: make(map[string]cacheEntry),
entriesExpiryHeap: h,
}
// Start the expiry watcher
go c.runExpiryLoop()
return c
}
// RegisterOptions are options that can be associated with a type being
// registered for the cache. This changes the behavior of the cache for
// this type.
type RegisterOptions struct {
// LastGetTTL is the time that the values returned by this type remain
// in the cache after the last get operation. If a value isn't accessed
// within this duration, the value is purged from the cache and
// background refreshing will cease.
LastGetTTL time.Duration
// Refresh configures whether the data is actively refreshed or if
// the data is only refreshed on an explicit Get. The default (false)
// is to only request data on explicit Get.
Refresh bool
// RefreshTimer is the time between attempting to refresh data.
// If this is zero, then data is refreshed immediately when a fetch
// is returned.
//
// RefreshTimeout determines the maximum query time for a refresh
// operation. This is specified as part of the query options and is
// expected to be implemented by the Type itself.
//
// Using these values, various "refresh" mechanisms can be implemented:
//
// * With a high timer duration and a low timeout, a timer-based
// refresh can be set that minimizes load on the Consul servers.
//
// * With a low timer and high timeout duration, a blocking-query-based
// refresh can be set so that changes in server data are recognized
// within the cache very quickly.
//
RefreshTimer time.Duration
RefreshTimeout time.Duration
}
// RegisterType registers a cacheable type.
//
// This makes the type available for Get but does not automatically perform
// any prefetching. In order to populate the cache, Get must be called.
func (c *Cache) RegisterType(n string, typ Type, opts *RegisterOptions) {
if opts == nil {
opts = &RegisterOptions{}
}
if opts.LastGetTTL == 0 {
opts.LastGetTTL = 72 * time.Hour // reasonable default is days
}
c.typesLock.Lock()
defer c.typesLock.Unlock()
c.types[n] = typeEntry{Type: typ, Opts: opts}
}
// Get loads the data for the given type and request. If data satisfying the
// minimum index is present in the cache, it is returned immediately. Otherwise,
// this will block until the data is available or the request timeout is
// reached.
//
// Multiple Get calls for the same Request (matching CacheKey value) will
// block on a single network request.
//
// The timeout specified by the Request will be the timeout on the cache
// Get, and does not correspond to the timeout of any background data
// fetching. If the timeout is reached before data satisfying the minimum
// index is retrieved, the last known value (maybe nil) is returned. No
// error is returned on timeout. This matches the behavior of Consul blocking
// queries.
func (c *Cache) Get(t string, r Request) (interface{}, ResultMeta, error) {
info := r.CacheInfo()
if info.Key == "" {
metrics.IncrCounter([]string{"consul", "cache", "bypass"}, 1)
// If no key is specified, then we do not cache this request.
// Pass directly through to the backend.
return c.fetchDirect(t, r)
}
// Get the actual key for our entry
key := c.entryKey(&info)
// First time through
first := true
// timeoutCh for watching our timeout
var timeoutCh <-chan time.Time
RETRY_GET:
// Get the current value
c.entriesLock.RLock()
entry, ok := c.entries[key]
c.entriesLock.RUnlock()
// If we have a current value and the index is greater than the
// currently stored index then we return that right away. If the
// index is zero and we have something in the cache we accept whatever
// we have.
if ok && entry.Valid {
if info.MinIndex == 0 || info.MinIndex < entry.Index {
meta := ResultMeta{}
if first {
metrics.IncrCounter([]string{"consul", "cache", t, "hit"}, 1)
meta.Hit = true
}
// Touch the expiration and fix the heap.
c.entriesLock.Lock()
entry.Expiry.Reset()
c.entriesExpiryHeap.Fix(entry.Expiry)
c.entriesLock.Unlock()
// We purposely do not return an error here since the cache
// only works with fetching values that either have a value
// or have an error, but not both. The Error may be non-nil
// in the entry because of this to note future fetch errors.
return entry.Value, meta, nil
}
}
// If this isn't our first time through and our last value has an error,
// then we return the error. This has the behavior that we don't sit in
// a retry loop getting the same error for the entire duration of the
// timeout. Instead, we make one effort to fetch a new value, and if
// there was an error, we return.
if !first && entry.Error != nil {
return entry.Value, ResultMeta{}, entry.Error
}
if first {
// We increment two different counters for cache misses depending on
// whether we're missing because we didn't have the data at all,
// or if we're missing because we're blocking on a set index.
if info.MinIndex == 0 {
metrics.IncrCounter([]string{"consul", "cache", t, "miss_new"}, 1)
} else {
metrics.IncrCounter([]string{"consul", "cache", t, "miss_block"}, 1)
}
}
// No longer our first time through
first = false
// Set our timeout channel if we must
if info.Timeout > 0 && timeoutCh == nil {
timeoutCh = time.After(info.Timeout)
}
// At this point, we know we either don't have a value at all or the
// value we have is too old. We need to wait for new data.
waiterCh, err := c.fetch(t, key, r, true, 0)
if err != nil {
return nil, ResultMeta{}, err
}
select {
case <-waiterCh:
// Our fetch returned, retry the get from the cache
goto RETRY_GET
case <-timeoutCh:
// Timeout on the cache read, just return whatever we have.
return entry.Value, ResultMeta{}, nil
}
}
// entryKey returns the key for the entry in the cache. See the note
// about the entry key format in the structure docs for Cache.
func (c *Cache) entryKey(r *RequestInfo) string {
return fmt.Sprintf("%s/%s/%s", r.Datacenter, r.Token, r.Key)
}
// fetch triggers a new background fetch for the given Request. If a
// background fetch is already running for a matching Request, the waiter
// channel for that request is returned. The effect of this is that there
// is only ever one blocking query for any matching requests.
//
// If allowNew is true then the fetch should create the cache entry
// if it doesn't exist. If this is false, then fetch will do nothing
// if the entry doesn't exist. This latter case is to support refreshing.
func (c *Cache) fetch(t, key string, r Request, allowNew bool, attempt uint) (<-chan struct{}, error) {
// Get the type that we're fetching
c.typesLock.RLock()
tEntry, ok := c.types[t]
c.typesLock.RUnlock()
if !ok {
return nil, fmt.Errorf("unknown type in cache: %s", t)
}
// We acquire a write lock because we may have to set Fetching to true.
c.entriesLock.Lock()
defer c.entriesLock.Unlock()
entry, ok := c.entries[key]
// If we aren't allowing new values and we don't have an existing value,
// return immediately. We return an immediately-closed channel so nothing
// blocks.
if !ok && !allowNew {
ch := make(chan struct{})
close(ch)
return ch, nil
}
// If we already have an entry and it is actively fetching, then return
// the currently active waiter.
if ok && entry.Fetching {
return entry.Waiter, nil
}
// If we don't have an entry, then create it. The entry must be marked
// as invalid so that it isn't returned as a valid value for a zero index.
if !ok {
entry = cacheEntry{Valid: false, Waiter: make(chan struct{})}
}
// Set that we're fetching to true, which makes it so that future
// identical calls to fetch will return the same waiter rather than
// perform multiple fetches.
entry.Fetching = true
c.entries[key] = entry
metrics.SetGauge([]string{"consul", "cache", "entries_count"}, float32(len(c.entries)))
// The actual Fetch must be performed in a goroutine.
go func() {
// Start building the new entry by blocking on the fetch.
result, err := tEntry.Type.Fetch(FetchOptions{
MinIndex: entry.Index,
Timeout: tEntry.Opts.RefreshTimeout,
}, r)
// Copy the existing entry to start.
newEntry := entry
newEntry.Fetching = false
if result.Value != nil {
// A new value was given, so we create a brand new entry.
newEntry.Value = result.Value
newEntry.Index = result.Index
if newEntry.Index < 1 {
// Less than one is invalid unless there was an error and in this case
// there wasn't since a value was returned. If a badly behaved RPC
// returns 0 when it has no data, we might get into a busy loop here. We
// set this to minimum of 1 which is safe because no valid user data can
// ever be written at raft index 1 due to the bootstrap process for
// raft. This insure that any subsequent background refresh request will
// always block, but allows the initial request to return immediately
// even if there is no data.
newEntry.Index = 1
}
// This is a valid entry with a result
newEntry.Valid = true
}
// Error handling
if err == nil {
metrics.IncrCounter([]string{"consul", "cache", "fetch_success"}, 1)
metrics.IncrCounter([]string{"consul", "cache", t, "fetch_success"}, 1)
if result.Index > 0 {
// Reset the attempts counter so we don't have any backoff
attempt = 0
} else {
// Result having a zero index is an implicit error case. There was no
// actual error but it implies the RPC found in index (nothing written
// yet for that type) but didn't take care to return safe "1" index. We
// don't want to actually treat it like an error by setting
// newEntry.Error to something non-nil, but we should guard against 100%
// CPU burn hot loops caused by that case which will never block but
// also won't backoff either. So we treat it as a failed attempt so that
// at least the failure backoff will save our CPU while still
// periodically refreshing so normal service can resume when the servers
// actually have something to return from the RPC. If we get in this
// state it can be considered a bug in the RPC implementation (to ever
// return a zero index) however since it can happen this is a safety net
// for the future.
attempt++
}
} else {
metrics.IncrCounter([]string{"consul", "cache", "fetch_error"}, 1)
metrics.IncrCounter([]string{"consul", "cache", t, "fetch_error"}, 1)
// Increment attempt counter
attempt++
// Always set the error. We don't override the value here because
// if Valid is true, then we can reuse the Value in the case a
// specific index isn't requested. However, for blocking queries,
// we want Error to be set so that we can return early with the
// error.
newEntry.Error = err
}
// Create a new waiter that will be used for the next fetch.
newEntry.Waiter = make(chan struct{})
// Set our entry
c.entriesLock.Lock()
// If this is a new entry (not in the heap yet), then setup the
// initial expiry information and insert. If we're already in
// the heap we do nothing since we're reusing the same entry.
if newEntry.Expiry == nil || newEntry.Expiry.HeapIndex == -1 {
newEntry.Expiry = &cacheEntryExpiry{
Key: key,
TTL: tEntry.Opts.LastGetTTL,
}
newEntry.Expiry.Reset()
heap.Push(c.entriesExpiryHeap, newEntry.Expiry)
}
c.entries[key] = newEntry
c.entriesLock.Unlock()
// Trigger the old waiter
close(entry.Waiter)
// If refresh is enabled, run the refresh in due time. The refresh
// below might block, but saves us from spawning another goroutine.
if tEntry.Opts.Refresh {
c.refresh(tEntry.Opts, attempt, t, key, r)
}
}()
return entry.Waiter, nil
}
// fetchDirect fetches the given request with no caching. Because this
// bypasses the caching entirely, multiple matching requests will result
// in multiple actual RPC calls (unlike fetch).
func (c *Cache) fetchDirect(t string, r Request) (interface{}, ResultMeta, error) {
// Get the type that we're fetching
c.typesLock.RLock()
tEntry, ok := c.types[t]
c.typesLock.RUnlock()
if !ok {
return nil, ResultMeta{}, fmt.Errorf("unknown type in cache: %s", t)
}
// Fetch it with the min index specified directly by the request.
result, err := tEntry.Type.Fetch(FetchOptions{
MinIndex: r.CacheInfo().MinIndex,
}, r)
if err != nil {
return nil, ResultMeta{}, err
}
// Return the result and ignore the rest
return result.Value, ResultMeta{}, nil
}
// refresh triggers a fetch for a specific Request according to the
// registration options.
func (c *Cache) refresh(opts *RegisterOptions, attempt uint, t string, key string, r Request) {
// Sanity-check, we should not schedule anything that has refresh disabled
if !opts.Refresh {
return
}
// If we're over the attempt minimum, start an exponential backoff.
if attempt > CacheRefreshBackoffMin {
waitTime := (1 << (attempt - CacheRefreshBackoffMin)) * time.Second
if waitTime > CacheRefreshMaxWait {
waitTime = CacheRefreshMaxWait
}
time.Sleep(waitTime)
}
// If we have a timer, wait for it
if opts.RefreshTimer > 0 {
time.Sleep(opts.RefreshTimer)
}
// Trigger. The "allowNew" field is false because in the time we were
// waiting to refresh we may have expired and got evicted. If that
// happened, we don't want to create a new entry.
c.fetch(t, key, r, false, attempt)
}
// runExpiryLoop is a blocking function that watches the expiration
// heap and invalidates entries that have expired.
func (c *Cache) runExpiryLoop() {
var expiryTimer *time.Timer
for {
// If we have a previous timer, stop it.
if expiryTimer != nil {
expiryTimer.Stop()
}
// Get the entry expiring soonest
var entry *cacheEntryExpiry
var expiryCh <-chan time.Time
c.entriesLock.RLock()
if len(c.entriesExpiryHeap.Entries) > 0 {
entry = c.entriesExpiryHeap.Entries[0]
expiryTimer = time.NewTimer(entry.Expires.Sub(time.Now()))
expiryCh = expiryTimer.C
}
c.entriesLock.RUnlock()
select {
case <-c.entriesExpiryHeap.NotifyCh:
// Entries changed, so the heap may have changed. Restart loop.
case <-expiryCh:
c.entriesLock.Lock()
// Entry expired! Remove it.
delete(c.entries, entry.Key)
heap.Remove(c.entriesExpiryHeap, entry.HeapIndex)
// This is subtle but important: if we race and simultaneously
// evict and fetch a new value, then we set this to -1 to
// have it treated as a new value so that the TTL is extended.
entry.HeapIndex = -1
// Set some metrics
metrics.IncrCounter([]string{"consul", "cache", "evict_expired"}, 1)
metrics.SetGauge([]string{"consul", "cache", "entries_count"}, float32(len(c.entries)))
c.entriesLock.Unlock()
}
}
}

760
agent/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,760 @@
package cache
import (
"fmt"
"sort"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// Test a basic Get with no indexes (and therefore no blocking queries).
func TestCacheGet_noIndex(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
typ.Static(FetchResult{Value: 42}, nil).Times(1)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Get, should not fetch since we already have a satisfying value
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.True(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test a basic Get with no index and a failed fetch.
func TestCacheGet_initError(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
fetcherr := fmt.Errorf("error")
typ.Static(FetchResult{}, fetcherr).Times(2)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.Error(err)
require.Nil(result)
require.False(meta.Hit)
// Get, should fetch again since our last fetch was an error
result, meta, err = c.Get("t", req)
require.Error(err)
require.Nil(result)
require.False(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test a Get with a request that returns a blank cache key. This should
// force a backend request and skip the cache entirely.
func TestCacheGet_blankCacheKey(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
typ.Static(FetchResult{Value: 42}, nil).Times(2)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: ""})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Get, should not fetch since we already have a satisfying value
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test that Get blocks on the initial value
func TestCacheGet_blockingInitSameKey(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
triggerCh := make(chan time.Time)
typ.Static(FetchResult{Value: 42}, nil).WaitUntil(triggerCh).Times(1)
// Perform multiple gets
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
// They should block
select {
case <-getCh1:
t.Fatal("should block (ch1)")
case <-getCh2:
t.Fatal("should block (ch2)")
case <-time.After(50 * time.Millisecond):
}
// Trigger it
close(triggerCh)
// Should return
TestCacheGetChResult(t, getCh1, 42)
TestCacheGetChResult(t, getCh2, 42)
}
// Test that Get with different cache keys both block on initial value
// but that the fetches were both properly called.
func TestCacheGet_blockingInitDiffKeys(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Keep track of the keys
var keysLock sync.Mutex
var keys []string
// Configure the type
triggerCh := make(chan time.Time)
typ.Static(FetchResult{Value: 42}, nil).
WaitUntil(triggerCh).
Times(2).
Run(func(args mock.Arguments) {
keysLock.Lock()
defer keysLock.Unlock()
keys = append(keys, args.Get(1).(Request).CacheInfo().Key)
})
// Perform multiple gets
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "goodbye"}))
// They should block
select {
case <-getCh1:
t.Fatal("should block (ch1)")
case <-getCh2:
t.Fatal("should block (ch2)")
case <-time.After(50 * time.Millisecond):
}
// Trigger it
close(triggerCh)
// Should return both!
TestCacheGetChResult(t, getCh1, 42)
TestCacheGetChResult(t, getCh2, 42)
// Verify proper keys
sort.Strings(keys)
require.Equal([]string{"goodbye", "hello"}, keys)
}
// Test a get with an index set will wait until an index that is higher
// is set in the cache.
func TestCacheGet_blockingIndex(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
triggerCh := make(chan time.Time)
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
typ.Static(FetchResult{Value: 42, Index: 6}, nil).WaitUntil(triggerCh)
// Fetch should block
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 5}))
// Should block
select {
case <-resultCh:
t.Fatal("should block")
case <-time.After(50 * time.Millisecond):
}
// Wait a bit
close(triggerCh)
// Should return
TestCacheGetChResult(t, resultCh, 42)
}
// Test a get with an index set will timeout if the fetch doesn't return
// anything.
func TestCacheGet_blockingIndexTimeout(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
triggerCh := make(chan time.Time)
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
typ.Static(FetchResult{Value: 42, Index: 6}, nil).WaitUntil(triggerCh)
// Fetch should block
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 5, Timeout: 200 * time.Millisecond}))
// Should block
select {
case <-resultCh:
t.Fatal("should block")
case <-time.After(50 * time.Millisecond):
}
// Should return after more of the timeout
select {
case result := <-resultCh:
require.Equal(t, 12, result)
case <-time.After(300 * time.Millisecond):
t.Fatal("should've returned")
}
}
// Test a get with an index set with requests returning an error
// will return that error.
func TestCacheGet_blockingIndexError(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
var retries uint32
fetchErr := fmt.Errorf("test fetch error")
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: nil, Index: 5}, fetchErr).Run(func(args mock.Arguments) {
atomic.AddUint32(&retries, 1)
})
// First good fetch to populate catch
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Fetch should not block and should return error
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 7, Timeout: 1 * time.Minute}))
TestCacheGetChResult(t, resultCh, nil)
// Wait a bit
time.Sleep(100 * time.Millisecond)
// Check the number
actual := atomic.LoadUint32(&retries)
require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
}
// Test that if a Type returns an empty value on Fetch that the previous
// value is preserved.
func TestCacheGet_emptyFetchResult(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
typ.Static(FetchResult{Value: 42, Index: 1}, nil).Times(1)
typ.Static(FetchResult{Value: nil}, nil)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Get, should not fetch since we already have a satisfying value
req = TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 1, Timeout: 100 * time.Millisecond})
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test that a type registered with a periodic refresh will perform
// that refresh after the timer is up.
func TestCacheGet_periodicRefresh(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, &RegisterOptions{
Refresh: true,
RefreshTimer: 100 * time.Millisecond,
RefreshTimeout: 5 * time.Minute,
})
// This is a bit weird, but we do this to ensure that the final
// call to the Fetch (if it happens, depends on timing) just blocks.
triggerCh := make(chan time.Time)
defer close(triggerCh)
// Configure the type
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
typ.Static(FetchResult{Value: 12, Index: 5}, nil).WaitUntil(triggerCh)
// Fetch should block
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Fetch again almost immediately should return old result
time.Sleep(5 * time.Millisecond)
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Wait for the timer
time.Sleep(200 * time.Millisecond)
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 12)
}
// Test that a type registered with a periodic refresh will perform
// that refresh after the timer is up.
func TestCacheGet_periodicRefreshMultiple(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, &RegisterOptions{
Refresh: true,
RefreshTimer: 0 * time.Millisecond,
RefreshTimeout: 5 * time.Minute,
})
// This is a bit weird, but we do this to ensure that the final
// call to the Fetch (if it happens, depends on timing) just blocks.
trigger := make([]chan time.Time, 3)
for i := range trigger {
trigger[i] = make(chan time.Time)
}
// Configure the type
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once().WaitUntil(trigger[0])
typ.Static(FetchResult{Value: 24, Index: 6}, nil).Once().WaitUntil(trigger[1])
typ.Static(FetchResult{Value: 42, Index: 7}, nil).WaitUntil(trigger[2])
// Fetch should block
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Fetch again almost immediately should return old result
time.Sleep(5 * time.Millisecond)
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Trigger the next, sleep a bit, and verify we get the next result
close(trigger[0])
time.Sleep(100 * time.Millisecond)
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 12)
// Trigger the next, sleep a bit, and verify we get the next result
close(trigger[1])
time.Sleep(100 * time.Millisecond)
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 24)
}
// Test that a refresh performs a backoff.
func TestCacheGet_periodicRefreshErrorBackoff(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, &RegisterOptions{
Refresh: true,
RefreshTimer: 0,
RefreshTimeout: 5 * time.Minute,
})
// Configure the type
var retries uint32
fetchErr := fmt.Errorf("test fetch error")
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: nil, Index: 5}, fetchErr).Run(func(args mock.Arguments) {
atomic.AddUint32(&retries, 1)
})
// Fetch
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Sleep a bit. The refresh will quietly fail in the background. What we
// want to verify is that it doesn't retry too much. "Too much" is hard
// to measure since its CPU dependent if this test is failing. But due
// to the short sleep below, we can calculate about what we'd expect if
// backoff IS working.
time.Sleep(500 * time.Millisecond)
// Fetch should work, we should get a 1 still. Errors are ignored.
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Check the number
actual := atomic.LoadUint32(&retries)
require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
}
// Test that a badly behaved RPC that returns 0 index will perform a backoff.
func TestCacheGet_periodicRefreshBadRPCZeroIndexErrorBackoff(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, &RegisterOptions{
Refresh: true,
RefreshTimer: 0,
RefreshTimeout: 5 * time.Minute,
})
// Configure the type
var retries uint32
typ.Static(FetchResult{Value: 0, Index: 0}, nil).Run(func(args mock.Arguments) {
atomic.AddUint32(&retries, 1)
})
// Fetch
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 0)
// Sleep a bit. The refresh will quietly fail in the background. What we
// want to verify is that it doesn't retry too much. "Too much" is hard
// to measure since its CPU dependent if this test is failing. But due
// to the short sleep below, we can calculate about what we'd expect if
// backoff IS working.
time.Sleep(500 * time.Millisecond)
// Fetch should work, we should get a 0 still. Errors are ignored.
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 0)
// Check the number
actual := atomic.LoadUint32(&retries)
require.True(t, actual < 10, fmt.Sprintf("%d retries, should be < 10", actual))
}
// Test that fetching with no index makes an initial request with no index, but
// then ensures all background refreshes have > 0. This ensures we don't end up
// with any index 0 loops from background refreshed while also returning
// immediately on the initial request if there is no data written to that table
// yet.
func TestCacheGet_noIndexSetsOne(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, &RegisterOptions{
Refresh: true,
RefreshTimer: 0,
RefreshTimeout: 5 * time.Minute,
})
// Simulate "well behaved" RPC with no data yet but returning 1
{
first := int32(1)
typ.Static(FetchResult{Value: 0, Index: 1}, nil).Run(func(args mock.Arguments) {
opts := args.Get(0).(FetchOptions)
isFirst := atomic.SwapInt32(&first, 0)
if isFirst == 1 {
assert.Equal(t, uint64(0), opts.MinIndex)
} else {
assert.True(t, opts.MinIndex > 0, "minIndex > 0")
}
})
// Fetch
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 0)
// Sleep a bit so background refresh happens
time.Sleep(100 * time.Millisecond)
}
// Same for "badly behaved" RPC that returns 0 index and no data
{
first := int32(1)
typ.Static(FetchResult{Value: 0, Index: 0}, nil).Run(func(args mock.Arguments) {
opts := args.Get(0).(FetchOptions)
isFirst := atomic.SwapInt32(&first, 0)
if isFirst == 1 {
assert.Equal(t, uint64(0), opts.MinIndex)
} else {
assert.True(t, opts.MinIndex > 0, "minIndex > 0")
}
})
// Fetch
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 0)
// Sleep a bit so background refresh happens
time.Sleep(100 * time.Millisecond)
}
}
// Test that the backend fetch sets the proper timeout.
func TestCacheGet_fetchTimeout(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
// Register the type with a timeout
timeout := 10 * time.Minute
c.RegisterType("t", typ, &RegisterOptions{
RefreshTimeout: timeout,
})
// Configure the type
var actual time.Duration
typ.Static(FetchResult{Value: 42}, nil).Times(1).Run(func(args mock.Arguments) {
opts := args.Get(0).(FetchOptions)
actual = opts.Timeout
})
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Test the timeout
require.Equal(timeout, actual)
}
// Test that entries expire
func TestCacheGet_expire(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
// Register the type with a timeout
c.RegisterType("t", typ, &RegisterOptions{
LastGetTTL: 400 * time.Millisecond,
})
// Configure the type
typ.Static(FetchResult{Value: 42}, nil).Times(2)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Get, should not fetch, verified via the mock assertions above
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.True(meta.Hit)
// Sleep for the expiry
time.Sleep(500 * time.Millisecond)
// Get, should fetch
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test that entries reset their TTL on Get
func TestCacheGet_expireResetGet(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
// Register the type with a timeout
c.RegisterType("t", typ, &RegisterOptions{
LastGetTTL: 150 * time.Millisecond,
})
// Configure the type
typ.Static(FetchResult{Value: 42}, nil).Times(2)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Fetch multiple times, where the total time is well beyond
// the TTL. We should not trigger any fetches during this time.
for i := 0; i < 5; i++ {
// Sleep a bit
time.Sleep(50 * time.Millisecond)
// Get, should not fetch
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.True(meta.Hit)
}
time.Sleep(200 * time.Millisecond)
// Get, should fetch
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test that Get partitions the caches based on DC so two equivalent requests
// to different datacenters are automatically cached even if their keys are
// the same.
func TestCacheGet_partitionDC(t *testing.T) {
t.Parallel()
c := TestCache(t)
c.RegisterType("t", &testPartitionType{}, nil)
// Perform multiple gets
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Datacenter: "dc1", Key: "hello"}))
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Datacenter: "dc9", Key: "hello"}))
// Should return both!
TestCacheGetChResult(t, getCh1, "dc1")
TestCacheGetChResult(t, getCh2, "dc9")
}
// Test that Get partitions the caches based on token so two equivalent requests
// with different ACL tokens do not return the same result.
func TestCacheGet_partitionToken(t *testing.T) {
t.Parallel()
c := TestCache(t)
c.RegisterType("t", &testPartitionType{}, nil)
// Perform multiple gets
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Token: "", Key: "hello"}))
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Token: "foo", Key: "hello"}))
// Should return both!
TestCacheGetChResult(t, getCh1, "")
TestCacheGetChResult(t, getCh2, "foo")
}
// testPartitionType implements Type for testing that simply returns a value
// comprised of the request DC and ACL token, used for testing cache
// partitioning.
type testPartitionType struct{}
func (t *testPartitionType) Fetch(opts FetchOptions, r Request) (FetchResult, error) {
info := r.CacheInfo()
return FetchResult{
Value: fmt.Sprintf("%s%s", info.Datacenter, info.Token),
}, nil
}

143
agent/cache/entry.go vendored Normal file
View File

@ -0,0 +1,143 @@
package cache
import (
"container/heap"
"time"
)
// cacheEntry stores a single cache entry.
//
// Note that this isn't a very optimized structure currently. There are
// a lot of improvements that can be made here in the long term.
type cacheEntry struct {
// Fields pertaining to the actual value
Value interface{}
Error error
Index uint64
// Metadata that is used for internal accounting
Valid bool // True if the Value is set
Fetching bool // True if a fetch is already active
Waiter chan struct{} // Closed when this entry is invalidated
// Expiry contains information about the expiration of this
// entry. This is a pointer as its shared as a value in the
// expiryHeap as well.
Expiry *cacheEntryExpiry
}
// cacheEntryExpiry contains the expiration information for a cache
// entry. Any modifications to this struct should be done only while
// the Cache entriesLock is held.
type cacheEntryExpiry struct {
Key string // Key in the cache map
Expires time.Time // Time when entry expires (monotonic clock)
TTL time.Duration // TTL for this entry to extend when resetting
HeapIndex int // Index in the heap
}
// Reset resets the expiration to be the ttl duration from now.
func (e *cacheEntryExpiry) Reset() {
e.Expires = time.Now().Add(e.TTL)
}
// expiryHeap is a heap implementation that stores information about
// when entires expire. Implements container/heap.Interface.
//
// All operations on the heap and read/write of the heap contents require
// the proper entriesLock to be held on Cache.
type expiryHeap struct {
Entries []*cacheEntryExpiry
// NotifyCh is sent a value whenever the 0 index value of the heap
// changes. This can be used to detect when the earliest value
// changes.
//
// There is a single edge case where the heap will not automatically
// send a notification: if heap.Fix is called manually and the index
// changed is 0 and the change doesn't result in any moves (stays at index
// 0), then we won't detect the change. To work around this, please
// always call the expiryHeap.Fix method instead.
NotifyCh chan struct{}
}
// Identical to heap.Fix for this heap instance but will properly handle
// the edge case where idx == 0 and no heap modification is necessary,
// and still notify the NotifyCh.
//
// This is important for cache expiry since the expiry time may have been
// extended and if we don't send a message to the NotifyCh then we'll never
// reset the timer and the entry will be evicted early.
func (h *expiryHeap) Fix(entry *cacheEntryExpiry) {
idx := entry.HeapIndex
heap.Fix(h, idx)
// This is the edge case we handle: if the prev (idx) and current (HeapIndex)
// is zero, it means the head-of-line didn't change while the value
// changed. Notify to reset our expiry worker.
if idx == 0 && entry.HeapIndex == 0 {
h.notify()
}
}
func (h *expiryHeap) Len() int { return len(h.Entries) }
func (h *expiryHeap) Swap(i, j int) {
h.Entries[i], h.Entries[j] = h.Entries[j], h.Entries[i]
h.Entries[i].HeapIndex = i
h.Entries[j].HeapIndex = j
// If we're moving the 0 index, update the channel since we need
// to re-update the timer we're waiting on for the soonest expiring
// value.
if i == 0 || j == 0 {
h.notify()
}
}
func (h *expiryHeap) Less(i, j int) bool {
// The usage of Before here is important (despite being obvious):
// this function uses the monotonic time that should be available
// on the time.Time value so the heap is immune to wall clock changes.
return h.Entries[i].Expires.Before(h.Entries[j].Expires)
}
// heap.Interface, this isn't expected to be called directly.
func (h *expiryHeap) Push(x interface{}) {
entry := x.(*cacheEntryExpiry)
// Set initial heap index, if we're going to the end then Swap
// won't be called so we need to initialize
entry.HeapIndex = len(h.Entries)
// For the first entry, we need to trigger a channel send because
// Swap won't be called; nothing to swap! We can call it right away
// because all heap operations are within a lock.
if len(h.Entries) == 0 {
h.notify()
}
h.Entries = append(h.Entries, entry)
}
// heap.Interface, this isn't expected to be called directly.
func (h *expiryHeap) Pop() interface{} {
old := h.Entries
n := len(old)
x := old[n-1]
h.Entries = old[0 : n-1]
return x
}
func (h *expiryHeap) notify() {
select {
case h.NotifyCh <- struct{}{}:
// Good
default:
// If the send would've blocked, we just ignore it. The reason this
// is safe is because NotifyCh should always be a buffered channel.
// If this blocks, it means that there is a pending message anyways
// so the receiver will restart regardless.
}
}

91
agent/cache/entry_test.go vendored Normal file
View File

@ -0,0 +1,91 @@
package cache
import (
"container/heap"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestExpiryHeap_impl(t *testing.T) {
var _ heap.Interface = new(expiryHeap)
}
func TestExpiryHeap(t *testing.T) {
require := require.New(t)
now := time.Now()
ch := make(chan struct{}, 10) // buffered to prevent blocking in tests
h := &expiryHeap{NotifyCh: ch}
// Init, shouldn't trigger anything
heap.Init(h)
testNoMessage(t, ch)
// Push an initial value, expect one message
entry := &cacheEntryExpiry{Key: "foo", HeapIndex: -1, Expires: now.Add(100)}
heap.Push(h, entry)
require.Equal(0, entry.HeapIndex)
testMessage(t, ch)
testNoMessage(t, ch) // exactly one asserted above
// Push another that goes earlier than entry
entry2 := &cacheEntryExpiry{Key: "bar", HeapIndex: -1, Expires: now.Add(50)}
heap.Push(h, entry2)
require.Equal(0, entry2.HeapIndex)
require.Equal(1, entry.HeapIndex)
testMessage(t, ch)
testNoMessage(t, ch) // exactly one asserted above
// Push another that goes at the end
entry3 := &cacheEntryExpiry{Key: "bar", HeapIndex: -1, Expires: now.Add(1000)}
heap.Push(h, entry3)
require.Equal(2, entry3.HeapIndex)
testNoMessage(t, ch) // no notify cause index 0 stayed the same
// Remove the first entry (not Pop, since we don't use Pop, but that works too)
remove := h.Entries[0]
heap.Remove(h, remove.HeapIndex)
require.Equal(0, entry.HeapIndex)
require.Equal(1, entry3.HeapIndex)
testMessage(t, ch)
testMessage(t, ch) // we have two because two swaps happen
testNoMessage(t, ch)
// Let's change entry 3 to be early, and fix it
entry3.Expires = now.Add(10)
h.Fix(entry3)
require.Equal(1, entry.HeapIndex)
require.Equal(0, entry3.HeapIndex)
testMessage(t, ch)
testNoMessage(t, ch)
// Let's change entry 3 again, this is an edge case where if the 0th
// element changed, we didn't trigger the channel. Our Fix func should.
entry.Expires = now.Add(20)
h.Fix(entry3)
require.Equal(1, entry.HeapIndex) // no move
require.Equal(0, entry3.HeapIndex)
testMessage(t, ch)
testNoMessage(t, ch) // one message
}
func testNoMessage(t *testing.T, ch <-chan struct{}) {
t.Helper()
select {
case <-ch:
t.Fatal("should not have a message")
default:
}
}
func testMessage(t *testing.T, ch <-chan struct{}) {
t.Helper()
select {
case <-ch:
default:
t.Fatal("should have a message")
}
}

23
agent/cache/mock_Request.go vendored Normal file
View File

@ -0,0 +1,23 @@
// Code generated by mockery v1.0.0
package cache
import mock "github.com/stretchr/testify/mock"
// MockRequest is an autogenerated mock type for the Request type
type MockRequest struct {
mock.Mock
}
// CacheInfo provides a mock function with given fields:
func (_m *MockRequest) CacheInfo() RequestInfo {
ret := _m.Called()
var r0 RequestInfo
if rf, ok := ret.Get(0).(func() RequestInfo); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(RequestInfo)
}
return r0
}

30
agent/cache/mock_Type.go vendored Normal file
View File

@ -0,0 +1,30 @@
// Code generated by mockery v1.0.0
package cache
import mock "github.com/stretchr/testify/mock"
// MockType is an autogenerated mock type for the Type type
type MockType struct {
mock.Mock
}
// Fetch provides a mock function with given fields: _a0, _a1
func (_m *MockType) Fetch(_a0 FetchOptions, _a1 Request) (FetchResult, error) {
ret := _m.Called(_a0, _a1)
var r0 FetchResult
if rf, ok := ret.Get(0).(func(FetchOptions, Request) FetchResult); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Get(0).(FetchResult)
}
var r1 error
if rf, ok := ret.Get(1).(func(FetchOptions, Request) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

51
agent/cache/request.go vendored Normal file
View File

@ -0,0 +1,51 @@
package cache
import (
"time"
)
// Request is a cacheable request.
//
// This interface is typically implemented by request structures in
// the agent/structs package.
type Request interface {
// CacheInfo returns information used for caching this request.
CacheInfo() RequestInfo
}
// RequestInfo represents cache information for a request. The caching
// framework uses this to control the behavior of caching and to determine
// cacheability.
type RequestInfo struct {
// Key is a unique cache key for this request. This key should
// be globally unique to identify this request, since any conflicting
// cache keys could result in invalid data being returned from the cache.
// The Key does not need to include ACL or DC information, since the
// cache already partitions by these values prior to using this key.
Key string
// Token is the ACL token associated with this request.
//
// Datacenter is the datacenter that the request is targeting.
//
// Both of these values are used to partition the cache. The cache framework
// today partitions data on these values to simplify behavior: by
// partitioning ACL tokens, the cache doesn't need to be smart about
// filtering results. By filtering datacenter results, the cache can
// service the multi-DC nature of Consul. This comes at the expense of
// working set size, but in general the effect is minimal.
Token string
Datacenter string
// MinIndex is the minimum index being queried. This is used to
// determine if we already have data satisfying the query or if we need
// to block until new data is available. If no index is available, the
// default value (zero) is acceptable.
MinIndex uint64
// Timeout is the timeout for waiting on a blocking query. When the
// timeout is reached, the last known value is returned (or maybe nil
// if there was no prior value). This "last known value" behavior matches
// normal Consul blocking queries.
Timeout time.Duration
}

78
agent/cache/testing.go vendored Normal file
View File

@ -0,0 +1,78 @@
package cache
import (
"reflect"
"time"
"github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/mock"
)
// TestCache returns a Cache instance configuring for testing.
func TestCache(t testing.T) *Cache {
// Simple but lets us do some fine-tuning later if we want to.
return New(nil)
}
// TestCacheGetCh returns a channel that returns the result of the Get call.
// This is useful for testing timing and concurrency with Get calls. Any
// error will be logged, so the result value should always be asserted.
func TestCacheGetCh(t testing.T, c *Cache, typ string, r Request) <-chan interface{} {
resultCh := make(chan interface{})
go func() {
result, _, err := c.Get(typ, r)
if err != nil {
t.Logf("Error: %s", err)
close(resultCh)
return
}
resultCh <- result
}()
return resultCh
}
// TestCacheGetChResult tests that the result from TestCacheGetCh matches
// within a reasonable period of time (it expects it to be "immediate" but
// waits some milliseconds).
func TestCacheGetChResult(t testing.T, ch <-chan interface{}, expected interface{}) {
t.Helper()
select {
case result := <-ch:
if !reflect.DeepEqual(result, expected) {
t.Fatalf("Result doesn't match!\n\n%#v\n\n%#v", result, expected)
}
case <-time.After(50 * time.Millisecond):
t.Fatalf("Result not sent on channel")
}
}
// TestRequest returns a Request that returns the given cache key and index.
// The Reset method can be called to reset it for custom usage.
func TestRequest(t testing.T, info RequestInfo) *MockRequest {
req := &MockRequest{}
req.On("CacheInfo").Return(info)
return req
}
// TestType returns a MockType that can be used to setup expectations
// on data fetching.
func TestType(t testing.T) *MockType {
typ := &MockType{}
return typ
}
// A bit weird, but we add methods to the auto-generated structs here so that
// they don't get clobbered. The helper methods are conveniences.
// Static sets a static value to return for a call to Fetch.
func (m *MockType) Static(r FetchResult, err error) *mock.Call {
return m.Mock.On("Fetch", mock.Anything, mock.Anything).Return(r, err)
}
func (m *MockRequest) Reset() {
m.Mock = mock.Mock{}
}

48
agent/cache/type.go vendored Normal file
View File

@ -0,0 +1,48 @@
package cache
import (
"time"
)
// Type implements the logic to fetch certain types of data.
type Type interface {
// Fetch fetches a single unique item.
//
// The FetchOptions contain the index and timeouts for blocking queries.
// The MinIndex value on the Request itself should NOT be used
// as the blocking index since a request may be reused multiple times
// as part of Refresh behavior.
//
// The return value is a FetchResult which contains information about
// the fetch. If an error is given, the FetchResult is ignored. The
// cache does not support backends that return partial values.
//
// On timeout, FetchResult can behave one of two ways. First, it can
// return the last known value. This is the default behavior of blocking
// RPC calls in Consul so this allows cache types to be implemented with
// no extra logic. Second, FetchResult can return an unset value and index.
// In this case, the cache will reuse the last value automatically.
Fetch(FetchOptions, Request) (FetchResult, error)
}
// FetchOptions are various settable options when a Fetch is called.
type FetchOptions struct {
// MinIndex is the minimum index to be used for blocking queries.
// If blocking queries aren't supported for data being returned,
// this value can be ignored.
MinIndex uint64
// Timeout is the maximum time for the query. This must be implemented
// in the Fetch itself.
Timeout time.Duration
}
// FetchResult is the result of a Type Fetch operation and contains the
// data along with metadata gathered from that operation.
type FetchResult struct {
// Value is the result of the fetch.
Value interface{}
// Index is the corresponding index value for this data.
Index uint64
}

View File

@ -157,12 +157,27 @@ RETRY_ONCE:
return out.Services, nil
}
func (s *HTTPServer) CatalogConnectServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return s.catalogServiceNodes(resp, req, true)
}
func (s *HTTPServer) CatalogServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_service_nodes"}, 1,
return s.catalogServiceNodes(resp, req, false)
}
func (s *HTTPServer) catalogServiceNodes(resp http.ResponseWriter, req *http.Request, connect bool) (interface{}, error) {
metricsKey := "catalog_service_nodes"
pathPrefix := "/v1/catalog/service/"
if connect {
metricsKey = "catalog_connect_service_nodes"
pathPrefix = "/v1/catalog/connect/"
}
metrics.IncrCounterWithLabels([]string{"client", "api", metricsKey}, 1,
[]metrics.Label{{Name: "node", Value: s.nodeName()}})
// Set default DC
args := structs.ServiceSpecificRequest{}
args := structs.ServiceSpecificRequest{Connect: connect}
s.parseSource(req, &args.Source)
args.NodeMetaFilters = s.parseMetaFilter(req)
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
@ -177,7 +192,7 @@ func (s *HTTPServer) CatalogServiceNodes(resp http.ResponseWriter, req *http.Req
}
// Pull out the service name
args.ServiceName = strings.TrimPrefix(req.URL.Path, "/v1/catalog/service/")
args.ServiceName = strings.TrimPrefix(req.URL.Path, pathPrefix)
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service name")

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/testutil/retry"
"github.com/hashicorp/serf/coordinate"
"github.com/stretchr/testify/assert"
)
func TestCatalogRegister_Service_InvalidAddress(t *testing.T) {
@ -750,6 +751,60 @@ func TestCatalogServiceNodes_DistanceSort(t *testing.T) {
}
}
// Test that connect proxies can be queried via /v1/catalog/service/:service
// directly and that their results contain the proxy fields.
func TestCatalogServiceNodes_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/service/%s", args.Service.Service), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1)
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
}
// Test that the Connect-compatible endpoints can be queried for a
// service via /v1/catalog/connect/:service.
func TestCatalogConnectServiceNodes_good(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
args := structs.TestRegisterRequestProxy(t)
args.Service.Address = "127.0.0.55"
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/connect/%s", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogConnectServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1)
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal(args.Service.Address, nodes[0].ServiceAddress)
}
func TestCatalogNodeServices(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
@ -785,6 +840,33 @@ func TestCatalogNodeServices(t *testing.T) {
}
}
// Test that the services on a node contain all the Connect proxies on
// the node as well with their fields properly populated.
func TestCatalogNodeServices_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/node/%s", args.Node), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogNodeServices(resp, req)
assert.Nil(err)
assertIndex(t, resp)
ns := obj.(*structs.NodeServices)
assert.Len(ns.Services, 1)
v := ns.Services[args.Service.Service]
assert.Equal(structs.ServiceKindConnectProxy, v.Kind)
}
func TestCatalogNodeServices_WanTranslation(t *testing.T) {
t.Parallel()
a1 := NewTestAgent(t.Name(), `

View File

@ -14,9 +14,11 @@ import (
"strings"
"time"
"github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/consul"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types"
multierror "github.com/hashicorp/go-multierror"
@ -340,6 +342,12 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
serverPort := b.portVal("ports.server", c.Ports.Server)
serfPortLAN := b.portVal("ports.serf_lan", c.Ports.SerfLAN)
serfPortWAN := b.portVal("ports.serf_wan", c.Ports.SerfWAN)
proxyMinPort := b.portVal("ports.proxy_min_port", c.Ports.ProxyMinPort)
proxyMaxPort := b.portVal("ports.proxy_max_port", c.Ports.ProxyMaxPort)
if proxyMaxPort < proxyMinPort {
return RuntimeConfig{}, fmt.Errorf(
"proxy_min_port must be less than proxy_max_port. To disable, set both to zero.")
}
// determine the default bind and advertise address
//
@ -520,6 +528,30 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
consulRaftHeartbeatTimeout := b.durationVal("consul.raft.heartbeat_timeout", c.Consul.Raft.HeartbeatTimeout) * time.Duration(performanceRaftMultiplier)
consulRaftLeaderLeaseTimeout := b.durationVal("consul.raft.leader_lease_timeout", c.Consul.Raft.LeaderLeaseTimeout) * time.Duration(performanceRaftMultiplier)
// Connect proxy defaults.
connectEnabled := b.boolVal(c.Connect.Enabled)
connectCAProvider := b.stringVal(c.Connect.CAProvider)
connectCAConfig := c.Connect.CAConfig
if connectCAConfig != nil {
TranslateKeys(connectCAConfig, map[string]string{
// Consul CA config
"private_key": "PrivateKey",
"root_cert": "RootCert",
"rotation_period": "RotationPeriod",
// Vault CA config
"address": "Address",
"token": "Token",
"root_pki_path": "RootPKIPath",
"intermediate_pki_path": "IntermediatePKIPath",
})
}
proxyDefaultExecMode := b.stringVal(c.Connect.ProxyDefaults.ExecMode)
proxyDefaultDaemonCommand := c.Connect.ProxyDefaults.DaemonCommand
proxyDefaultScriptCommand := c.Connect.ProxyDefaults.ScriptCommand
proxyDefaultConfig := c.Connect.ProxyDefaults.Config
// ----------------------------------------------------------------
// build runtime config
//
@ -592,6 +624,7 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
DNSRecursors: dnsRecursors,
DNSServiceTTL: dnsServiceTTL,
DNSUDPAnswerLimit: b.intVal(c.DNS.UDPAnswerLimit),
DNSNodeMetaTXT: b.boolValWithDefault(c.DNS.NodeMetaTXT, true),
// HTTP
HTTPPort: httpPort,
@ -602,120 +635,133 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
HTTPResponseHeaders: c.HTTPConfig.ResponseHeaders,
// Telemetry
TelemetryCirconusAPIApp: b.stringVal(c.Telemetry.CirconusAPIApp),
TelemetryCirconusAPIToken: b.stringVal(c.Telemetry.CirconusAPIToken),
TelemetryCirconusAPIURL: b.stringVal(c.Telemetry.CirconusAPIURL),
TelemetryCirconusBrokerID: b.stringVal(c.Telemetry.CirconusBrokerID),
TelemetryCirconusBrokerSelectTag: b.stringVal(c.Telemetry.CirconusBrokerSelectTag),
TelemetryCirconusCheckDisplayName: b.stringVal(c.Telemetry.CirconusCheckDisplayName),
TelemetryCirconusCheckForceMetricActivation: b.stringVal(c.Telemetry.CirconusCheckForceMetricActivation),
TelemetryCirconusCheckID: b.stringVal(c.Telemetry.CirconusCheckID),
TelemetryCirconusCheckInstanceID: b.stringVal(c.Telemetry.CirconusCheckInstanceID),
TelemetryCirconusCheckSearchTag: b.stringVal(c.Telemetry.CirconusCheckSearchTag),
TelemetryCirconusCheckTags: b.stringVal(c.Telemetry.CirconusCheckTags),
TelemetryCirconusSubmissionInterval: b.stringVal(c.Telemetry.CirconusSubmissionInterval),
TelemetryCirconusSubmissionURL: b.stringVal(c.Telemetry.CirconusSubmissionURL),
TelemetryDisableHostname: b.boolVal(c.Telemetry.DisableHostname),
TelemetryDogstatsdAddr: b.stringVal(c.Telemetry.DogstatsdAddr),
TelemetryDogstatsdTags: c.Telemetry.DogstatsdTags,
TelemetryPrometheusRetentionTime: b.durationVal("prometheus_retention_time", c.Telemetry.PrometheusRetentionTime),
TelemetryFilterDefault: b.boolVal(c.Telemetry.FilterDefault),
TelemetryAllowedPrefixes: telemetryAllowedPrefixes,
TelemetryBlockedPrefixes: telemetryBlockedPrefixes,
TelemetryMetricsPrefix: b.stringVal(c.Telemetry.MetricsPrefix),
TelemetryStatsdAddr: b.stringVal(c.Telemetry.StatsdAddr),
TelemetryStatsiteAddr: b.stringVal(c.Telemetry.StatsiteAddr),
Telemetry: lib.TelemetryConfig{
CirconusAPIApp: b.stringVal(c.Telemetry.CirconusAPIApp),
CirconusAPIToken: b.stringVal(c.Telemetry.CirconusAPIToken),
CirconusAPIURL: b.stringVal(c.Telemetry.CirconusAPIURL),
CirconusBrokerID: b.stringVal(c.Telemetry.CirconusBrokerID),
CirconusBrokerSelectTag: b.stringVal(c.Telemetry.CirconusBrokerSelectTag),
CirconusCheckDisplayName: b.stringVal(c.Telemetry.CirconusCheckDisplayName),
CirconusCheckForceMetricActivation: b.stringVal(c.Telemetry.CirconusCheckForceMetricActivation),
CirconusCheckID: b.stringVal(c.Telemetry.CirconusCheckID),
CirconusCheckInstanceID: b.stringVal(c.Telemetry.CirconusCheckInstanceID),
CirconusCheckSearchTag: b.stringVal(c.Telemetry.CirconusCheckSearchTag),
CirconusCheckTags: b.stringVal(c.Telemetry.CirconusCheckTags),
CirconusSubmissionInterval: b.stringVal(c.Telemetry.CirconusSubmissionInterval),
CirconusSubmissionURL: b.stringVal(c.Telemetry.CirconusSubmissionURL),
DisableHostname: b.boolVal(c.Telemetry.DisableHostname),
DogstatsdAddr: b.stringVal(c.Telemetry.DogstatsdAddr),
DogstatsdTags: c.Telemetry.DogstatsdTags,
PrometheusRetentionTime: b.durationVal("prometheus_retention_time", c.Telemetry.PrometheusRetentionTime),
FilterDefault: b.boolVal(c.Telemetry.FilterDefault),
AllowedPrefixes: telemetryAllowedPrefixes,
BlockedPrefixes: telemetryBlockedPrefixes,
MetricsPrefix: b.stringVal(c.Telemetry.MetricsPrefix),
StatsdAddr: b.stringVal(c.Telemetry.StatsdAddr),
StatsiteAddr: b.stringVal(c.Telemetry.StatsiteAddr),
},
// Agent
AdvertiseAddrLAN: advertiseAddrLAN,
AdvertiseAddrWAN: advertiseAddrWAN,
BindAddr: bindAddr,
Bootstrap: b.boolVal(c.Bootstrap),
BootstrapExpect: b.intVal(c.BootstrapExpect),
CAFile: b.stringVal(c.CAFile),
CAPath: b.stringVal(c.CAPath),
CertFile: b.stringVal(c.CertFile),
CheckUpdateInterval: b.durationVal("check_update_interval", c.CheckUpdateInterval),
Checks: checks,
ClientAddrs: clientAddrs,
DataDir: b.stringVal(c.DataDir),
Datacenter: strings.ToLower(b.stringVal(c.Datacenter)),
DevMode: b.boolVal(b.Flags.DevMode),
DisableAnonymousSignature: b.boolVal(c.DisableAnonymousSignature),
DisableCoordinates: b.boolVal(c.DisableCoordinates),
DisableHostNodeID: b.boolVal(c.DisableHostNodeID),
DisableKeyringFile: b.boolVal(c.DisableKeyringFile),
DisableRemoteExec: b.boolVal(c.DisableRemoteExec),
DisableUpdateCheck: b.boolVal(c.DisableUpdateCheck),
DiscardCheckOutput: b.boolVal(c.DiscardCheckOutput),
DiscoveryMaxStale: b.durationVal("discovery_max_stale", c.DiscoveryMaxStale),
EnableAgentTLSForChecks: b.boolVal(c.EnableAgentTLSForChecks),
EnableDebug: b.boolVal(c.EnableDebug),
EnableScriptChecks: b.boolVal(c.EnableScriptChecks),
EnableSyslog: b.boolVal(c.EnableSyslog),
EnableUI: b.boolVal(c.UI),
EncryptKey: b.stringVal(c.EncryptKey),
EncryptVerifyIncoming: b.boolVal(c.EncryptVerifyIncoming),
EncryptVerifyOutgoing: b.boolVal(c.EncryptVerifyOutgoing),
KeyFile: b.stringVal(c.KeyFile),
LeaveDrainTime: b.durationVal("performance.leave_drain_time", c.Performance.LeaveDrainTime),
LeaveOnTerm: leaveOnTerm,
LogLevel: b.stringVal(c.LogLevel),
NodeID: types.NodeID(b.stringVal(c.NodeID)),
NodeMeta: c.NodeMeta,
NodeName: b.nodeName(c.NodeName),
NonVotingServer: b.boolVal(c.NonVotingServer),
PidFile: b.stringVal(c.PidFile),
RPCAdvertiseAddr: rpcAdvertiseAddr,
RPCBindAddr: rpcBindAddr,
RPCHoldTimeout: b.durationVal("performance.rpc_hold_timeout", c.Performance.RPCHoldTimeout),
RPCMaxBurst: b.intVal(c.Limits.RPCMaxBurst),
RPCProtocol: b.intVal(c.RPCProtocol),
RPCRateLimit: rate.Limit(b.float64Val(c.Limits.RPCRate)),
RaftProtocol: b.intVal(c.RaftProtocol),
RaftSnapshotThreshold: b.intVal(c.RaftSnapshotThreshold),
RaftSnapshotInterval: b.durationVal("raft_snapshot_interval", c.RaftSnapshotInterval),
ReconnectTimeoutLAN: b.durationVal("reconnect_timeout", c.ReconnectTimeoutLAN),
ReconnectTimeoutWAN: b.durationVal("reconnect_timeout_wan", c.ReconnectTimeoutWAN),
RejoinAfterLeave: b.boolVal(c.RejoinAfterLeave),
RetryJoinIntervalLAN: b.durationVal("retry_interval", c.RetryJoinIntervalLAN),
RetryJoinIntervalWAN: b.durationVal("retry_interval_wan", c.RetryJoinIntervalWAN),
RetryJoinLAN: b.expandAllOptionalAddrs("retry_join", c.RetryJoinLAN),
RetryJoinMaxAttemptsLAN: b.intVal(c.RetryJoinMaxAttemptsLAN),
RetryJoinMaxAttemptsWAN: b.intVal(c.RetryJoinMaxAttemptsWAN),
RetryJoinWAN: b.expandAllOptionalAddrs("retry_join_wan", c.RetryJoinWAN),
SegmentName: b.stringVal(c.SegmentName),
Segments: segments,
SerfAdvertiseAddrLAN: serfAdvertiseAddrLAN,
SerfAdvertiseAddrWAN: serfAdvertiseAddrWAN,
SerfBindAddrLAN: serfBindAddrLAN,
SerfBindAddrWAN: serfBindAddrWAN,
SerfPortLAN: serfPortLAN,
SerfPortWAN: serfPortWAN,
ServerMode: b.boolVal(c.ServerMode),
ServerName: b.stringVal(c.ServerName),
ServerPort: serverPort,
Services: services,
SessionTTLMin: b.durationVal("session_ttl_min", c.SessionTTLMin),
SkipLeaveOnInt: skipLeaveOnInt,
StartJoinAddrsLAN: b.expandAllOptionalAddrs("start_join", c.StartJoinAddrsLAN),
StartJoinAddrsWAN: b.expandAllOptionalAddrs("start_join_wan", c.StartJoinAddrsWAN),
SyslogFacility: b.stringVal(c.SyslogFacility),
TLSCipherSuites: b.tlsCipherSuites("tls_cipher_suites", c.TLSCipherSuites),
TLSMinVersion: b.stringVal(c.TLSMinVersion),
TLSPreferServerCipherSuites: b.boolVal(c.TLSPreferServerCipherSuites),
TaggedAddresses: c.TaggedAddresses,
TranslateWANAddrs: b.boolVal(c.TranslateWANAddrs),
UIDir: b.stringVal(c.UIDir),
UnixSocketGroup: b.stringVal(c.UnixSocket.Group),
UnixSocketMode: b.stringVal(c.UnixSocket.Mode),
UnixSocketUser: b.stringVal(c.UnixSocket.User),
VerifyIncoming: b.boolVal(c.VerifyIncoming),
VerifyIncomingHTTPS: b.boolVal(c.VerifyIncomingHTTPS),
VerifyIncomingRPC: b.boolVal(c.VerifyIncomingRPC),
VerifyOutgoing: b.boolVal(c.VerifyOutgoing),
VerifyServerHostname: b.boolVal(c.VerifyServerHostname),
Watches: c.Watches,
AdvertiseAddrLAN: advertiseAddrLAN,
AdvertiseAddrWAN: advertiseAddrWAN,
BindAddr: bindAddr,
Bootstrap: b.boolVal(c.Bootstrap),
BootstrapExpect: b.intVal(c.BootstrapExpect),
CAFile: b.stringVal(c.CAFile),
CAPath: b.stringVal(c.CAPath),
CertFile: b.stringVal(c.CertFile),
CheckUpdateInterval: b.durationVal("check_update_interval", c.CheckUpdateInterval),
Checks: checks,
ClientAddrs: clientAddrs,
ConnectEnabled: connectEnabled,
ConnectCAProvider: connectCAProvider,
ConnectCAConfig: connectCAConfig,
ConnectProxyAllowManagedRoot: b.boolVal(c.Connect.Proxy.AllowManagedRoot),
ConnectProxyAllowManagedAPIRegistration: b.boolVal(c.Connect.Proxy.AllowManagedAPIRegistration),
ConnectProxyBindMinPort: proxyMinPort,
ConnectProxyBindMaxPort: proxyMaxPort,
ConnectProxyDefaultExecMode: proxyDefaultExecMode,
ConnectProxyDefaultDaemonCommand: proxyDefaultDaemonCommand,
ConnectProxyDefaultScriptCommand: proxyDefaultScriptCommand,
ConnectProxyDefaultConfig: proxyDefaultConfig,
DataDir: b.stringVal(c.DataDir),
Datacenter: strings.ToLower(b.stringVal(c.Datacenter)),
DevMode: b.boolVal(b.Flags.DevMode),
DisableAnonymousSignature: b.boolVal(c.DisableAnonymousSignature),
DisableCoordinates: b.boolVal(c.DisableCoordinates),
DisableHostNodeID: b.boolVal(c.DisableHostNodeID),
DisableKeyringFile: b.boolVal(c.DisableKeyringFile),
DisableRemoteExec: b.boolVal(c.DisableRemoteExec),
DisableUpdateCheck: b.boolVal(c.DisableUpdateCheck),
DiscardCheckOutput: b.boolVal(c.DiscardCheckOutput),
DiscoveryMaxStale: b.durationVal("discovery_max_stale", c.DiscoveryMaxStale),
EnableAgentTLSForChecks: b.boolVal(c.EnableAgentTLSForChecks),
EnableDebug: b.boolVal(c.EnableDebug),
EnableScriptChecks: b.boolVal(c.EnableScriptChecks),
EnableSyslog: b.boolVal(c.EnableSyslog),
EnableUI: b.boolVal(c.UI),
EncryptKey: b.stringVal(c.EncryptKey),
EncryptVerifyIncoming: b.boolVal(c.EncryptVerifyIncoming),
EncryptVerifyOutgoing: b.boolVal(c.EncryptVerifyOutgoing),
KeyFile: b.stringVal(c.KeyFile),
LeaveDrainTime: b.durationVal("performance.leave_drain_time", c.Performance.LeaveDrainTime),
LeaveOnTerm: leaveOnTerm,
LogLevel: b.stringVal(c.LogLevel),
NodeID: types.NodeID(b.stringVal(c.NodeID)),
NodeMeta: c.NodeMeta,
NodeName: b.nodeName(c.NodeName),
NonVotingServer: b.boolVal(c.NonVotingServer),
PidFile: b.stringVal(c.PidFile),
RPCAdvertiseAddr: rpcAdvertiseAddr,
RPCBindAddr: rpcBindAddr,
RPCHoldTimeout: b.durationVal("performance.rpc_hold_timeout", c.Performance.RPCHoldTimeout),
RPCMaxBurst: b.intVal(c.Limits.RPCMaxBurst),
RPCProtocol: b.intVal(c.RPCProtocol),
RPCRateLimit: rate.Limit(b.float64Val(c.Limits.RPCRate)),
RaftProtocol: b.intVal(c.RaftProtocol),
RaftSnapshotThreshold: b.intVal(c.RaftSnapshotThreshold),
RaftSnapshotInterval: b.durationVal("raft_snapshot_interval", c.RaftSnapshotInterval),
ReconnectTimeoutLAN: b.durationVal("reconnect_timeout", c.ReconnectTimeoutLAN),
ReconnectTimeoutWAN: b.durationVal("reconnect_timeout_wan", c.ReconnectTimeoutWAN),
RejoinAfterLeave: b.boolVal(c.RejoinAfterLeave),
RetryJoinIntervalLAN: b.durationVal("retry_interval", c.RetryJoinIntervalLAN),
RetryJoinIntervalWAN: b.durationVal("retry_interval_wan", c.RetryJoinIntervalWAN),
RetryJoinLAN: b.expandAllOptionalAddrs("retry_join", c.RetryJoinLAN),
RetryJoinMaxAttemptsLAN: b.intVal(c.RetryJoinMaxAttemptsLAN),
RetryJoinMaxAttemptsWAN: b.intVal(c.RetryJoinMaxAttemptsWAN),
RetryJoinWAN: b.expandAllOptionalAddrs("retry_join_wan", c.RetryJoinWAN),
SegmentName: b.stringVal(c.SegmentName),
Segments: segments,
SerfAdvertiseAddrLAN: serfAdvertiseAddrLAN,
SerfAdvertiseAddrWAN: serfAdvertiseAddrWAN,
SerfBindAddrLAN: serfBindAddrLAN,
SerfBindAddrWAN: serfBindAddrWAN,
SerfPortLAN: serfPortLAN,
SerfPortWAN: serfPortWAN,
ServerMode: b.boolVal(c.ServerMode),
ServerName: b.stringVal(c.ServerName),
ServerPort: serverPort,
Services: services,
SessionTTLMin: b.durationVal("session_ttl_min", c.SessionTTLMin),
SkipLeaveOnInt: skipLeaveOnInt,
StartJoinAddrsLAN: b.expandAllOptionalAddrs("start_join", c.StartJoinAddrsLAN),
StartJoinAddrsWAN: b.expandAllOptionalAddrs("start_join_wan", c.StartJoinAddrsWAN),
SyslogFacility: b.stringVal(c.SyslogFacility),
TLSCipherSuites: b.tlsCipherSuites("tls_cipher_suites", c.TLSCipherSuites),
TLSMinVersion: b.stringVal(c.TLSMinVersion),
TLSPreferServerCipherSuites: b.boolVal(c.TLSPreferServerCipherSuites),
TaggedAddresses: c.TaggedAddresses,
TranslateWANAddrs: b.boolVal(c.TranslateWANAddrs),
UIDir: b.stringVal(c.UIDir),
UnixSocketGroup: b.stringVal(c.UnixSocket.Group),
UnixSocketMode: b.stringVal(c.UnixSocket.Mode),
UnixSocketUser: b.stringVal(c.UnixSocket.User),
VerifyIncoming: b.boolVal(c.VerifyIncoming),
VerifyIncomingHTTPS: b.boolVal(c.VerifyIncomingHTTPS),
VerifyIncomingRPC: b.boolVal(c.VerifyIncomingRPC),
VerifyOutgoing: b.boolVal(c.VerifyOutgoing),
VerifyServerHostname: b.boolVal(c.VerifyServerHostname),
Watches: c.Watches,
}
if rt.BootstrapExpect == 1 {
@ -881,6 +927,34 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
return b.err
}
// Check for errors in the service definitions
for _, s := range rt.Services {
if err := s.Validate(); err != nil {
return fmt.Errorf("service %q: %s", s.Name, err)
}
}
// Validate the given Connect CA provider config
validCAProviders := map[string]bool{
"": true,
structs.ConsulCAProvider: true,
structs.VaultCAProvider: true,
}
if _, ok := validCAProviders[rt.ConnectCAProvider]; !ok {
return fmt.Errorf("%s is not a valid CA provider", rt.ConnectCAProvider)
} else {
switch rt.ConnectCAProvider {
case structs.ConsulCAProvider:
if _, err := ca.ParseConsulCAConfig(rt.ConnectCAConfig); err != nil {
return err
}
case structs.VaultCAProvider:
if _, err := ca.ParseVaultCAConfig(rt.ConnectCAConfig); err != nil {
return err
}
}
}
// ----------------------------------------------------------------
// warnings
//
@ -1007,16 +1081,41 @@ func (b *Builder) serviceVal(v *ServiceDefinition) *structs.ServiceDefinition {
Token: b.stringVal(v.Token),
EnableTagOverride: b.boolVal(v.EnableTagOverride),
Checks: checks,
Connect: b.serviceConnectVal(v.Connect),
}
}
func (b *Builder) boolVal(v *bool) bool {
func (b *Builder) serviceConnectVal(v *ServiceConnect) *structs.ServiceConnect {
if v == nil {
return false
return nil
}
var proxy *structs.ServiceDefinitionConnectProxy
if v.Proxy != nil {
proxy = &structs.ServiceDefinitionConnectProxy{
ExecMode: b.stringVal(v.Proxy.ExecMode),
Command: v.Proxy.Command,
Config: v.Proxy.Config,
}
}
return &structs.ServiceConnect{
Proxy: proxy,
}
}
func (b *Builder) boolValWithDefault(v *bool, default_val bool) bool {
if v == nil {
return default_val
}
return *v
}
func (b *Builder) boolVal(v *bool) bool {
return b.boolValWithDefault(v, false)
}
func (b *Builder) durationVal(name string, v *string) (d time.Duration) {
if v == nil {
return 0

View File

@ -84,6 +84,7 @@ func Parse(data string, format string) (c Config, err error) {
"services",
"services.checks",
"watches",
"service.connect.proxy.config.upstreams",
})
// There is a difference of representation of some fields depending on
@ -159,6 +160,7 @@ type Config struct {
CheckUpdateInterval *string `json:"check_update_interval,omitempty" hcl:"check_update_interval" mapstructure:"check_update_interval"`
Checks []CheckDefinition `json:"checks,omitempty" hcl:"checks" mapstructure:"checks"`
ClientAddr *string `json:"client_addr,omitempty" hcl:"client_addr" mapstructure:"client_addr"`
Connect Connect `json:"connect,omitempty" hcl:"connect" mapstructure:"connect"`
DNS DNS `json:"dns_config,omitempty" hcl:"dns_config" mapstructure:"dns_config"`
DNSDomain *string `json:"domain,omitempty" hcl:"domain" mapstructure:"domain"`
DNSRecursors []string `json:"recursors,omitempty" hcl:"recursors" mapstructure:"recursors"`
@ -324,6 +326,7 @@ type ServiceDefinition struct {
Checks []CheckDefinition `json:"checks,omitempty" hcl:"checks" mapstructure:"checks"`
Token *string `json:"token,omitempty" hcl:"token" mapstructure:"token"`
EnableTagOverride *bool `json:"enable_tag_override,omitempty" hcl:"enable_tag_override" mapstructure:"enable_tag_override"`
Connect *ServiceConnect `json:"connect,omitempty" hcl:"connect" mapstructure:"connect"`
}
type CheckDefinition struct {
@ -349,6 +352,58 @@ type CheckDefinition struct {
DeregisterCriticalServiceAfter *string `json:"deregister_critical_service_after,omitempty" hcl:"deregister_critical_service_after" mapstructure:"deregister_critical_service_after"`
}
// ServiceConnect is the connect block within a service registration
type ServiceConnect struct {
// TODO(banks) add way to specify that the app is connect-native
// Proxy configures a connect proxy instance for the service
Proxy *ServiceConnectProxy `json:"proxy,omitempty" hcl:"proxy" mapstructure:"proxy"`
}
type ServiceConnectProxy struct {
Command []string `json:"command,omitempty" hcl:"command" mapstructure:"command"`
ExecMode *string `json:"exec_mode,omitempty" hcl:"exec_mode" mapstructure:"exec_mode"`
Config map[string]interface{} `json:"config,omitempty" hcl:"config" mapstructure:"config"`
}
// Connect is the agent-global connect configuration.
type Connect struct {
// Enabled opts the agent into connect. It should be set on all clients and
// servers in a cluster for correct connect operation.
Enabled *bool `json:"enabled,omitempty" hcl:"enabled" mapstructure:"enabled"`
Proxy ConnectProxy `json:"proxy,omitempty" hcl:"proxy" mapstructure:"proxy"`
ProxyDefaults ConnectProxyDefaults `json:"proxy_defaults,omitempty" hcl:"proxy_defaults" mapstructure:"proxy_defaults"`
CAProvider *string `json:"ca_provider,omitempty" hcl:"ca_provider" mapstructure:"ca_provider"`
CAConfig map[string]interface{} `json:"ca_config,omitempty" hcl:"ca_config" mapstructure:"ca_config"`
}
// ConnectProxy is the agent-global connect proxy configuration.
type ConnectProxy struct {
// Consul will not execute managed proxies if its EUID is 0 (root).
// If this is true, then Consul will execute proxies if Consul is
// running as root. This is not recommended.
AllowManagedRoot *bool `json:"allow_managed_root" hcl:"allow_managed_root" mapstructure:"allow_managed_root"`
// AllowManagedAPIRegistration enables managed proxy registration
// via the agent HTTP API. If this is false, only file configurations
// can be used.
AllowManagedAPIRegistration *bool `json:"allow_managed_api_registration" hcl:"allow_managed_api_registration" mapstructure:"allow_managed_api_registration"`
}
// ConnectProxyDefaults is the agent-global defaults for managed Connect proxies.
type ConnectProxyDefaults struct {
// ExecMode is used where a registration doesn't include an exec_mode.
// Defaults to daemon.
ExecMode *string `json:"exec_mode,omitempty" hcl:"exec_mode" mapstructure:"exec_mode"`
// DaemonCommand is used to start proxy in exec_mode = daemon if not specified
// at registration time.
DaemonCommand []string `json:"daemon_command,omitempty" hcl:"daemon_command" mapstructure:"daemon_command"`
// ScriptCommand is used to start proxy in exec_mode = script if not specified
// at registration time.
ScriptCommand []string `json:"script_command,omitempty" hcl:"script_command" mapstructure:"script_command"`
// Config is merged into an Config specified at registration time.
Config map[string]interface{} `json:"config,omitempty" hcl:"config" mapstructure:"config"`
}
type DNS struct {
AllowStale *bool `json:"allow_stale,omitempty" hcl:"allow_stale" mapstructure:"allow_stale"`
ARecordLimit *int `json:"a_record_limit,omitempty" hcl:"a_record_limit" mapstructure:"a_record_limit"`
@ -360,6 +415,7 @@ type DNS struct {
RecursorTimeout *string `json:"recursor_timeout,omitempty" hcl:"recursor_timeout" mapstructure:"recursor_timeout"`
ServiceTTL map[string]string `json:"service_ttl,omitempty" hcl:"service_ttl" mapstructure:"service_ttl"`
UDPAnswerLimit *int `json:"udp_answer_limit,omitempty" hcl:"udp_answer_limit" mapstructure:"udp_answer_limit"`
NodeMetaTXT *bool `json:"enable_additional_node_meta_txt,omitempty" hcl:"enable_additional_node_meta_txt" mapstructure:"enable_additional_node_meta_txt"`
}
type HTTPConfig struct {
@ -399,12 +455,14 @@ type Telemetry struct {
}
type Ports struct {
DNS *int `json:"dns,omitempty" hcl:"dns" mapstructure:"dns"`
HTTP *int `json:"http,omitempty" hcl:"http" mapstructure:"http"`
HTTPS *int `json:"https,omitempty" hcl:"https" mapstructure:"https"`
SerfLAN *int `json:"serf_lan,omitempty" hcl:"serf_lan" mapstructure:"serf_lan"`
SerfWAN *int `json:"serf_wan,omitempty" hcl:"serf_wan" mapstructure:"serf_wan"`
Server *int `json:"server,omitempty" hcl:"server" mapstructure:"server"`
DNS *int `json:"dns,omitempty" hcl:"dns" mapstructure:"dns"`
HTTP *int `json:"http,omitempty" hcl:"http" mapstructure:"http"`
HTTPS *int `json:"https,omitempty" hcl:"https" mapstructure:"https"`
SerfLAN *int `json:"serf_lan,omitempty" hcl:"serf_lan" mapstructure:"serf_lan"`
SerfWAN *int `json:"serf_wan,omitempty" hcl:"serf_wan" mapstructure:"serf_wan"`
Server *int `json:"server,omitempty" hcl:"server" mapstructure:"server"`
ProxyMinPort *int `json:"proxy_min_port,omitempty" hcl:"proxy_min_port" mapstructure:"proxy_min_port"`
ProxyMaxPort *int `json:"proxy_max_port,omitempty" hcl:"proxy_max_port" mapstructure:"proxy_max_port"`
}
type UnixSocket struct {

View File

@ -85,6 +85,8 @@ func DefaultSource() Source {
serf_lan = ` + strconv.Itoa(consul.DefaultLANSerfPort) + `
serf_wan = ` + strconv.Itoa(consul.DefaultWANSerfPort) + `
server = ` + strconv.Itoa(consul.DefaultRPCPort) + `
proxy_min_port = 20000
proxy_max_port = 20255
}
telemetry = {
metrics_prefix = "consul"
@ -108,6 +110,10 @@ func DevSource() Source {
ui = true
log_level = "DEBUG"
server = true
connect = {
enabled = true
}
performance = {
raft_multiplier = 1
}

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types"
"golang.org/x/time/rate"
@ -281,6 +282,11 @@ type RuntimeConfig struct {
// hcl: dns_config { udp_answer_limit = int }
DNSUDPAnswerLimit int
// DNSNodeMetaTXT controls whether DNS queries will synthesize
// TXT records for the node metadata and add them when not specifically
// request (query type = TXT). If unset this will default to true
DNSNodeMetaTXT bool
// DNSRecursors can be set to allow the DNS servers to recursively
// resolve non-consul domains.
//
@ -299,177 +305,8 @@ type RuntimeConfig struct {
// hcl: http_config { response_headers = map[string]string }
HTTPResponseHeaders map[string]string
// TelemetryCirconus*: see https://github.com/circonus-labs/circonus-gometrics
// for more details on the various configuration options.
// Valid configuration combinations:
// - CirconusAPIToken
// metric management enabled (search for existing check or create a new one)
// - CirconusSubmissionUrl
// metric management disabled (use check with specified submission_url,
// broker must be using a public SSL certificate)
// - CirconusAPIToken + CirconusCheckSubmissionURL
// metric management enabled (use check with specified submission_url)
// - CirconusAPIToken + CirconusCheckID
// metric management enabled (use check with specified id)
// TelemetryCirconusAPIApp is an app name associated with API token.
// Default: "consul"
//
// hcl: telemetry { circonus_api_app = string }
TelemetryCirconusAPIApp string
// TelemetryCirconusAPIToken is a valid API Token used to create/manage check. If provided,
// metric management is enabled.
// Default: none
//
// hcl: telemetry { circonus_api_token = string }
TelemetryCirconusAPIToken string
// TelemetryCirconusAPIURL is the base URL to use for contacting the Circonus API.
// Default: "https://api.circonus.com/v2"
//
// hcl: telemetry { circonus_api_url = string }
TelemetryCirconusAPIURL string
// TelemetryCirconusBrokerID is an explicit broker to use when creating a new check. The numeric portion
// of broker._cid. If metric management is enabled and neither a Submission URL nor Check ID
// is provided, an attempt will be made to search for an existing check using Instance ID and
// Search Tag. If one is not found, a new HTTPTRAP check will be created.
// Default: use Select Tag if provided, otherwise, a random Enterprise Broker associated
// with the specified API token or the default Circonus Broker.
// Default: none
//
// hcl: telemetry { circonus_broker_id = string }
TelemetryCirconusBrokerID string
// TelemetryCirconusBrokerSelectTag is a special tag which will be used to select a broker when
// a Broker ID is not provided. The best use of this is to as a hint for which broker
// should be used based on *where* this particular instance is running.
// (e.g. a specific geo location or datacenter, dc:sfo)
// Default: none
//
// hcl: telemetry { circonus_broker_select_tag = string }
TelemetryCirconusBrokerSelectTag string
// TelemetryCirconusCheckDisplayName is the name for the check which will be displayed in the Circonus UI.
// Default: value of CirconusCheckInstanceID
//
// hcl: telemetry { circonus_check_display_name = string }
TelemetryCirconusCheckDisplayName string
// TelemetryCirconusCheckForceMetricActivation will force enabling metrics, as they are encountered,
// if the metric already exists and is NOT active. If check management is enabled, the default
// behavior is to add new metrics as they are encountered. If the metric already exists in the
// check, it will *NOT* be activated. This setting overrides that behavior.
// Default: "false"
//
// hcl: telemetry { circonus_check_metrics_activation = (true|false)
TelemetryCirconusCheckForceMetricActivation string
// TelemetryCirconusCheckID is the check id (not check bundle id) from a previously created
// HTTPTRAP check. The numeric portion of the check._cid field.
// Default: none
//
// hcl: telemetry { circonus_check_id = string }
TelemetryCirconusCheckID string
// TelemetryCirconusCheckInstanceID serves to uniquely identify the metrics coming from this "instance".
// It can be used to maintain metric continuity with transient or ephemeral instances as
// they move around within an infrastructure.
// Default: hostname:app
//
// hcl: telemetry { circonus_check_instance_id = string }
TelemetryCirconusCheckInstanceID string
// TelemetryCirconusCheckSearchTag is a special tag which, when coupled with the instance id, helps to
// narrow down the search results when neither a Submission URL or Check ID is provided.
// Default: service:app (e.g. service:consul)
//
// hcl: telemetry { circonus_check_search_tag = string }
TelemetryCirconusCheckSearchTag string
// TelemetryCirconusCheckSearchTag is a special tag which, when coupled with the instance id, helps to
// narrow down the search results when neither a Submission URL or Check ID is provided.
// Default: service:app (e.g. service:consul)
//
// hcl: telemetry { circonus_check_tags = string }
TelemetryCirconusCheckTags string
// TelemetryCirconusSubmissionInterval is the interval at which metrics are submitted to Circonus.
// Default: 10s
//
// hcl: telemetry { circonus_submission_interval = "duration" }
TelemetryCirconusSubmissionInterval string
// TelemetryCirconusCheckSubmissionURL is the check.config.submission_url field from a
// previously created HTTPTRAP check.
// Default: none
//
// hcl: telemetry { circonus_submission_url = string }
TelemetryCirconusSubmissionURL string
// DisableHostname will disable hostname prefixing for all metrics.
//
// hcl: telemetry { disable_hostname = (true|false)
TelemetryDisableHostname bool
// TelemetryDogStatsdAddr is the address of a dogstatsd instance. If provided,
// metrics will be sent to that instance
//
// hcl: telemetry { dogstatsd_addr = string }
TelemetryDogstatsdAddr string
// TelemetryDogStatsdTags are the global tags that should be sent with each packet to dogstatsd
// It is a list of strings, where each string looks like "my_tag_name:my_tag_value"
//
// hcl: telemetry { dogstatsd_tags = []string }
TelemetryDogstatsdTags []string
// PrometheusRetentionTime is the retention time for prometheus metrics if greater than 0.
// A value of 0 disable Prometheus support. Regarding Prometheus, it is considered a good
// practice to put large values here (such as a few days), and at least the interval between
// prometheus requests.
//
// hcl: telemetry { prometheus_retention_time = "duration" }
TelemetryPrometheusRetentionTime time.Duration
// TelemetryFilterDefault is the default for whether to allow a metric that's not
// covered by the filter.
//
// hcl: telemetry { filter_default = (true|false) }
TelemetryFilterDefault bool
// TelemetryAllowedPrefixes is a list of filter rules to apply for allowing metrics
// by prefix. Use the 'prefix_filter' option and prefix rules with '+' to be
// included.
//
// hcl: telemetry { prefix_filter = []string{"+<expr>", "+<expr>", ...} }
TelemetryAllowedPrefixes []string
// TelemetryBlockedPrefixes is a list of filter rules to apply for blocking metrics
// by prefix. Use the 'prefix_filter' option and prefix rules with '-' to be
// excluded.
//
// hcl: telemetry { prefix_filter = []string{"-<expr>", "-<expr>", ...} }
TelemetryBlockedPrefixes []string
// TelemetryMetricsPrefix is the prefix used to write stats values to.
// Default: "consul."
//
// hcl: telemetry { metrics_prefix = string }
TelemetryMetricsPrefix string
// TelemetryStatsdAddr is the address of a statsd instance. If provided,
// metrics will be sent to that instance.
//
// hcl: telemetry { statsd_addr = string }
TelemetryStatsdAddr string
// TelemetryStatsiteAddr is the address of a statsite instance. If provided,
// metrics will be streamed to that instance.
//
// hcl: telemetry { statsite_addr = string }
TelemetryStatsiteAddr string
// Embed Telemetry Config
Telemetry lib.TelemetryConfig
// Datacenter is the datacenter this node is in. Defaults to "dc1".
//
@ -616,6 +453,61 @@ type RuntimeConfig struct {
// flag: -client string
ClientAddrs []*net.IPAddr
// ConnectEnabled opts the agent into connect. It should be set on all clients
// and servers in a cluster for correct connect operation.
ConnectEnabled bool
// ConnectProxyBindMinPort is the inclusive start of the range of ports
// allocated to the agent for starting proxy listeners on where no explicit
// port is specified.
ConnectProxyBindMinPort int
// ConnectProxyBindMaxPort is the inclusive end of the range of ports
// allocated to the agent for starting proxy listeners on where no explicit
// port is specified.
ConnectProxyBindMaxPort int
// ConnectProxyAllowManagedRoot is true if Consul can execute managed
// proxies when running as root (EUID == 0).
ConnectProxyAllowManagedRoot bool
// ConnectProxyAllowManagedAPIRegistration enables managed proxy registration
// via the agent HTTP API. If this is false, only file configurations
// can be used.
ConnectProxyAllowManagedAPIRegistration bool
// ConnectProxyDefaultExecMode is used where a registration doesn't include an
// exec_mode. Defaults to daemon.
ConnectProxyDefaultExecMode string
// ConnectProxyDefaultDaemonCommand is used to start proxy in exec_mode =
// daemon if not specified at registration time.
ConnectProxyDefaultDaemonCommand []string
// ConnectProxyDefaultScriptCommand is used to start proxy in exec_mode =
// script if not specified at registration time.
ConnectProxyDefaultScriptCommand []string
// ConnectProxyDefaultConfig is merged with any config specified at
// registration time to allow global control of defaults.
ConnectProxyDefaultConfig map[string]interface{}
// ConnectCAProvider is the type of CA provider to use with Connect.
ConnectCAProvider string
// ConnectCAConfig is the config to use for the CA provider.
ConnectCAConfig map[string]interface{}
// ConnectTestDisableManagedProxies is not exposed to public config but us
// used by TestAgent to prevent self-executing the test binary in the
// background if a managed proxy is created for a test. The only place we
// actually want to test processes really being spun up and managed is in
// `agent/proxy` and it does it at a lower level. Note that this still allows
// registering managed proxies via API and other methods, and still creates
// all the agent state for them, just doesn't actually start external
// processes up.
ConnectTestDisableManagedProxies bool
// DNSAddrs contains the list of TCP and UDP addresses the DNS server will
// bind to. If the DNS endpoint is disabled (ports.dns <= 0) the list is
// empty.

View File

@ -19,10 +19,11 @@ import (
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/testutil"
"github.com/hashicorp/consul/types"
"github.com/pascaldekloe/goe/verify"
"github.com/sergi/go-diff/diffmatchpatch"
"github.com/stretchr/testify/require"
)
type configTest struct {
@ -31,6 +32,7 @@ type configTest struct {
pre, post func()
json, jsontail []string
hcl, hcltail []string
skipformat bool
privatev4 func() ([]*net.IPAddr, error)
publicv6 func() ([]*net.IPAddr, error)
patch func(rt *RuntimeConfig)
@ -263,6 +265,7 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
rt.AdvertiseAddrLAN = ipAddr("127.0.0.1")
rt.AdvertiseAddrWAN = ipAddr("127.0.0.1")
rt.BindAddr = ipAddr("127.0.0.1")
rt.ConnectEnabled = true
rt.DevMode = true
rt.DisableAnonymousSignature = true
rt.DisableKeyringFile = true
@ -1850,8 +1853,8 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
`},
patch: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.TelemetryAllowedPrefixes = []string{"foo"}
rt.TelemetryBlockedPrefixes = []string{"bar"}
rt.Telemetry.AllowedPrefixes = []string{"foo"}
rt.Telemetry.BlockedPrefixes = []string{"bar"}
},
warns: []string{`Filter rule must begin with either '+' or '-': "nix"`},
},
@ -2068,6 +2071,127 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
rt.DataDir = dataDir
},
},
{
desc: "HCL service managed proxy 'upstreams'",
args: []string{
`-data-dir=` + dataDir,
},
hcl: []string{
`service {
name = "web"
port = 8080
connect {
proxy {
config {
upstreams {
local_bind_port = 1234
}
}
}
}
}`,
},
skipformat: true, // skipping JSON cause we get slightly diff types (okay)
patch: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.Services = []*structs.ServiceDefinition{
&structs.ServiceDefinition{
Name: "web",
Port: 8080,
Connect: &structs.ServiceConnect{
Proxy: &structs.ServiceDefinitionConnectProxy{
Config: map[string]interface{}{
"upstreams": []map[string]interface{}{
map[string]interface{}{
"local_bind_port": 1234,
},
},
},
},
},
},
}
},
},
{
desc: "JSON service managed proxy 'upstreams'",
args: []string{
`-data-dir=` + dataDir,
},
json: []string{
`{
"service": {
"name": "web",
"port": 8080,
"connect": {
"proxy": {
"config": {
"upstreams": [{
"local_bind_port": 1234
}]
}
}
}
}
}`,
},
skipformat: true, // skipping HCL cause we get slightly diff types (okay)
patch: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.Services = []*structs.ServiceDefinition{
&structs.ServiceDefinition{
Name: "web",
Port: 8080,
Connect: &structs.ServiceConnect{
Proxy: &structs.ServiceDefinitionConnectProxy{
Config: map[string]interface{}{
"upstreams": []interface{}{
map[string]interface{}{
"local_bind_port": float64(1234),
},
},
},
},
},
},
}
},
},
{
desc: "enabling Connect allow_managed_root",
args: []string{
`-data-dir=` + dataDir,
},
json: []string{
`{ "connect": { "proxy": { "allow_managed_root": true } } }`,
},
hcl: []string{
`connect { proxy { allow_managed_root = true } }`,
},
patch: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.ConnectProxyAllowManagedRoot = true
},
},
{
desc: "enabling Connect allow_managed_api_registration",
args: []string{
`-data-dir=` + dataDir,
},
json: []string{
`{ "connect": { "proxy": { "allow_managed_api_registration": true } } }`,
},
hcl: []string{
`connect { proxy { allow_managed_api_registration = true } }`,
},
patch: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.ConnectProxyAllowManagedAPIRegistration = true
},
},
}
testConfig(t, tests, dataDir)
@ -2089,7 +2213,7 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
// json and hcl sources need to be in sync
// to make sure we're generating the same config
if len(tt.json) != len(tt.hcl) {
if len(tt.json) != len(tt.hcl) && !tt.skipformat {
t.Fatal(tt.desc, ": JSON and HCL test case out of sync")
}
@ -2099,6 +2223,12 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
srcs, tails = tt.hcl, tt.hcltail
}
// If we're skipping a format and the current format is empty,
// then skip it!
if tt.skipformat && len(srcs) == 0 {
continue
}
// build the description
var desc []string
if !flagsOnly {
@ -2353,6 +2483,23 @@ func TestFullConfig(t *testing.T) {
],
"check_update_interval": "16507s",
"client_addr": "93.83.18.19",
"connect": {
"ca_provider": "consul",
"ca_config": {
"RotationPeriod": "90h"
},
"enabled": true,
"proxy_defaults": {
"exec_mode": "script",
"daemon_command": ["consul", "connect", "proxy"],
"script_command": ["proxyctl.sh"],
"config": {
"foo": "bar",
"connect_timeout_ms": 1000,
"pedantic_mode": true
}
}
},
"data_dir": "` + dataDir + `",
"datacenter": "rzo029wg",
"disable_anonymous_signature": true,
@ -2417,7 +2564,9 @@ func TestFullConfig(t *testing.T) {
"dns": 7001,
"http": 7999,
"https": 15127,
"server": 3757
"server": 3757,
"proxy_min_port": 2000,
"proxy_max_port": 3000
},
"protocol": 30793,
"raft_protocol": 19016,
@ -2613,7 +2762,16 @@ func TestFullConfig(t *testing.T) {
"ttl": "11222s",
"deregister_critical_service_after": "68482s"
}
]
],
"connect": {
"proxy": {
"exec_mode": "daemon",
"command": ["awesome-proxy"],
"config": {
"foo": "qux"
}
}
}
}
],
"session_ttl_min": "26627s",
@ -2786,6 +2944,25 @@ func TestFullConfig(t *testing.T) {
]
check_update_interval = "16507s"
client_addr = "93.83.18.19"
connect {
ca_provider = "consul"
ca_config {
"RotationPeriod" = "90h"
}
enabled = true
proxy_defaults {
exec_mode = "script"
daemon_command = ["consul", "connect", "proxy"]
script_command = ["proxyctl.sh"]
config = {
foo = "bar"
# hack float since json parses numbers as float and we have to
# assert against the same thing
connect_timeout_ms = 1000.0
pedantic_mode = true
}
}
}
data_dir = "` + dataDir + `"
datacenter = "rzo029wg"
disable_anonymous_signature = true
@ -2851,6 +3028,8 @@ func TestFullConfig(t *testing.T) {
http = 7999,
https = 15127
server = 3757
proxy_min_port = 2000
proxy_max_port = 3000
}
protocol = 30793
raft_protocol = 19016
@ -3047,6 +3226,15 @@ func TestFullConfig(t *testing.T) {
deregister_critical_service_after = "68482s"
}
]
connect {
proxy {
exec_mode = "daemon"
command = ["awesome-proxy"]
config = {
foo = "qux"
}
}
}
}
]
session_ttl_min = "26627s"
@ -3355,8 +3543,25 @@ func TestFullConfig(t *testing.T) {
DeregisterCriticalServiceAfter: 13209 * time.Second,
},
},
CheckUpdateInterval: 16507 * time.Second,
ClientAddrs: []*net.IPAddr{ipAddr("93.83.18.19")},
CheckUpdateInterval: 16507 * time.Second,
ClientAddrs: []*net.IPAddr{ipAddr("93.83.18.19")},
ConnectEnabled: true,
ConnectProxyBindMinPort: 2000,
ConnectProxyBindMaxPort: 3000,
ConnectCAProvider: "consul",
ConnectCAConfig: map[string]interface{}{
"RotationPeriod": "90h",
},
ConnectProxyAllowManagedRoot: false,
ConnectProxyAllowManagedAPIRegistration: false,
ConnectProxyDefaultExecMode: "script",
ConnectProxyDefaultDaemonCommand: []string{"consul", "connect", "proxy"},
ConnectProxyDefaultScriptCommand: []string{"proxyctl.sh"},
ConnectProxyDefaultConfig: map[string]interface{}{
"foo": "bar",
"connect_timeout_ms": float64(1000),
"pedantic_mode": true,
},
DNSAddrs: []net.Addr{tcpAddr("93.95.95.81:7001"), udpAddr("93.95.95.81:7001")},
DNSARecordLimit: 29907,
DNSAllowStale: true,
@ -3371,6 +3576,7 @@ func TestFullConfig(t *testing.T) {
DNSRecursors: []string{"63.38.39.58", "92.49.18.18"},
DNSServiceTTL: map[string]time.Duration{"*": 32030 * time.Second},
DNSUDPAnswerLimit: 29909,
DNSNodeMetaTXT: true,
DataDir: dataDir,
Datacenter: "rzo029wg",
DevMode: true,
@ -3529,6 +3735,15 @@ func TestFullConfig(t *testing.T) {
DeregisterCriticalServiceAfter: 68482 * time.Second,
},
},
Connect: &structs.ServiceConnect{
Proxy: &structs.ServiceDefinitionConnectProxy{
ExecMode: "daemon",
Command: []string{"awesome-proxy"},
Config: map[string]interface{}{
"foo": "qux",
},
},
},
},
{
ID: "dLOXpSCI",
@ -3606,41 +3821,43 @@ func TestFullConfig(t *testing.T) {
},
},
},
SerfAdvertiseAddrLAN: tcpAddr("17.99.29.16:8301"),
SerfAdvertiseAddrWAN: tcpAddr("78.63.37.19:8302"),
SerfBindAddrLAN: tcpAddr("99.43.63.15:8301"),
SerfBindAddrWAN: tcpAddr("67.88.33.19:8302"),
SessionTTLMin: 26627 * time.Second,
SkipLeaveOnInt: true,
StartJoinAddrsLAN: []string{"LR3hGDoG", "MwVpZ4Up"},
StartJoinAddrsWAN: []string{"EbFSc3nA", "kwXTh623"},
SyslogFacility: "hHv79Uia",
TelemetryCirconusAPIApp: "p4QOTe9j",
TelemetryCirconusAPIToken: "E3j35V23",
TelemetryCirconusAPIURL: "mEMjHpGg",
TelemetryCirconusBrokerID: "BHlxUhed",
TelemetryCirconusBrokerSelectTag: "13xy1gHm",
TelemetryCirconusCheckDisplayName: "DRSlQR6n",
TelemetryCirconusCheckForceMetricActivation: "Ua5FGVYf",
TelemetryCirconusCheckID: "kGorutad",
TelemetryCirconusCheckInstanceID: "rwoOL6R4",
TelemetryCirconusCheckSearchTag: "ovT4hT4f",
TelemetryCirconusCheckTags: "prvO4uBl",
TelemetryCirconusSubmissionInterval: "DolzaflP",
TelemetryCirconusSubmissionURL: "gTcbS93G",
TelemetryDisableHostname: true,
TelemetryDogstatsdAddr: "0wSndumK",
TelemetryDogstatsdTags: []string{"3N81zSUB", "Xtj8AnXZ"},
TelemetryFilterDefault: true,
TelemetryAllowedPrefixes: []string{"oJotS8XJ"},
TelemetryBlockedPrefixes: []string{"cazlEhGn"},
TelemetryMetricsPrefix: "ftO6DySn",
TelemetryPrometheusRetentionTime: 15 * time.Second,
TelemetryStatsdAddr: "drce87cy",
TelemetryStatsiteAddr: "HpFwKB8R",
TLSCipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384},
TLSMinVersion: "pAOWafkR",
TLSPreferServerCipherSuites: true,
SerfAdvertiseAddrLAN: tcpAddr("17.99.29.16:8301"),
SerfAdvertiseAddrWAN: tcpAddr("78.63.37.19:8302"),
SerfBindAddrLAN: tcpAddr("99.43.63.15:8301"),
SerfBindAddrWAN: tcpAddr("67.88.33.19:8302"),
SessionTTLMin: 26627 * time.Second,
SkipLeaveOnInt: true,
StartJoinAddrsLAN: []string{"LR3hGDoG", "MwVpZ4Up"},
StartJoinAddrsWAN: []string{"EbFSc3nA", "kwXTh623"},
SyslogFacility: "hHv79Uia",
Telemetry: lib.TelemetryConfig{
CirconusAPIApp: "p4QOTe9j",
CirconusAPIToken: "E3j35V23",
CirconusAPIURL: "mEMjHpGg",
CirconusBrokerID: "BHlxUhed",
CirconusBrokerSelectTag: "13xy1gHm",
CirconusCheckDisplayName: "DRSlQR6n",
CirconusCheckForceMetricActivation: "Ua5FGVYf",
CirconusCheckID: "kGorutad",
CirconusCheckInstanceID: "rwoOL6R4",
CirconusCheckSearchTag: "ovT4hT4f",
CirconusCheckTags: "prvO4uBl",
CirconusSubmissionInterval: "DolzaflP",
CirconusSubmissionURL: "gTcbS93G",
DisableHostname: true,
DogstatsdAddr: "0wSndumK",
DogstatsdTags: []string{"3N81zSUB", "Xtj8AnXZ"},
FilterDefault: true,
AllowedPrefixes: []string{"oJotS8XJ"},
BlockedPrefixes: []string{"cazlEhGn"},
MetricsPrefix: "ftO6DySn",
PrometheusRetentionTime: 15 * time.Second,
StatsdAddr: "drce87cy",
StatsiteAddr: "HpFwKB8R",
},
TLSCipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384},
TLSMinVersion: "pAOWafkR",
TLSPreferServerCipherSuites: true,
TaggedAddresses: map[string]string{
"7MYgHrYH": "dALJAhLD",
"h6DdBy6K": "ebrr9zZ8",
@ -4018,6 +4235,18 @@ func TestSanitize(t *testing.T) {
}
],
"ClientAddrs": [],
"ConnectCAConfig": {},
"ConnectCAProvider": "",
"ConnectEnabled": false,
"ConnectProxyAllowManagedAPIRegistration": false,
"ConnectProxyAllowManagedRoot": false,
"ConnectProxyBindMaxPort": 0,
"ConnectProxyBindMinPort": 0,
"ConnectProxyDefaultConfig": {},
"ConnectProxyDefaultDaemonCommand": [],
"ConnectProxyDefaultExecMode": "",
"ConnectProxyDefaultScriptCommand": [],
"ConnectTestDisableManagedProxies": false,
"ConsulCoordinateUpdateBatchSize": 0,
"ConsulCoordinateUpdateMaxBatches": 0,
"ConsulCoordinateUpdatePeriod": "15s",
@ -4043,6 +4272,7 @@ func TestSanitize(t *testing.T) {
"DNSDomain": "",
"DNSEnableTruncate": false,
"DNSMaxStale": "0s",
"DNSNodeMetaTXT": false,
"DNSNodeTTL": "0s",
"DNSOnlyPassing": false,
"DNSPort": 0,
@ -4148,11 +4378,14 @@ func TestSanitize(t *testing.T) {
"Timeout": "0s"
},
"Checks": [],
"Connect": null,
"EnableTagOverride": false,
"ID": "",
"Kind": "",
"Meta": {},
"Name": "foo",
"Port": 0,
"ProxyDestination": "",
"Tags": [],
"Token": "hidden"
}
@ -4168,29 +4401,31 @@ func TestSanitize(t *testing.T) {
"TLSMinVersion": "",
"TLSPreferServerCipherSuites": false,
"TaggedAddresses": {},
"TelemetryAllowedPrefixes": [],
"TelemetryBlockedPrefixes": [],
"TelemetryCirconusAPIApp": "",
"TelemetryCirconusAPIToken": "hidden",
"TelemetryCirconusAPIURL": "",
"TelemetryCirconusBrokerID": "",
"TelemetryCirconusBrokerSelectTag": "",
"TelemetryCirconusCheckDisplayName": "",
"TelemetryCirconusCheckForceMetricActivation": "",
"TelemetryCirconusCheckID": "",
"TelemetryCirconusCheckInstanceID": "",
"TelemetryCirconusCheckSearchTag": "",
"TelemetryCirconusCheckTags": "",
"TelemetryCirconusSubmissionInterval": "",
"TelemetryCirconusSubmissionURL": "",
"TelemetryDisableHostname": false,
"TelemetryDogstatsdAddr": "",
"TelemetryDogstatsdTags": [],
"TelemetryFilterDefault": false,
"TelemetryMetricsPrefix": "",
"TelemetryPrometheusRetentionTime": "0s",
"TelemetryStatsdAddr": "",
"TelemetryStatsiteAddr": "",
"Telemetry":{
"AllowedPrefixes": [],
"BlockedPrefixes": [],
"CirconusAPIApp": "",
"CirconusAPIToken": "hidden",
"CirconusAPIURL": "",
"CirconusBrokerID": "",
"CirconusBrokerSelectTag": "",
"CirconusCheckDisplayName": "",
"CirconusCheckForceMetricActivation": "",
"CirconusCheckID": "",
"CirconusCheckInstanceID": "",
"CirconusCheckSearchTag": "",
"CirconusCheckTags": "",
"CirconusSubmissionInterval": "",
"CirconusSubmissionURL": "",
"DisableHostname": false,
"DogstatsdAddr": "",
"DogstatsdTags": [],
"FilterDefault": false,
"MetricsPrefix": "",
"PrometheusRetentionTime": "0s",
"StatsdAddr": "",
"StatsiteAddr": ""
},
"TranslateWANAddrs": false,
"UIDir": "",
"UnixSocketGroup": "",
@ -4210,11 +4445,7 @@ func TestSanitize(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if got, want := string(b), rtJSON; got != want {
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(want, got, false)
t.Fatal(dmp.DiffPrettyText(diffs))
}
require.JSONEq(t, rtJSON, string(b))
}
func splitIPPort(hostport string) (net.IP, int) {

View File

@ -0,0 +1,49 @@
package ca
import (
"crypto/x509"
)
// Provider is the interface for Consul to interact with
// an external CA that provides leaf certificate signing for
// given SpiffeIDServices.
type Provider interface {
// Active root returns the currently active root CA for this
// provider. This should be a parent of the certificate returned by
// ActiveIntermediate()
ActiveRoot() (string, error)
// ActiveIntermediate returns the current signing cert used by this provider
// for generating SPIFFE leaf certs. Note that this must not change except
// when Consul requests the change via GenerateIntermediate. Changing the
// signing cert will break Consul's assumptions about which validation paths
// are active.
ActiveIntermediate() (string, error)
// GenerateIntermediate returns a new intermediate signing cert and sets it to
// the active intermediate. If multiple intermediates are needed to complete
// the chain from the signing certificate back to the active root, they should
// all by bundled here.
GenerateIntermediate() (string, error)
// Sign signs a leaf certificate used by Connect proxies from a CSR. The PEM
// returned should include only the leaf certificate as all Intermediates
// needed to validate it will be added by Consul based on the active
// intemediate and any cross-signed intermediates managed by Consul.
Sign(*x509.CertificateRequest) (string, error)
// CrossSignCA must accept a CA certificate from another CA provider
// and cross sign it exactly as it is such that it forms a chain back the the
// CAProvider's current root. Specifically, the Distinguished Name, Subject
// Alternative Name, SubjectKeyID and other relevant extensions must be kept.
// The resulting certificate must have a distinct Serial Number and the
// AuthorityKeyID set to the CAProvider's current signing key as well as the
// Issuer related fields changed as necessary. The resulting certificate is
// returned as a PEM formatted string.
CrossSignCA(*x509.Certificate) (string, error)
// Cleanup performs any necessary cleanup that should happen when the provider
// is shut down permanently, such as removing a temporary PKI backend in Vault
// created for an intermediate CA.
Cleanup() error
}

View File

@ -0,0 +1,379 @@
package ca
import (
"bytes"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net/url"
"sync"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
)
type ConsulProvider struct {
config *structs.ConsulCAProviderConfig
id string
delegate ConsulProviderStateDelegate
sync.RWMutex
}
type ConsulProviderStateDelegate interface {
State() *state.Store
ApplyCARequest(*structs.CARequest) error
}
// NewConsulProvider returns a new instance of the Consul CA provider,
// bootstrapping its state in the state store necessary
func NewConsulProvider(rawConfig map[string]interface{}, delegate ConsulProviderStateDelegate) (*ConsulProvider, error) {
conf, err := ParseConsulCAConfig(rawConfig)
if err != nil {
return nil, err
}
provider := &ConsulProvider{
config: conf,
delegate: delegate,
id: fmt.Sprintf("%s,%s", conf.PrivateKey, conf.RootCert),
}
// Check if this configuration of the provider has already been
// initialized in the state store.
state := delegate.State()
_, providerState, err := state.CAProviderState(provider.id)
if err != nil {
return nil, err
}
// Exit early if the state store has already been populated for this config.
if providerState != nil {
return provider, nil
}
newState := structs.CAConsulProviderState{
ID: provider.id,
}
// Write the initial provider state to get the index to use for the
// CA serial number.
{
args := &structs.CARequest{
Op: structs.CAOpSetProviderState,
ProviderState: &newState,
}
if err := delegate.ApplyCARequest(args); err != nil {
return nil, err
}
}
idx, _, err := state.CAProviderState(provider.id)
if err != nil {
return nil, err
}
// Generate a private key if needed
if conf.PrivateKey == "" {
_, pk, err := connect.GeneratePrivateKey()
if err != nil {
return nil, err
}
newState.PrivateKey = pk
} else {
newState.PrivateKey = conf.PrivateKey
}
// Generate the root CA if necessary
if conf.RootCert == "" {
ca, err := provider.generateCA(newState.PrivateKey, idx+1)
if err != nil {
return nil, fmt.Errorf("error generating CA: %v", err)
}
newState.RootCert = ca
} else {
newState.RootCert = conf.RootCert
}
// Write the provider state
args := &structs.CARequest{
Op: structs.CAOpSetProviderState,
ProviderState: &newState,
}
if err := delegate.ApplyCARequest(args); err != nil {
return nil, err
}
return provider, nil
}
// Return the active root CA and generate a new one if needed
func (c *ConsulProvider) ActiveRoot() (string, error) {
state := c.delegate.State()
_, providerState, err := state.CAProviderState(c.id)
if err != nil {
return "", err
}
return providerState.RootCert, nil
}
// We aren't maintaining separate root/intermediate CAs for the builtin
// provider, so just return the root.
func (c *ConsulProvider) ActiveIntermediate() (string, error) {
return c.ActiveRoot()
}
// We aren't maintaining separate root/intermediate CAs for the builtin
// provider, so just return the root.
func (c *ConsulProvider) GenerateIntermediate() (string, error) {
return c.ActiveIntermediate()
}
// Remove the state store entry for this provider instance.
func (c *ConsulProvider) Cleanup() error {
args := &structs.CARequest{
Op: structs.CAOpDeleteProviderState,
ProviderState: &structs.CAConsulProviderState{ID: c.id},
}
if err := c.delegate.ApplyCARequest(args); err != nil {
return err
}
return nil
}
// Sign returns a new certificate valid for the given SpiffeIDService
// using the current CA.
func (c *ConsulProvider) Sign(csr *x509.CertificateRequest) (string, error) {
// Lock during the signing so we don't use the same index twice
// for different cert serial numbers.
c.Lock()
defer c.Unlock()
// Get the provider state
state := c.delegate.State()
idx, providerState, err := state.CAProviderState(c.id)
if err != nil {
return "", err
}
// Create the keyId for the cert from the signing private key.
signer, err := connect.ParseSigner(providerState.PrivateKey)
if err != nil {
return "", err
}
if signer == nil {
return "", fmt.Errorf("error signing cert: Consul CA not initialized yet")
}
keyId, err := connect.KeyId(signer.Public())
if err != nil {
return "", err
}
// Parse the SPIFFE ID
spiffeId, err := connect.ParseCertURI(csr.URIs[0])
if err != nil {
return "", err
}
serviceId, ok := spiffeId.(*connect.SpiffeIDService)
if !ok {
return "", fmt.Errorf("SPIFFE ID in CSR must be a service ID")
}
// Parse the CA cert
caCert, err := connect.ParseCert(providerState.RootCert)
if err != nil {
return "", fmt.Errorf("error parsing CA cert: %s", err)
}
// Cert template for generation
sn := &big.Int{}
sn.SetUint64(idx + 1)
// Sign the certificate valid from 1 minute in the past, this helps it be
// accepted right away even when nodes are not in close time sync accross the
// cluster. A minute is more than enough for typical DC clock drift.
effectiveNow := time.Now().Add(-1 * time.Minute)
template := x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{CommonName: serviceId.Service},
URIs: csr.URIs,
Signature: csr.Signature,
SignatureAlgorithm: csr.SignatureAlgorithm,
PublicKeyAlgorithm: csr.PublicKeyAlgorithm,
PublicKey: csr.PublicKey,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageDataEncipherment |
x509.KeyUsageKeyAgreement |
x509.KeyUsageDigitalSignature |
x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
// todo(kyhavlov): add a way to set the cert lifetime here from the CA config
NotAfter: effectiveNow.Add(3 * 24 * time.Hour),
NotBefore: effectiveNow,
AuthorityKeyId: keyId,
SubjectKeyId: keyId,
}
// Create the certificate, PEM encode it and return that value.
var buf bytes.Buffer
bs, err := x509.CreateCertificate(
rand.Reader, &template, caCert, csr.PublicKey, signer)
if err != nil {
return "", fmt.Errorf("error generating certificate: %s", err)
}
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
return "", fmt.Errorf("error encoding certificate: %s", err)
}
err = c.incrementProviderIndex(providerState)
if err != nil {
return "", err
}
// Set the response
return buf.String(), nil
}
// CrossSignCA returns the given CA cert signed by the current active root.
func (c *ConsulProvider) CrossSignCA(cert *x509.Certificate) (string, error) {
c.Lock()
defer c.Unlock()
// Get the provider state
state := c.delegate.State()
idx, providerState, err := state.CAProviderState(c.id)
if err != nil {
return "", err
}
privKey, err := connect.ParseSigner(providerState.PrivateKey)
if err != nil {
return "", fmt.Errorf("error parsing private key %q: %s", providerState.PrivateKey, err)
}
rootCA, err := connect.ParseCert(providerState.RootCert)
if err != nil {
return "", err
}
keyId, err := connect.KeyId(privKey.Public())
if err != nil {
return "", err
}
// Create the cross-signing template from the existing root CA
serialNum := &big.Int{}
serialNum.SetUint64(idx + 1)
template := *cert
template.SerialNumber = serialNum
template.SignatureAlgorithm = rootCA.SignatureAlgorithm
template.AuthorityKeyId = keyId
// Sign the certificate valid from 1 minute in the past, this helps it be
// accepted right away even when nodes are not in close time sync accross the
// cluster. A minute is more than enough for typical DC clock drift.
effectiveNow := time.Now().Add(-1 * time.Minute)
template.NotBefore = effectiveNow
// This cross-signed cert is only needed during rotation, and only while old
// leaf certs are still in use. They expire within 3 days currently so 7 is
// safe. TODO(banks): make this be based on leaf expiry time when that is
// configurable.
template.NotAfter = effectiveNow.Add(7 * 24 * time.Hour)
bs, err := x509.CreateCertificate(
rand.Reader, &template, rootCA, cert.PublicKey, privKey)
if err != nil {
return "", fmt.Errorf("error generating CA certificate: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
return "", fmt.Errorf("error encoding private key: %s", err)
}
err = c.incrementProviderIndex(providerState)
if err != nil {
return "", err
}
return buf.String(), nil
}
// incrementProviderIndex does a write to increment the provider state store table index
// used for serial numbers when generating certificates.
func (c *ConsulProvider) incrementProviderIndex(providerState *structs.CAConsulProviderState) error {
newState := *providerState
args := &structs.CARequest{
Op: structs.CAOpSetProviderState,
ProviderState: &newState,
}
if err := c.delegate.ApplyCARequest(args); err != nil {
return err
}
return nil
}
// generateCA makes a new root CA using the current private key
func (c *ConsulProvider) generateCA(privateKey string, sn uint64) (string, error) {
state := c.delegate.State()
_, config, err := state.CAConfig()
if err != nil {
return "", err
}
privKey, err := connect.ParseSigner(privateKey)
if err != nil {
return "", fmt.Errorf("error parsing private key %q: %s", privateKey, err)
}
name := fmt.Sprintf("Consul CA %d", sn)
// The URI (SPIFFE compatible) for the cert
id := connect.SpiffeIDSigningForCluster(config)
keyId, err := connect.KeyId(privKey.Public())
if err != nil {
return "", err
}
// Create the CA cert
serialNum := &big.Int{}
serialNum.SetUint64(sn)
template := x509.Certificate{
SerialNumber: serialNum,
Subject: pkix.Name{CommonName: name},
URIs: []*url.URL{id.URI()},
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageCertSign |
x509.KeyUsageCRLSign |
x509.KeyUsageDigitalSignature,
IsCA: true,
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
NotBefore: time.Now(),
AuthorityKeyId: keyId,
SubjectKeyId: keyId,
}
bs, err := x509.CreateCertificate(
rand.Reader, &template, &template, privKey.Public(), privKey)
if err != nil {
return "", fmt.Errorf("error generating CA certificate: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
return "", fmt.Errorf("error encoding private key: %s", err)
}
return buf.String(), nil
}

View File

@ -0,0 +1,77 @@
package ca
import (
"fmt"
"reflect"
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/mitchellh/mapstructure"
)
func ParseConsulCAConfig(raw map[string]interface{}) (*structs.ConsulCAProviderConfig, error) {
var config structs.ConsulCAProviderConfig
decodeConf := &mapstructure.DecoderConfig{
DecodeHook: ParseDurationFunc(),
ErrorUnused: true,
Result: &config,
WeaklyTypedInput: true,
}
decoder, err := mapstructure.NewDecoder(decodeConf)
if err != nil {
return nil, err
}
if err := decoder.Decode(raw); err != nil {
return nil, fmt.Errorf("error decoding config: %s", err)
}
if config.PrivateKey == "" && config.RootCert != "" {
return nil, fmt.Errorf("must provide a private key when providing a root cert")
}
return &config, nil
}
// ParseDurationFunc is a mapstructure hook for decoding a string or
// []uint8 into a time.Duration value.
func ParseDurationFunc() mapstructure.DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
var v time.Duration
if t != reflect.TypeOf(v) {
return data, nil
}
switch {
case f.Kind() == reflect.String:
if dur, err := time.ParseDuration(data.(string)); err != nil {
return nil, err
} else {
v = dur
}
return v, nil
case f == reflect.SliceOf(reflect.TypeOf(uint8(0))):
s := Uint8ToString(data.([]uint8))
if dur, err := time.ParseDuration(s); err != nil {
return nil, err
} else {
v = dur
}
return v, nil
default:
return data, nil
}
}
}
func Uint8ToString(bs []uint8) string {
b := make([]byte, len(bs))
for i, v := range bs {
b[i] = byte(v)
}
return string(b)
}

View File

@ -0,0 +1,266 @@
package ca
import (
"crypto/x509"
"fmt"
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type consulCAMockDelegate struct {
state *state.Store
}
func (c *consulCAMockDelegate) State() *state.Store {
return c.state
}
func (c *consulCAMockDelegate) ApplyCARequest(req *structs.CARequest) error {
idx, _, err := c.state.CAConfig()
if err != nil {
return err
}
switch req.Op {
case structs.CAOpSetProviderState:
_, err := c.state.CASetProviderState(idx+1, req.ProviderState)
if err != nil {
return err
}
return nil
case structs.CAOpDeleteProviderState:
if err := c.state.CADeleteProviderState(req.ProviderState.ID); err != nil {
return err
}
return nil
default:
return fmt.Errorf("Invalid CA operation '%s'", req.Op)
}
}
func newMockDelegate(t *testing.T, conf *structs.CAConfiguration) *consulCAMockDelegate {
s, err := state.NewStateStore(nil)
if err != nil {
t.Fatalf("err: %s", err)
}
if s == nil {
t.Fatalf("missing state store")
}
if err := s.CASetConfig(conf.RaftIndex.CreateIndex, conf); err != nil {
t.Fatalf("err: %s", err)
}
return &consulCAMockDelegate{s}
}
func testConsulCAConfig() *structs.CAConfiguration {
return &structs.CAConfiguration{
ClusterID: "asdf",
Provider: "consul",
Config: map[string]interface{}{},
}
}
func TestConsulCAProvider_Bootstrap(t *testing.T) {
t.Parallel()
assert := assert.New(t)
conf := testConsulCAConfig()
delegate := newMockDelegate(t, conf)
provider, err := NewConsulProvider(conf.Config, delegate)
assert.NoError(err)
root, err := provider.ActiveRoot()
assert.NoError(err)
// Intermediate should be the same cert.
inter, err := provider.ActiveIntermediate()
assert.NoError(err)
assert.Equal(root, inter)
// Should be a valid cert
parsed, err := connect.ParseCert(root)
assert.NoError(err)
assert.Equal(parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", conf.ClusterID))
}
func TestConsulCAProvider_Bootstrap_WithCert(t *testing.T) {
t.Parallel()
// Make sure setting a custom private key/root cert works.
assert := assert.New(t)
rootCA := connect.TestCA(t, nil)
conf := testConsulCAConfig()
conf.Config = map[string]interface{}{
"PrivateKey": rootCA.SigningKey,
"RootCert": rootCA.RootCert,
}
delegate := newMockDelegate(t, conf)
provider, err := NewConsulProvider(conf.Config, delegate)
assert.NoError(err)
root, err := provider.ActiveRoot()
assert.NoError(err)
assert.Equal(root, rootCA.RootCert)
}
func TestConsulCAProvider_SignLeaf(t *testing.T) {
t.Parallel()
assert := assert.New(t)
conf := testConsulCAConfig()
delegate := newMockDelegate(t, conf)
provider, err := NewConsulProvider(conf.Config, delegate)
assert.NoError(err)
spiffeService := &connect.SpiffeIDService{
Host: "node1",
Namespace: "default",
Datacenter: "dc1",
Service: "foo",
}
// Generate a leaf cert for the service.
{
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
assert.NoError(err)
cert, err := provider.Sign(csr)
assert.NoError(err)
parsed, err := connect.ParseCert(cert)
assert.NoError(err)
assert.Equal(parsed.URIs[0], spiffeService.URI())
assert.Equal(parsed.Subject.CommonName, "foo")
assert.Equal(uint64(2), parsed.SerialNumber.Uint64())
// Ensure the cert is valid now and expires within the correct limit.
assert.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
assert.True(parsed.NotBefore.Before(time.Now()))
}
// Generate a new cert for another service and make sure
// the serial number is incremented.
spiffeService.Service = "bar"
{
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
assert.NoError(err)
cert, err := provider.Sign(csr)
assert.NoError(err)
parsed, err := connect.ParseCert(cert)
assert.NoError(err)
assert.Equal(parsed.URIs[0], spiffeService.URI())
assert.Equal(parsed.Subject.CommonName, "bar")
assert.Equal(parsed.SerialNumber.Uint64(), uint64(2))
// Ensure the cert is valid now and expires within the correct limit.
assert.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
assert.True(parsed.NotBefore.Before(time.Now()))
}
}
func TestConsulCAProvider_CrossSignCA(t *testing.T) {
t.Parallel()
conf1 := testConsulCAConfig()
delegate1 := newMockDelegate(t, conf1)
provider1, err := NewConsulProvider(conf1.Config, delegate1)
require.NoError(t, err)
conf2 := testConsulCAConfig()
conf2.CreateIndex = 10
delegate2 := newMockDelegate(t, conf2)
provider2, err := NewConsulProvider(conf2.Config, delegate2)
require.NoError(t, err)
testCrossSignProviders(t, provider1, provider2)
}
func testCrossSignProviders(t *testing.T, provider1, provider2 Provider) {
require := require.New(t)
// Get the root from the new provider to be cross-signed.
newRootPEM, err := provider2.ActiveRoot()
require.NoError(err)
newRoot, err := connect.ParseCert(newRootPEM)
require.NoError(err)
oldSubject := newRoot.Subject.CommonName
newInterPEM, err := provider2.ActiveIntermediate()
require.NoError(err)
newIntermediate, err := connect.ParseCert(newInterPEM)
require.NoError(err)
// Have provider1 cross sign our new root cert.
xcPEM, err := provider1.CrossSignCA(newRoot)
require.NoError(err)
xc, err := connect.ParseCert(xcPEM)
require.NoError(err)
oldRootPEM, err := provider1.ActiveRoot()
require.NoError(err)
oldRoot, err := connect.ParseCert(oldRootPEM)
require.NoError(err)
// AuthorityKeyID should now be the signing root's, SubjectKeyId should be kept.
require.Equal(oldRoot.AuthorityKeyId, xc.AuthorityKeyId)
require.Equal(newRoot.SubjectKeyId, xc.SubjectKeyId)
// Subject name should not have changed.
require.Equal(oldSubject, xc.Subject.CommonName)
// Issuer should be the signing root.
require.Equal(oldRoot.Issuer.CommonName, xc.Issuer.CommonName)
// Get a leaf cert so we can verify against the cross-signed cert.
spiffeService := &connect.SpiffeIDService{
Host: "node1",
Namespace: "default",
Datacenter: "dc1",
Service: "foo",
}
raw, _ := connect.TestCSR(t, spiffeService)
leafCsr, err := connect.ParseCSR(raw)
require.NoError(err)
leafPEM, err := provider2.Sign(leafCsr)
require.NoError(err)
cert, err := connect.ParseCert(leafPEM)
require.NoError(err)
// Check that the leaf signed by the new cert can be verified by either root
// certificate by using the new intermediate + cross-signed cert.
intermediatePool := x509.NewCertPool()
intermediatePool.AddCert(newIntermediate)
intermediatePool.AddCert(xc)
for _, root := range []*x509.Certificate{oldRoot, newRoot} {
rootPool := x509.NewCertPool()
rootPool.AddCert(root)
_, err = cert.Verify(x509.VerifyOptions{
Intermediates: intermediatePool,
Roots: rootPool,
})
require.NoError(err)
}
}

View File

@ -0,0 +1,322 @@
package ca
import (
"bytes"
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"net/http"
"strings"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-uuid"
vaultapi "github.com/hashicorp/vault/api"
"github.com/mitchellh/mapstructure"
)
const VaultCALeafCertRole = "leaf-cert"
var ErrBackendNotMounted = fmt.Errorf("backend not mounted")
var ErrBackendNotInitialized = fmt.Errorf("backend not initialized")
type VaultProvider struct {
config *structs.VaultCAProviderConfig
client *vaultapi.Client
clusterId string
}
// NewVaultProvider returns a vault provider with its root and intermediate PKI
// backends mounted and initialized. If the root backend is not set up already,
// it will be mounted/generated as needed, but any existing state will not be
// overwritten.
func NewVaultProvider(rawConfig map[string]interface{}, clusterId string) (*VaultProvider, error) {
conf, err := ParseVaultCAConfig(rawConfig)
if err != nil {
return nil, err
}
// todo(kyhavlov): figure out the right way to pass the TLS config
clientConf := &vaultapi.Config{
Address: conf.Address,
}
client, err := vaultapi.NewClient(clientConf)
if err != nil {
return nil, err
}
client.SetToken(conf.Token)
provider := &VaultProvider{
config: conf,
client: client,
clusterId: clusterId,
}
// Set up the root PKI backend if necessary.
_, err = provider.ActiveRoot()
switch err {
case ErrBackendNotMounted:
err := client.Sys().Mount(conf.RootPKIPath, &vaultapi.MountInput{
Type: "pki",
Description: "root CA backend for Consul Connect",
Config: vaultapi.MountConfigInput{
MaxLeaseTTL: "8760h",
},
})
if err != nil {
return nil, err
}
fallthrough
case ErrBackendNotInitialized:
spiffeID := connect.SpiffeIDSigning{ClusterID: clusterId, Domain: "consul"}
uuid, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
_, err = client.Logical().Write(conf.RootPKIPath+"root/generate/internal", map[string]interface{}{
"common_name": fmt.Sprintf("Vault CA Root Authority %s", uuid),
"uri_sans": spiffeID.URI().String(),
})
if err != nil {
return nil, err
}
default:
if err != nil {
return nil, err
}
}
// Set up the intermediate backend.
if _, err := provider.GenerateIntermediate(); err != nil {
return nil, err
}
return provider, nil
}
func (v *VaultProvider) ActiveRoot() (string, error) {
return v.getCA(v.config.RootPKIPath)
}
func (v *VaultProvider) ActiveIntermediate() (string, error) {
return v.getCA(v.config.IntermediatePKIPath)
}
// getCA returns the raw CA cert for the given endpoint if there is one.
// We have to use the raw NewRequest call here instead of Logical().Read
// because the endpoint only returns the raw PEM contents of the CA cert
// and not the typical format of the secrets endpoints.
func (v *VaultProvider) getCA(path string) (string, error) {
req := v.client.NewRequest("GET", "/v1/"+path+"/ca/pem")
resp, err := v.client.RawRequest(req)
if resp != nil {
defer resp.Body.Close()
}
if resp != nil && resp.StatusCode == http.StatusNotFound {
return "", ErrBackendNotMounted
}
if err != nil {
return "", err
}
bytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
root := string(bytes)
if root == "" {
return "", ErrBackendNotInitialized
}
return root, nil
}
// GenerateIntermediate mounts the configured intermediate PKI backend if
// necessary, then generates and signs a new CA CSR using the root PKI backend
// and updates the intermediate backend to use that new certificate.
func (v *VaultProvider) GenerateIntermediate() (string, error) {
mounts, err := v.client.Sys().ListMounts()
if err != nil {
return "", err
}
// Mount the backend if it isn't mounted already.
if _, ok := mounts[v.config.IntermediatePKIPath]; !ok {
err := v.client.Sys().Mount(v.config.IntermediatePKIPath, &vaultapi.MountInput{
Type: "pki",
Description: "intermediate CA backend for Consul Connect",
Config: vaultapi.MountConfigInput{
MaxLeaseTTL: "2160h",
},
})
if err != nil {
return "", err
}
}
// Create the role for issuing leaf certs if it doesn't exist yet
rolePath := v.config.IntermediatePKIPath + "roles/" + VaultCALeafCertRole
role, err := v.client.Logical().Read(rolePath)
if err != nil {
return "", err
}
spiffeID := connect.SpiffeIDSigning{ClusterID: v.clusterId, Domain: "consul"}
if role == nil {
_, err := v.client.Logical().Write(rolePath, map[string]interface{}{
"allow_any_name": true,
"allowed_uri_sans": "spiffe://*",
"key_type": "any",
"max_ttl": "72h",
"require_cn": false,
})
if err != nil {
return "", err
}
}
// Generate a new intermediate CSR for the root to sign.
csr, err := v.client.Logical().Write(v.config.IntermediatePKIPath+"intermediate/generate/internal", map[string]interface{}{
"common_name": "Vault CA Intermediate Authority",
"uri_sans": spiffeID.URI().String(),
})
if err != nil {
return "", err
}
if csr == nil || csr.Data["csr"] == "" {
return "", fmt.Errorf("got empty value when generating intermediate CSR")
}
// Sign the CSR with the root backend.
intermediate, err := v.client.Logical().Write(v.config.RootPKIPath+"root/sign-intermediate", map[string]interface{}{
"csr": csr.Data["csr"],
"format": "pem_bundle",
})
if err != nil {
return "", err
}
if intermediate == nil || intermediate.Data["certificate"] == "" {
return "", fmt.Errorf("got empty value when generating intermediate certificate")
}
// Set the intermediate backend to use the new certificate.
_, err = v.client.Logical().Write(v.config.IntermediatePKIPath+"intermediate/set-signed", map[string]interface{}{
"certificate": intermediate.Data["certificate"],
})
if err != nil {
return "", err
}
return v.ActiveIntermediate()
}
// Sign calls the configured role in the intermediate PKI backend to issue
// a new leaf certificate based on the provided CSR, with the issuing
// intermediate CA cert attached.
func (v *VaultProvider) Sign(csr *x509.CertificateRequest) (string, error) {
var pemBuf bytes.Buffer
if err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw}); err != nil {
return "", err
}
// Use the leaf cert role to sign a new cert for this CSR.
response, err := v.client.Logical().Write(v.config.IntermediatePKIPath+"sign/"+VaultCALeafCertRole, map[string]interface{}{
"csr": pemBuf.String(),
})
if err != nil {
return "", fmt.Errorf("error issuing cert: %v", err)
}
if response == nil || response.Data["certificate"] == "" || response.Data["issuing_ca"] == "" {
return "", fmt.Errorf("certificate info returned from Vault was blank")
}
cert, ok := response.Data["certificate"].(string)
if !ok {
return "", fmt.Errorf("certificate was not a string")
}
ca, ok := response.Data["issuing_ca"].(string)
if !ok {
return "", fmt.Errorf("issuing_ca was not a string")
}
return fmt.Sprintf("%s\n%s", cert, ca), nil
}
// CrossSignCA takes a CA certificate and cross-signs it to form a trust chain
// back to our active root.
func (v *VaultProvider) CrossSignCA(cert *x509.Certificate) (string, error) {
var pemBuf bytes.Buffer
err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
if err != nil {
return "", err
}
// Have the root PKI backend sign this cert.
response, err := v.client.Logical().Write(v.config.RootPKIPath+"root/sign-self-issued", map[string]interface{}{
"certificate": pemBuf.String(),
})
if err != nil {
return "", fmt.Errorf("error having Vault cross-sign cert: %v", err)
}
if response == nil || response.Data["certificate"] == "" {
return "", fmt.Errorf("certificate info returned from Vault was blank")
}
xcCert, ok := response.Data["certificate"].(string)
if !ok {
return "", fmt.Errorf("certificate was not a string")
}
return xcCert, nil
}
// Cleanup unmounts the configured intermediate PKI backend. It's fine to tear
// this down and recreate it on small config changes because the intermediate
// certs get bundled with the leaf certs, so there's no cost to the CA changing.
func (v *VaultProvider) Cleanup() error {
return v.client.Sys().Unmount(v.config.IntermediatePKIPath)
}
func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) {
var config structs.VaultCAProviderConfig
decodeConf := &mapstructure.DecoderConfig{
ErrorUnused: true,
Result: &config,
WeaklyTypedInput: true,
}
decoder, err := mapstructure.NewDecoder(decodeConf)
if err != nil {
return nil, err
}
if err := decoder.Decode(raw); err != nil {
return nil, fmt.Errorf("error decoding config: %s", err)
}
if config.Token == "" {
return nil, fmt.Errorf("must provide a Vault token")
}
if config.RootPKIPath == "" {
return nil, fmt.Errorf("must provide a valid path to a root PKI backend")
}
if !strings.HasSuffix(config.RootPKIPath, "/") {
config.RootPKIPath += "/"
}
if config.IntermediatePKIPath == "" {
return nil, fmt.Errorf("must provide a valid path for the intermediate PKI backend")
}
if !strings.HasSuffix(config.IntermediatePKIPath, "/") {
config.IntermediatePKIPath += "/"
}
return &config, nil
}

View File

@ -0,0 +1,162 @@
package ca
import (
"fmt"
"io/ioutil"
"net"
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
vaultapi "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/pki"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
"github.com/stretchr/testify/require"
)
func testVaultCluster(t *testing.T) (*VaultProvider, *vault.Core, net.Listener) {
if err := vault.AddTestLogicalBackend("pki", pki.Factory); err != nil {
t.Fatal(err)
}
core, _, token := vault.TestCoreUnsealedRaw(t)
ln, addr := vaulthttp.TestServer(t, core)
provider, err := NewVaultProvider(map[string]interface{}{
"Address": addr,
"Token": token,
"RootPKIPath": "pki-root/",
"IntermediatePKIPath": "pki-intermediate/",
}, "asdf")
if err != nil {
t.Fatal(err)
}
return provider, core, ln
}
func TestVaultCAProvider_Bootstrap(t *testing.T) {
t.Parallel()
require := require.New(t)
provider, core, listener := testVaultCluster(t)
defer core.Shutdown()
defer listener.Close()
client, err := vaultapi.NewClient(&vaultapi.Config{
Address: "http://" + listener.Addr().String(),
})
require.NoError(err)
client.SetToken(provider.config.Token)
cases := []struct {
certFunc func() (string, error)
backendPath string
}{
{
certFunc: provider.ActiveRoot,
backendPath: "pki-root/",
},
{
certFunc: provider.ActiveIntermediate,
backendPath: "pki-intermediate/",
},
}
// Verify the root and intermediate certs match the ones in the vault backends
for _, tc := range cases {
cert, err := tc.certFunc()
require.NoError(err)
req := client.NewRequest("GET", "/v1/"+tc.backendPath+"ca/pem")
resp, err := client.RawRequest(req)
require.NoError(err)
bytes, err := ioutil.ReadAll(resp.Body)
require.NoError(err)
require.Equal(cert, string(bytes))
// Should be a valid CA cert
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.True(parsed.IsCA)
require.Len(parsed.URIs, 1)
require.Equal(parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", provider.clusterId))
}
}
func TestVaultCAProvider_SignLeaf(t *testing.T) {
t.Parallel()
require := require.New(t)
provider, core, listener := testVaultCluster(t)
defer core.Shutdown()
defer listener.Close()
client, err := vaultapi.NewClient(&vaultapi.Config{
Address: "http://" + listener.Addr().String(),
})
require.NoError(err)
client.SetToken(provider.config.Token)
spiffeService := &connect.SpiffeIDService{
Host: "node1",
Namespace: "default",
Datacenter: "dc1",
Service: "foo",
}
// Generate a leaf cert for the service.
var firstSerial uint64
{
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
require.NoError(err)
cert, err := provider.Sign(csr)
require.NoError(err)
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.Equal(parsed.URIs[0], spiffeService.URI())
firstSerial = parsed.SerialNumber.Uint64()
// Ensure the cert is valid now and expires within the correct limit.
require.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
require.True(parsed.NotBefore.Before(time.Now()))
}
// Generate a new cert for another service and make sure
// the serial number is unique.
spiffeService.Service = "bar"
{
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
require.NoError(err)
cert, err := provider.Sign(csr)
require.NoError(err)
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.Equal(parsed.URIs[0], spiffeService.URI())
require.NotEqual(firstSerial, parsed.SerialNumber.Uint64())
// Ensure the cert is valid now and expires within the correct limit.
require.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
require.True(parsed.NotBefore.Before(time.Now()))
}
}
func TestVaultCAProvider_CrossSignCA(t *testing.T) {
t.Parallel()
provider1, core1, listener1 := testVaultCluster(t)
defer core1.Shutdown()
defer listener1.Close()
provider2, core2, listener2 := testVaultCluster(t)
defer core2.Shutdown()
defer listener2.Close()
testCrossSignProviders(t, provider1, provider2)
}

33
agent/connect/csr.go Normal file
View File

@ -0,0 +1,33 @@
package connect
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"net/url"
)
// CreateCSR returns a CSR to sign the given service along with the PEM-encoded
// private key for this certificate.
func CreateCSR(uri CertURI, privateKey crypto.Signer) (string, error) {
template := &x509.CertificateRequest{
URIs: []*url.URL{uri.URI()},
SignatureAlgorithm: x509.ECDSAWithSHA256,
}
// Create the CSR itself
var csrBuf bytes.Buffer
bs, err := x509.CreateCertificateRequest(rand.Reader, template, privateKey)
if err != nil {
return "", err
}
err = pem.Encode(&csrBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: bs})
if err != nil {
return "", err
}
return csrBuf.String(), nil
}

35
agent/connect/generate.go Normal file
View File

@ -0,0 +1,35 @@
package connect
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"fmt"
)
// GeneratePrivateKey generates a new Private key
func GeneratePrivateKey() (crypto.Signer, string, error) {
var pk *ecdsa.PrivateKey
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, "", fmt.Errorf("error generating private key: %s", err)
}
bs, err := x509.MarshalECPrivateKey(pk)
if err != nil {
return nil, "", fmt.Errorf("error generating private key: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: bs})
if err != nil {
return nil, "", fmt.Errorf("error encoding private key: %s", err)
}
return pk, buf.String(), nil
}

121
agent/connect/parsing.go Normal file
View File

@ -0,0 +1,121 @@
package connect
import (
"crypto"
"crypto/ecdsa"
"crypto/sha1"
"crypto/sha256"
"crypto/x509"
"encoding/pem"
"fmt"
"strings"
)
// ParseCert parses the x509 certificate from a PEM-encoded value.
func ParseCert(pemValue string) (*x509.Certificate, error) {
// The _ result below is not an error but the remaining PEM bytes.
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return nil, fmt.Errorf("no PEM-encoded data found")
}
if block.Type != "CERTIFICATE" {
return nil, fmt.Errorf("first PEM-block should be CERTIFICATE type")
}
return x509.ParseCertificate(block.Bytes)
}
// CalculateCertFingerprint parses the x509 certificate from a PEM-encoded value
// and calculates the SHA-1 fingerprint.
func CalculateCertFingerprint(pemValue string) (string, error) {
// The _ result below is not an error but the remaining PEM bytes.
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return "", fmt.Errorf("no PEM-encoded data found")
}
if block.Type != "CERTIFICATE" {
return "", fmt.Errorf("first PEM-block should be CERTIFICATE type")
}
hash := sha1.Sum(block.Bytes)
return HexString(hash[:]), nil
}
// ParseSigner parses a crypto.Signer from a PEM-encoded key. The private key
// is expected to be the first block in the PEM value.
func ParseSigner(pemValue string) (crypto.Signer, error) {
// The _ result below is not an error but the remaining PEM bytes.
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return nil, fmt.Errorf("no PEM-encoded data found")
}
switch block.Type {
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(block.Bytes)
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
case "PRIVATE KEY":
signer, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
pk, ok := signer.(crypto.Signer)
if !ok {
return nil, fmt.Errorf("private key is not a valid format")
}
return pk, nil
default:
return nil, fmt.Errorf("unknown PEM block type for signing key: %s", block.Type)
}
}
// ParseCSR parses a CSR from a PEM-encoded value. The certificate request
// must be the the first block in the PEM value.
func ParseCSR(pemValue string) (*x509.CertificateRequest, error) {
// The _ result below is not an error but the remaining PEM bytes.
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return nil, fmt.Errorf("no PEM-encoded data found")
}
if block.Type != "CERTIFICATE REQUEST" {
return nil, fmt.Errorf("first PEM-block should be CERTIFICATE REQUEST type")
}
return x509.ParseCertificateRequest(block.Bytes)
}
// KeyId returns a x509 KeyId from the given signing key. The key must be
// an *ecdsa.PublicKey currently, but may support more types in the future.
func KeyId(raw interface{}) ([]byte, error) {
switch raw.(type) {
case *ecdsa.PublicKey:
default:
return nil, fmt.Errorf("invalid key type: %T", raw)
}
// This is not standard; RFC allows any unique identifier as long as they
// match in subject/authority chains but suggests specific hashing of DER
// bytes of public key including DER tags.
bs, err := x509.MarshalPKIXPublicKey(raw)
if err != nil {
return nil, err
}
// String formatted
kID := sha256.Sum256(bs)
return []byte(strings.Replace(fmt.Sprintf("% x", kID), " ", ":", -1)), nil
}
// HexString returns a standard colon-separated hex value for the input
// byte slice. This should be used with cert serial numbers and so on.
func HexString(input []byte) string {
return strings.Replace(fmt.Sprintf("% x", input), " ", ":", -1)
}

332
agent/connect/testing_ca.go Normal file
View File

@ -0,0 +1,332 @@
package connect
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net/url"
"sync/atomic"
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-uuid"
"github.com/mitchellh/go-testing-interface"
)
// TestClusterID is the Consul cluster ID for testing.
const TestClusterID = "11111111-2222-3333-4444-555555555555"
// testCACounter is just an atomically incremented counter for creating
// unique names for the CA certs.
var testCACounter uint64
// TestCA creates a test CA certificate and signing key and returns it
// in the CARoot structure format. The returned CA will be set as Active = true.
//
// If xc is non-nil, then the returned certificate will have a signing cert
// that is cross-signed with the previous cert, and this will be set as
// SigningCert.
func TestCA(t testing.T, xc *structs.CARoot) *structs.CARoot {
var result structs.CARoot
result.Active = true
result.Name = fmt.Sprintf("Test CA %d", atomic.AddUint64(&testCACounter, 1))
// Create the private key we'll use for this CA cert.
signer, keyPEM := testPrivateKey(t)
result.SigningKey = keyPEM
// The serial number for the cert
sn, err := testSerialNumber()
if err != nil {
t.Fatalf("error generating serial number: %s", err)
}
// The URI (SPIFFE compatible) for the cert
id := &SpiffeIDSigning{ClusterID: TestClusterID, Domain: "consul"}
// Create the CA cert
template := x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{CommonName: result.Name},
URIs: []*url.URL{id.URI()},
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageCertSign |
x509.KeyUsageCRLSign |
x509.KeyUsageDigitalSignature,
IsCA: true,
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
NotBefore: time.Now(),
AuthorityKeyId: testKeyID(t, signer.Public()),
SubjectKeyId: testKeyID(t, signer.Public()),
}
bs, err := x509.CreateCertificate(
rand.Reader, &template, &template, signer.Public(), signer)
if err != nil {
t.Fatalf("error generating CA certificate: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
t.Fatalf("error encoding private key: %s", err)
}
result.RootCert = buf.String()
result.ID, err = CalculateCertFingerprint(result.RootCert)
if err != nil {
t.Fatalf("error generating CA ID fingerprint: %s", err)
}
// If there is a prior CA to cross-sign with, then we need to create that
// and set it as the signing cert.
if xc != nil {
xccert, err := ParseCert(xc.RootCert)
if err != nil {
t.Fatalf("error parsing CA cert: %s", err)
}
xcsigner, err := ParseSigner(xc.SigningKey)
if err != nil {
t.Fatalf("error parsing signing key: %s", err)
}
// Set the authority key to be the previous one.
// NOTE(mitchellh): From Paul Banks: if we have to cross-sign a cert
// that came from outside (e.g. vault) we can't rely on them using the
// same KeyID hashing algo we do so we'd need to actually copy this
// from the xc cert's subjectKeyIdentifier extension.
template.AuthorityKeyId = testKeyID(t, xcsigner.Public())
// Create the new certificate where the parent is the previous
// CA, the public key is the new public key, and the signing private
// key is the old private key.
bs, err := x509.CreateCertificate(
rand.Reader, &template, xccert, signer.Public(), xcsigner)
if err != nil {
t.Fatalf("error generating CA certificate: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
t.Fatalf("error encoding private key: %s", err)
}
result.SigningCert = buf.String()
}
return &result
}
// TestLeaf returns a valid leaf certificate and it's private key for the named
// service with the given CA Root.
func TestLeaf(t testing.T, service string, root *structs.CARoot) (string, string) {
// Parse the CA cert and signing key from the root
cert := root.SigningCert
if cert == "" {
cert = root.RootCert
}
caCert, err := ParseCert(cert)
if err != nil {
t.Fatalf("error parsing CA cert: %s", err)
}
caSigner, err := ParseSigner(root.SigningKey)
if err != nil {
t.Fatalf("error parsing signing key: %s", err)
}
// Build the SPIFFE ID
spiffeId := &SpiffeIDService{
Host: fmt.Sprintf("%s.consul", TestClusterID),
Namespace: "default",
Datacenter: "dc1",
Service: service,
}
// The serial number for the cert
sn, err := testSerialNumber()
if err != nil {
t.Fatalf("error generating serial number: %s", err)
}
// Generate fresh private key
pkSigner, pkPEM, err := GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key: %s", err)
}
// Cert template for generation
template := x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{CommonName: service},
URIs: []*url.URL{spiffeId.URI()},
SignatureAlgorithm: x509.ECDSAWithSHA256,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageDataEncipherment |
x509.KeyUsageKeyAgreement |
x509.KeyUsageDigitalSignature |
x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
NotBefore: time.Now(),
AuthorityKeyId: testKeyID(t, caSigner.Public()),
SubjectKeyId: testKeyID(t, pkSigner.Public()),
}
// Create the certificate, PEM encode it and return that value.
var buf bytes.Buffer
bs, err := x509.CreateCertificate(
rand.Reader, &template, caCert, pkSigner.Public(), caSigner)
if err != nil {
t.Fatalf("error generating certificate: %s", err)
}
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
t.Fatalf("error encoding private key: %s", err)
}
return buf.String(), pkPEM
}
// TestCSR returns a CSR to sign the given service along with the PEM-encoded
// private key for this certificate.
func TestCSR(t testing.T, uri CertURI) (string, string) {
template := &x509.CertificateRequest{
URIs: []*url.URL{uri.URI()},
SignatureAlgorithm: x509.ECDSAWithSHA256,
}
// Create the private key we'll use
signer, pkPEM := testPrivateKey(t)
// Create the CSR itself
var csrBuf bytes.Buffer
bs, err := x509.CreateCertificateRequest(rand.Reader, template, signer)
if err != nil {
t.Fatalf("error creating CSR: %s", err)
}
err = pem.Encode(&csrBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: bs})
if err != nil {
t.Fatalf("error encoding CSR: %s", err)
}
return csrBuf.String(), pkPEM
}
// testKeyID returns a KeyID from the given public key. This just calls
// KeyId but handles errors for tests.
func testKeyID(t testing.T, raw interface{}) []byte {
result, err := KeyId(raw)
if err != nil {
t.Fatalf("KeyId error: %s", err)
}
return result
}
// testPrivateKey creates an ECDSA based private key. Both a crypto.Signer and
// the key in PEM form are returned.
//
// NOTE(banks): this was memoized to save entropy during tests but it turns out
// crypto/rand will never block and always reads from /dev/urandom on unix OSes
// which does not consume entropy.
//
// If we find by profiling it's taking a lot of cycles we could optimise/cache
// again but we at least need to use different keys for each distinct CA (when
// multiple CAs are generated at once e.g. to test cross-signing) and a
// different one again for the leafs otherwise we risk tests that have false
// positives since signatures from different logical cert's keys are
// indistinguishable, but worse we build validation chains using AuthorityKeyID
// which will be the same for multiple CAs/Leafs. Also note that our UUID
// generator also reads from crypto rand and is called far more often during
// tests than this will be.
func testPrivateKey(t testing.T) (crypto.Signer, string) {
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("error generating private key: %s", err)
}
bs, err := x509.MarshalECPrivateKey(pk)
if err != nil {
t.Fatalf("error generating private key: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: bs})
if err != nil {
t.Fatalf("error encoding private key: %s", err)
}
return pk, buf.String()
}
// testSerialNumber generates a serial number suitable for a certificate.
// For testing, this just sets it to a random number.
//
// This function is taken directly from the Vault implementation.
func testSerialNumber() (*big.Int, error) {
return rand.Int(rand.Reader, (&big.Int{}).Exp(big.NewInt(2), big.NewInt(159), nil))
}
// testUUID generates a UUID for testing.
func testUUID(t testing.T) string {
ret, err := uuid.GenerateUUID()
if err != nil {
t.Fatalf("Unable to generate a UUID, %s", err)
}
return ret
}
// TestAgentRPC is an interface that an RPC client must implement. This is a
// helper interface that is implemented by the agent delegate so that test
// helpers can make RPCs without introducing an import cycle on `agent`.
type TestAgentRPC interface {
RPC(method string, args interface{}, reply interface{}) error
}
// TestCAConfigSet sets a CARoot returned by TestCA into the TestAgent state. It
// requires that TestAgent had connect enabled in it's config. If ca is nil, a
// new CA is created.
//
// It returns the CARoot passed or created.
//
// Note that we have to use an interface for the TestAgent.RPC method since we
// can't introduce an import cycle by importing `agent.TestAgent` here directly.
// It also means this will work in a few other places we mock that method.
func TestCAConfigSet(t testing.T, a TestAgentRPC,
ca *structs.CARoot) *structs.CARoot {
t.Helper()
if ca == nil {
ca = TestCA(t, nil)
}
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": ca.SigningKey,
"RootCert": ca.RootCert,
"RotationPeriod": 180 * 24 * time.Hour,
},
}
args := &structs.CARequest{
Datacenter: "dc1",
Config: newConfig,
}
var reply interface{}
err := a.RPC("ConnectCA.ConfigurationSet", args, &reply)
if err != nil {
t.Fatalf("failed to set test CA config: %s", err)
}
return ca
}

View File

@ -0,0 +1,100 @@
package connect
import (
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
// hasOpenSSL is used to determine if the openssl CLI exists for unit tests.
var hasOpenSSL bool
func init() {
_, err := exec.LookPath("openssl")
hasOpenSSL = err == nil
}
// Test that the TestCA and TestLeaf functions generate valid certificates.
func TestTestCAAndLeaf(t *testing.T) {
if !hasOpenSSL {
t.Skip("openssl not found")
return
}
assert := assert.New(t)
// Create the certs
ca := TestCA(t, nil)
leaf, _ := TestLeaf(t, "web", ca)
// Create a temporary directory for storing the certs
td, err := ioutil.TempDir("", "consul")
assert.Nil(err)
defer os.RemoveAll(td)
// Write the cert
assert.Nil(ioutil.WriteFile(filepath.Join(td, "ca.pem"), []byte(ca.RootCert), 0644))
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf.pem"), []byte(leaf), 0644))
// Use OpenSSL to verify so we have an external, known-working process
// that can verify this outside of our own implementations.
cmd := exec.Command(
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf.pem")
cmd.Dir = td
output, err := cmd.Output()
t.Log(string(output))
assert.Nil(err)
}
// Test cross-signing.
func TestTestCAAndLeaf_xc(t *testing.T) {
if !hasOpenSSL {
t.Skip("openssl not found")
return
}
assert := assert.New(t)
// Create the certs
ca1 := TestCA(t, nil)
ca2 := TestCA(t, ca1)
leaf1, _ := TestLeaf(t, "web", ca1)
leaf2, _ := TestLeaf(t, "web", ca2)
// Create a temporary directory for storing the certs
td, err := ioutil.TempDir("", "consul")
assert.Nil(err)
defer os.RemoveAll(td)
// Write the cert
xcbundle := []byte(ca1.RootCert)
xcbundle = append(xcbundle, '\n')
xcbundle = append(xcbundle, []byte(ca2.SigningCert)...)
assert.Nil(ioutil.WriteFile(filepath.Join(td, "ca.pem"), xcbundle, 0644))
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf1.pem"), []byte(leaf1), 0644))
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf2.pem"), []byte(leaf2), 0644))
// OpenSSL verify the cross-signed leaf (leaf2)
{
cmd := exec.Command(
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf2.pem")
cmd.Dir = td
output, err := cmd.Output()
t.Log(string(output))
assert.Nil(err)
}
// OpenSSL verify the old leaf (leaf1)
{
cmd := exec.Command(
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf1.pem")
cmd.Dir = td
output, err := cmd.Output()
t.Log(string(output))
assert.Nil(err)
}
}

View File

@ -0,0 +1,21 @@
package connect
import (
"github.com/mitchellh/go-testing-interface"
)
// TestSpiffeIDService returns a SPIFFE ID representing a service.
func TestSpiffeIDService(t testing.T, service string) *SpiffeIDService {
return TestSpiffeIDServiceWithHost(t, service, TestClusterID+".consul")
}
// TestSpiffeIDServiceWithHost returns a SPIFFE ID representing a service with
// the specified trust domain.
func TestSpiffeIDServiceWithHost(t testing.T, service, host string) *SpiffeIDService {
return &SpiffeIDService{
Host: host,
Namespace: "default",
Datacenter: "dc1",
Service: service,
}
}

91
agent/connect/uri.go Normal file
View File

@ -0,0 +1,91 @@
package connect
import (
"fmt"
"net/url"
"regexp"
"strings"
"github.com/hashicorp/consul/agent/structs"
)
// CertURI represents a Connect-valid URI value for a TLS certificate.
// The user should type switch on the various implementations in this
// package to determine the type of URI and the data encoded within it.
//
// Note that the current implementations of this are all also SPIFFE IDs.
// However, we anticipate that we may accept URIs that are also not SPIFFE
// compliant and therefore the interface is named as such.
type CertURI interface {
// Authorize tests the authorization for this URI as a client
// for the given intention. The return value `auth` is only valid if
// the second value `match` is true. If the second value `match` is
// false, then the intention doesn't match this client and any
// result should be ignored.
Authorize(*structs.Intention) (auth bool, match bool)
// URI is the valid URI value used in the cert.
URI() *url.URL
}
var (
spiffeIDServiceRegexp = regexp.MustCompile(
`^/ns/([^/]+)/dc/([^/]+)/svc/([^/]+)$`)
)
// ParseCertURI parses a the URI value from a TLS certificate.
func ParseCertURI(input *url.URL) (CertURI, error) {
if input.Scheme != "spiffe" {
return nil, fmt.Errorf("SPIFFE ID must have 'spiffe' scheme")
}
// Path is the raw value of the path without url decoding values.
// RawPath is empty if there were no encoded values so we must
// check both.
path := input.Path
if input.RawPath != "" {
path = input.RawPath
}
// Test for service IDs
if v := spiffeIDServiceRegexp.FindStringSubmatch(path); v != nil {
// Determine the values. We assume they're sane to save cycles,
// but if the raw path is not empty that means that something is
// URL encoded so we go to the slow path.
ns := v[1]
dc := v[2]
service := v[3]
if input.RawPath != "" {
var err error
if ns, err = url.PathUnescape(v[1]); err != nil {
return nil, fmt.Errorf("Invalid namespace: %s", err)
}
if dc, err = url.PathUnescape(v[2]); err != nil {
return nil, fmt.Errorf("Invalid datacenter: %s", err)
}
if service, err = url.PathUnescape(v[3]); err != nil {
return nil, fmt.Errorf("Invalid service: %s", err)
}
}
return &SpiffeIDService{
Host: input.Host,
Namespace: ns,
Datacenter: dc,
Service: service,
}, nil
}
// Test for signing ID
if input.Path == "" {
idx := strings.Index(input.Host, ".")
if idx > 0 {
return &SpiffeIDSigning{
ClusterID: input.Host[:idx],
Domain: input.Host[idx+1:],
}, nil
}
}
return nil, fmt.Errorf("SPIFFE ID is not in the expected format")
}

View File

@ -0,0 +1,42 @@
package connect
import (
"fmt"
"net/url"
"github.com/hashicorp/consul/agent/structs"
)
// SpiffeIDService is the structure to represent the SPIFFE ID for a service.
type SpiffeIDService struct {
Host string
Namespace string
Datacenter string
Service string
}
// URI returns the *url.URL for this SPIFFE ID.
func (id *SpiffeIDService) URI() *url.URL {
var result url.URL
result.Scheme = "spiffe"
result.Host = id.Host
result.Path = fmt.Sprintf("/ns/%s/dc/%s/svc/%s",
id.Namespace, id.Datacenter, id.Service)
return &result
}
// CertURI impl.
func (id *SpiffeIDService) Authorize(ixn *structs.Intention) (bool, bool) {
if ixn.SourceNS != structs.IntentionWildcard && ixn.SourceNS != id.Namespace {
// Non-matching namespace
return false, false
}
if ixn.SourceName != structs.IntentionWildcard && ixn.SourceName != id.Service {
// Non-matching name
return false, false
}
// Match, return allow value
return ixn.Action == structs.IntentionActionAllow, true
}

View File

@ -0,0 +1,104 @@
package connect
import (
"testing"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/assert"
)
func TestSpiffeIDServiceAuthorize(t *testing.T) {
ns := structs.IntentionDefaultNamespace
serviceWeb := &SpiffeIDService{
Host: "1234.consul",
Namespace: structs.IntentionDefaultNamespace,
Datacenter: "dc01",
Service: "web",
}
cases := []struct {
Name string
URI *SpiffeIDService
Ixn *structs.Intention
Auth bool
Match bool
}{
{
"exact source, not matching namespace",
serviceWeb,
&structs.Intention{
SourceNS: "different",
SourceName: "db",
},
false,
false,
},
{
"exact source, not matching name",
serviceWeb,
&structs.Intention{
SourceNS: ns,
SourceName: "db",
},
false,
false,
},
{
"exact source, allow",
serviceWeb,
&structs.Intention{
SourceNS: serviceWeb.Namespace,
SourceName: serviceWeb.Service,
Action: structs.IntentionActionAllow,
},
true,
true,
},
{
"exact source, deny",
serviceWeb,
&structs.Intention{
SourceNS: serviceWeb.Namespace,
SourceName: serviceWeb.Service,
Action: structs.IntentionActionDeny,
},
false,
true,
},
{
"exact namespace, wildcard service, deny",
serviceWeb,
&structs.Intention{
SourceNS: serviceWeb.Namespace,
SourceName: structs.IntentionWildcard,
Action: structs.IntentionActionDeny,
},
false,
true,
},
{
"exact namespace, wildcard service, allow",
serviceWeb,
&structs.Intention{
SourceNS: serviceWeb.Namespace,
SourceName: structs.IntentionWildcard,
Action: structs.IntentionActionAllow,
},
true,
true,
},
}
for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
auth, match := tc.URI.Authorize(tc.Ixn)
assert.Equal(t, tc.Auth, auth)
assert.Equal(t, tc.Match, match)
})
}
}

View File

@ -0,0 +1,75 @@
package connect
import (
"fmt"
"net/url"
"strings"
"github.com/hashicorp/consul/agent/structs"
)
// SpiffeIDSigning is the structure to represent the SPIFFE ID for a
// signing certificate (not a leaf service).
type SpiffeIDSigning struct {
ClusterID string // Unique cluster ID
Domain string // The domain, usually "consul"
}
// URI returns the *url.URL for this SPIFFE ID.
func (id *SpiffeIDSigning) URI() *url.URL {
var result url.URL
result.Scheme = "spiffe"
result.Host = id.Host()
return &result
}
// Host is the canonical representation as a DNS-compatible hostname.
func (id *SpiffeIDSigning) Host() string {
return strings.ToLower(fmt.Sprintf("%s.%s", id.ClusterID, id.Domain))
}
// CertURI impl.
func (id *SpiffeIDSigning) Authorize(ixn *structs.Intention) (bool, bool) {
// Never authorize as a client.
return false, true
}
// CanSign takes any CertURI and returns whether or not this signing entity is
// allowed to sign CSRs for that entity (i.e. represents the trust domain for
// that entity).
//
// I choose to make this a fixed centralised method here for now rather than a
// method on CertURI interface since we don't intend this to be extensible
// outside and it's easier to reason about the security properties when they are
// all in one place with "whitelist" semantics.
func (id *SpiffeIDSigning) CanSign(cu CertURI) bool {
switch other := cu.(type) {
case *SpiffeIDSigning:
// We can only sign other CA certificates for the same trust domain. Note
// that we could open this up later for example to support external
// federation of roots and cross-signing external roots that have different
// URI structure but it's simpler to start off restrictive.
return id == other
case *SpiffeIDService:
// The host component of the service must be an exact match for now under
// ascii case folding (since hostnames are case-insensitive). Later we might
// worry about Unicode domains if we start allowing customisation beyond the
// built-in cluster ids.
return strings.ToLower(other.Host) == id.Host()
default:
return false
}
}
// SpiffeIDSigningForCluster returns the SPIFFE signing identifier (trust
// domain) representation of the given CA config. If config is nil this function
// will panic.
//
// NOTE(banks): we intentionally fix the tld `.consul` for now rather than tie
// this to the `domain` config used for DNS because changing DNS domain can't
// break all certificate validation. That does mean that DNS prefix might not
// match the identity URIs and so the trust domain might not actually resolve
// which we would like but don't actually need.
func SpiffeIDSigningForCluster(config *structs.CAConfiguration) *SpiffeIDSigning {
return &SpiffeIDSigning{ClusterID: config.ClusterID, Domain: "consul"}
}

View File

@ -0,0 +1,123 @@
package connect
import (
"net/url"
"strings"
"testing"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/assert"
)
// Signing ID should never authorize
func TestSpiffeIDSigningAuthorize(t *testing.T) {
var id SpiffeIDSigning
auth, ok := id.Authorize(nil)
assert.False(t, auth)
assert.True(t, ok)
}
func TestSpiffeIDSigningForCluster(t *testing.T) {
// For now it should just append .consul to the ID.
config := &structs.CAConfiguration{
ClusterID: TestClusterID,
}
id := SpiffeIDSigningForCluster(config)
assert.Equal(t, id.URI().String(), "spiffe://"+TestClusterID+".consul")
}
// fakeCertURI is a CertURI implementation that our implementation doesn't know
// about
type fakeCertURI string
func (f fakeCertURI) Authorize(*structs.Intention) (auth bool, match bool) {
return false, false
}
func (f fakeCertURI) URI() *url.URL {
u, _ := url.Parse(string(f))
return u
}
func TestSpiffeIDSigning_CanSign(t *testing.T) {
testSigning := &SpiffeIDSigning{
ClusterID: TestClusterID,
Domain: "consul",
}
tests := []struct {
name string
id *SpiffeIDSigning
input CertURI
want bool
}{
{
name: "same signing ID",
id: testSigning,
input: testSigning,
want: true,
},
{
name: "other signing ID",
id: testSigning,
input: &SpiffeIDSigning{
ClusterID: "fakedomain",
Domain: "consul",
},
want: false,
},
{
name: "different TLD signing ID",
id: testSigning,
input: &SpiffeIDSigning{
ClusterID: TestClusterID,
Domain: "evil",
},
want: false,
},
{
name: "nil",
id: testSigning,
input: nil,
want: false,
},
{
name: "unrecognised CertURI implementation",
id: testSigning,
input: fakeCertURI("spiffe://foo.bar/baz"),
want: false,
},
{
name: "service - good",
id: testSigning,
input: &SpiffeIDService{TestClusterID + ".consul", "default", "dc1", "web"},
want: true,
},
{
name: "service - good midex case",
id: testSigning,
input: &SpiffeIDService{strings.ToUpper(TestClusterID) + ".CONsuL", "defAUlt", "dc1", "WEB"},
want: true,
},
{
name: "service - different cluster",
id: testSigning,
input: &SpiffeIDService{"55555555-4444-3333-2222-111111111111.consul", "default", "dc1", "web"},
want: false,
},
{
name: "service - different TLD",
id: testSigning,
input: &SpiffeIDService{TestClusterID + ".fake", "default", "dc1", "web"},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.id.CanSign(tt.input)
assert.Equal(t, tt.want, got)
})
}
}

82
agent/connect/uri_test.go Normal file
View File

@ -0,0 +1,82 @@
package connect
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
// testCertURICases contains the test cases for parsing and encoding
// the SPIFFE IDs. This is a global since it is used in multiple test functions.
var testCertURICases = []struct {
Name string
URI string
Struct interface{}
ParseError string
}{
{
"invalid scheme",
"http://google.com/",
nil,
"scheme",
},
{
"basic service ID",
"spiffe://1234.consul/ns/default/dc/dc01/svc/web",
&SpiffeIDService{
Host: "1234.consul",
Namespace: "default",
Datacenter: "dc01",
Service: "web",
},
"",
},
{
"service with URL-encoded values",
"spiffe://1234.consul/ns/foo%2Fbar/dc/bar%2Fbaz/svc/baz%2Fqux",
&SpiffeIDService{
Host: "1234.consul",
Namespace: "foo/bar",
Datacenter: "bar/baz",
Service: "baz/qux",
},
"",
},
{
"signing ID",
"spiffe://1234.consul",
&SpiffeIDSigning{
ClusterID: "1234",
Domain: "consul",
},
"",
},
}
func TestParseCertURI(t *testing.T) {
for _, tc := range testCertURICases {
t.Run(tc.Name, func(t *testing.T) {
assert := assert.New(t)
// Parse the URI, should always be valid
uri, err := url.Parse(tc.URI)
assert.Nil(err)
// Parse the ID and check the error/return value
actual, err := ParseCertURI(uri)
if err != nil {
t.Logf("parse error: %s", err.Error())
}
assert.Equal(tc.ParseError != "", err != nil, "error value")
if err != nil {
assert.Contains(err.Error(), tc.ParseError)
return
}
assert.Equal(tc.Struct, actual)
})
}
}

View File

@ -0,0 +1,98 @@
package agent
import (
"fmt"
"net/http"
"github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/structs"
)
// GET /v1/connect/ca/roots
func (s *HTTPServer) ConnectCARoots(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var args structs.DCSpecificRequest
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
var reply structs.IndexedCARoots
defer setMeta(resp, &reply.QueryMeta)
if err := s.agent.RPC("ConnectCA.Roots", &args, &reply); err != nil {
return nil, err
}
return reply, nil
}
// /v1/connect/ca/configuration
func (s *HTTPServer) ConnectCAConfiguration(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
switch req.Method {
case "GET":
return s.ConnectCAConfigurationGet(resp, req)
case "PUT":
return s.ConnectCAConfigurationSet(resp, req)
default:
return nil, MethodNotAllowedError{req.Method, []string{"GET", "POST"}}
}
}
// GEt /v1/connect/ca/configuration
func (s *HTTPServer) ConnectCAConfigurationGet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in ConnectCAConfiguration
var args structs.DCSpecificRequest
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
var reply structs.CAConfiguration
err := s.agent.RPC("ConnectCA.ConfigurationGet", &args, &reply)
if err != nil {
return nil, err
}
fixupConfig(&reply)
return reply, nil
}
// PUT /v1/connect/ca/configuration
func (s *HTTPServer) ConnectCAConfigurationSet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in ConnectCAConfiguration
var args structs.CARequest
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req, &args.Config, nil); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
}
var reply interface{}
err := s.agent.RPC("ConnectCA.ConfigurationSet", &args, &reply)
return nil, err
}
// A hack to fix up the config types inside of the map[string]interface{}
// so that they get formatted correctly during json.Marshal. Without this,
// string values that get converted to []uint8 end up getting output back
// to the user in base64-encoded form.
func fixupConfig(conf *structs.CAConfiguration) {
for k, v := range conf.Config {
if raw, ok := v.([]uint8); ok {
strVal := ca.Uint8ToString(raw)
conf.Config[k] = strVal
switch conf.Provider {
case structs.ConsulCAProvider:
if k == "PrivateKey" && strVal != "" {
conf.Config["PrivateKey"] = "hidden"
}
case structs.VaultCAProvider:
if k == "Token" && strVal != "" {
conf.Config["Token"] = "hidden"
}
}
}
}
}

View File

@ -0,0 +1,115 @@
package agent
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
ca "github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/assert"
)
func TestConnectCARoots_empty(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "connect { enabled = false }")
defer a.Shutdown()
req, _ := http.NewRequest("GET", "/v1/connect/ca/roots", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.ConnectCARoots(resp, req)
assert.Nil(err)
value := obj.(structs.IndexedCARoots)
assert.Equal(value.ActiveRootID, "")
assert.Len(value.Roots, 0)
}
func TestConnectCARoots_list(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Set some CAs. Note that NewTestAgent already bootstraps one CA so this just
// adds a second and makes it active.
ca2 := connect.TestCAConfigSet(t, a, nil)
// List
req, _ := http.NewRequest("GET", "/v1/connect/ca/roots", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.ConnectCARoots(resp, req)
assert.NoError(err)
value := obj.(structs.IndexedCARoots)
assert.Equal(value.ActiveRootID, ca2.ID)
assert.Len(value.Roots, 2)
// We should never have the secret information
for _, r := range value.Roots {
assert.Equal("", r.SigningCert)
assert.Equal("", r.SigningKey)
}
}
func TestConnectCAConfig(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
expected := &structs.ConsulCAProviderConfig{
RotationPeriod: 90 * 24 * time.Hour,
}
// Get the initial config.
{
req, _ := http.NewRequest("GET", "/v1/connect/ca/configuration", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.ConnectCAConfiguration(resp, req)
assert.NoError(err)
value := obj.(structs.CAConfiguration)
parsed, err := ca.ParseConsulCAConfig(value.Config)
assert.NoError(err)
assert.Equal("consul", value.Provider)
assert.Equal(expected, parsed)
}
// Set the config.
{
body := bytes.NewBuffer([]byte(`
{
"Provider": "consul",
"Config": {
"RotationPeriod": 3600000000000
}
}`))
req, _ := http.NewRequest("PUT", "/v1/connect/ca/configuration", body)
resp := httptest.NewRecorder()
_, err := a.srv.ConnectCAConfiguration(resp, req)
assert.NoError(err)
}
// The config should be updated now.
{
expected.RotationPeriod = time.Hour
req, _ := http.NewRequest("GET", "/v1/connect/ca/configuration", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.ConnectCAConfiguration(resp, req)
assert.NoError(err)
value := obj.(structs.CAConfiguration)
parsed, err := ca.ParseConsulCAConfig(value.Config)
assert.NoError(err)
assert.Equal("consul", value.Provider)
assert.Equal(expected, parsed)
}
}

View File

@ -454,6 +454,33 @@ func (f *aclFilter) filterCoordinates(coords *structs.Coordinates) {
*coords = c
}
// filterIntentions is used to filter intentions based on ACL rules.
// We prune entries the user doesn't have access to, and we redact any tokens
// if the user doesn't have a management token.
func (f *aclFilter) filterIntentions(ixns *structs.Intentions) {
// Management tokens can see everything with no filtering.
if f.acl.ACLList() {
return
}
// Otherwise, we need to see what the token has access to.
ret := make(structs.Intentions, 0, len(*ixns))
for _, ixn := range *ixns {
// If no prefix ACL applies to this then filter it, since
// we know at this point the user doesn't have a management
// token, otherwise see what the policy says.
prefix, ok := ixn.GetACLPrefix()
if !ok || !f.acl.IntentionRead(prefix) {
f.logger.Printf("[DEBUG] consul: dropping intention %q from result due to ACLs", ixn.ID)
continue
}
ret = append(ret, ixn)
}
*ixns = ret
}
// filterNodeDump is used to filter through all parts of a node dump and
// remove elements the provided ACL token cannot access.
func (f *aclFilter) filterNodeDump(dump *structs.NodeDump) {
@ -598,6 +625,9 @@ func (s *Server) filterACL(token string, subj interface{}) error {
case *structs.IndexedHealthChecks:
filt.filterHealthChecks(&v.HealthChecks)
case *structs.IndexedIntentions:
filt.filterIntentions(&v.Intentions)
case *structs.IndexedNodeDump:
filt.filterNodeDump(&v.Dump)

View File

@ -11,6 +11,7 @@ import (
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/consul/testutil/retry"
"github.com/stretchr/testify/assert"
)
var testACLPolicy = `
@ -847,6 +848,58 @@ node "node1" {
}
}
func TestACL_filterIntentions(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fill := func() structs.Intentions {
return structs.Intentions{
&structs.Intention{
ID: "f004177f-2c28-83b7-4229-eacc25fe55d1",
DestinationName: "bar",
},
&structs.Intention{
ID: "f004177f-2c28-83b7-4229-eacc25fe55d2",
DestinationName: "foo",
},
}
}
// Try permissive filtering.
{
ixns := fill()
filt := newACLFilter(acl.AllowAll(), nil, false)
filt.filterIntentions(&ixns)
assert.Len(ixns, 2)
}
// Try restrictive filtering.
{
ixns := fill()
filt := newACLFilter(acl.DenyAll(), nil, false)
filt.filterIntentions(&ixns)
assert.Len(ixns, 0)
}
// Policy to see one
policy, err := acl.Parse(`
service "foo" {
policy = "read"
}
`, nil)
assert.Nil(err)
perms, err := acl.New(acl.DenyAll(), policy, nil)
assert.Nil(err)
// Filter
{
ixns := fill()
filt := newACLFilter(perms, nil, false)
filt.filterIntentions(&ixns)
assert.Len(ixns, 1)
}
}
func TestACL_filterServices(t *testing.T) {
t.Parallel()
// Create some services

View File

@ -47,6 +47,13 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error
// Handle a service registration.
if args.Service != nil {
// Validate the service. This is in addition to the below since
// the above just hasn't been moved over yet. We should move it over
// in time.
if err := args.Service.Validate(); err != nil {
return err
}
// If no service id, but service name, use default
if args.Service.ID == "" && args.Service.Service != "" {
args.Service.ID = args.Service.Service
@ -73,6 +80,13 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error
return acl.ErrPermissionDenied
}
}
// Proxies must have write permission on their destination
if args.Service.Kind == structs.ServiceKindConnectProxy {
if rule != nil && !rule.ServiceWrite(args.Service.ProxyDestination, nil) {
return acl.ErrPermissionDenied
}
}
}
// Move the old format single check into the slice, and fixup IDs.
@ -244,24 +258,52 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru
return fmt.Errorf("Must provide service name")
}
// Determine the function we'll call
var f func(memdb.WatchSet, *state.Store) (uint64, structs.ServiceNodes, error)
switch {
case args.Connect:
f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) {
return s.ConnectServiceNodes(ws, args.ServiceName)
}
default:
f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) {
if args.ServiceAddress != "" {
return s.ServiceAddressNodes(ws, args.ServiceAddress)
}
if args.TagFilter {
return s.ServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
}
return s.ServiceNodes(ws, args.ServiceName)
}
}
// If we're doing a connect query, we need read access to the service
// we're trying to find proxies for, so check that.
if args.Connect {
// Fetch the ACL token, if any.
rule, err := c.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil && !rule.ServiceRead(args.ServiceName) {
// Just return nil, which will return an empty response (tested)
return nil
}
}
err := c.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
var index uint64
var services structs.ServiceNodes
var err error
if args.TagFilter {
index, services, err = state.ServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
} else {
index, services, err = state.ServiceNodes(ws, args.ServiceName)
}
if args.ServiceAddress != "" {
index, services, err = state.ServiceAddressNodes(ws, args.ServiceAddress)
}
index, services, err := f(ws, state)
if err != nil {
return err
}
reply.Index, reply.ServiceNodes = index, services
if len(args.NodeMetaFilters) > 0 {
var filtered structs.ServiceNodes
@ -280,17 +322,24 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru
// Provide some metrics
if err == nil {
metrics.IncrCounterWithLabels([]string{"catalog", "service", "query"}, 1,
// For metrics, we separate Connect-based lookups from non-Connect
key := "service"
if args.Connect {
key = "connect"
}
metrics.IncrCounterWithLabels([]string{"catalog", key, "query"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
if args.ServiceTag != "" {
metrics.IncrCounterWithLabels([]string{"catalog", "service", "query-tag"}, 1,
metrics.IncrCounterWithLabels([]string{"catalog", key, "query-tag"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}, {Name: "tag", Value: args.ServiceTag}})
}
if len(reply.ServiceNodes) == 0 {
metrics.IncrCounterWithLabels([]string{"catalog", "service", "not-found"}, 1,
metrics.IncrCounterWithLabels([]string{"catalog", key, "not-found"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
}
}
return err
}

View File

@ -16,6 +16,8 @@ import (
"github.com/hashicorp/consul/testutil/retry"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCatalog_Register(t *testing.T) {
@ -332,6 +334,147 @@ func TestCatalog_Register_ForwardDC(t *testing.T) {
}
}
func TestCatalog_Register_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
args := structs.TestRegisterRequestProxy(t)
// Register
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// List
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Service,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
}
// Test an invalid ConnectProxy. We don't need to exhaustively test because
// this is all tested in structs on the Validate method.
func TestCatalog_Register_ConnectProxy_invalid(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
args := structs.TestRegisterRequestProxy(t)
args.Service.ProxyDestination = ""
// Register
var out struct{}
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.NotNil(err)
assert.Contains(err.Error(), "ProxyDestination")
}
// Test that write is required for the proxy destination to register a proxy.
func TestCatalog_Register_ConnectProxy_ACLProxyDestination(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.ACLDatacenter = "dc1"
c.ACLMasterToken = "root"
c.ACLDefaultPolicy = "deny"
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create the ACL.
arg := structs.ACLRequest{
Datacenter: "dc1",
Op: structs.ACLSet,
ACL: structs.ACL{
Name: "User token",
Type: structs.ACLTypeClient,
Rules: `
service "foo" {
policy = "write"
}
`,
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
var token string
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &token))
// Register should fail because we don't have permission on the destination
args := structs.TestRegisterRequestProxy(t)
args.Service.Service = "foo"
args.Service.ProxyDestination = "bar"
args.WriteRequest.Token = token
var out struct{}
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.True(acl.IsErrPermissionDenied(err))
// Register should fail with the right destination but wrong name
args = structs.TestRegisterRequestProxy(t)
args.Service.Service = "bar"
args.Service.ProxyDestination = "foo"
args.WriteRequest.Token = token
err = msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.True(acl.IsErrPermissionDenied(err))
// Register should work with the right destination
args = structs.TestRegisterRequestProxy(t)
args.Service.Service = "foo"
args.Service.ProxyDestination = "foo"
args.WriteRequest.Token = token
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
}
func TestCatalog_Register_ConnectNative(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
args := structs.TestRegisterRequest(t)
args.Service.Connect.Native = true
// Register
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// List
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Service,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindTypical, v.ServiceKind)
assert.True(v.ServiceConnect.Native)
}
func TestCatalog_Deregister(t *testing.T) {
t.Parallel()
dir1, s1 := testServer(t)
@ -1599,6 +1742,246 @@ func TestCatalog_ListServiceNodes_DistanceSort(t *testing.T) {
}
}
func TestCatalog_ListServiceNodes_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the service
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Service,
TagFilter: false,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
}
func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the proxy service
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// Register the service
{
dst := args.Service.ProxyDestination
args := structs.TestRegisterRequest(t)
args.Service.Service = dst
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
}
// List
req := structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: args.Service.ProxyDestination,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
// List by non-Connect
req = structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.ProxyDestination,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v = resp.ServiceNodes[0]
assert.Equal(args.Service.ProxyDestination, v.ServiceName)
assert.Equal("", v.ServiceProxyDestination)
}
// Test that calling ServiceNodes with Connect: true will return
// Connect native services.
func TestCatalog_ListServiceNodes_ConnectDestinationNative(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the native service
args := structs.TestRegisterRequest(t)
args.Service.Connect.Native = true
var out struct{}
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: args.Service.Service,
}
var resp structs.IndexedServiceNodes
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
require.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
require.Equal(args.Service.Service, v.ServiceName)
// List by non-Connect
req = structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Service,
}
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
require.Len(resp.ServiceNodes, 1)
v = resp.ServiceNodes[0]
require.Equal(args.Service.Service, v.ServiceName)
}
func TestCatalog_ListServiceNodes_ConnectProxy_ACL(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.ACLDatacenter = "dc1"
c.ACLMasterToken = "root"
c.ACLDefaultPolicy = "deny"
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create the ACL.
arg := structs.ACLRequest{
Datacenter: "dc1",
Op: structs.ACLSet,
ACL: structs.ACL{
Name: "User token",
Type: structs.ACLTypeClient,
Rules: `
service "foo" {
policy = "write"
}
`,
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
var token string
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &token))
{
// Register a proxy
args := structs.TestRegisterRequestProxy(t)
args.Service.Service = "foo-proxy"
args.Service.ProxyDestination = "bar"
args.WriteRequest.Token = "root"
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a proxy
args = structs.TestRegisterRequestProxy(t)
args.Service.Service = "foo-proxy"
args.Service.ProxyDestination = "foo"
args.WriteRequest.Token = "root"
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a proxy
args = structs.TestRegisterRequestProxy(t)
args.Service.Service = "another-proxy"
args.Service.ProxyDestination = "foo"
args.WriteRequest.Token = "root"
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
}
// List w/ token. This should disallow because we don't have permission
// to read "bar"
req := structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: "bar",
QueryOptions: structs.QueryOptions{Token: token},
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 0)
// List w/ token. This should work since we're requesting "foo", but should
// also only contain the proxies with names that adhere to our ACL.
req = structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: "foo",
QueryOptions: structs.QueryOptions{Token: token},
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal("foo-proxy", v.ServiceName)
}
func TestCatalog_ListServiceNodes_ConnectNative(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the service
args := structs.TestRegisterRequest(t)
args.Service.Connect.Native = true
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Service,
TagFilter: false,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(args.Service.Connect.Native, v.ServiceConnect.Native)
}
func TestCatalog_NodeServices(t *testing.T) {
t.Parallel()
dir1, s1 := testServer(t)
@ -1649,6 +2032,67 @@ func TestCatalog_NodeServices(t *testing.T) {
}
}
func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the service
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.NodeSpecificRequest{
Datacenter: "dc1",
Node: args.Node,
}
var resp structs.IndexedNodeServices
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
assert.Len(resp.NodeServices.Services, 1)
v := resp.NodeServices.Services[args.Service.Service]
assert.Equal(structs.ServiceKindConnectProxy, v.Kind)
assert.Equal(args.Service.ProxyDestination, v.ProxyDestination)
}
func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the service
args := structs.TestRegisterRequest(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.NodeSpecificRequest{
Datacenter: "dc1",
Node: args.Node,
}
var resp structs.IndexedNodeServices
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
assert.Len(resp.NodeServices.Services, 1)
v := resp.NodeServices.Services[args.Service.Service]
assert.Equal(args.Service.Connect.Native, v.Connect.Native)
}
// Used to check for a regression against a known bug
func TestCatalog_Register_FailedCase1(t *testing.T) {
t.Parallel()

View File

@ -7,6 +7,7 @@ import (
"os"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/armon/go-metrics"
@ -56,7 +57,7 @@ type Client struct {
// rpcLimiter is used to rate limit the total number of RPCs initiated
// from an agent.
rpcLimiter *rate.Limiter
rpcLimiter atomic.Value
// eventCh is used to receive events from the
// serf cluster in the datacenter
@ -128,12 +129,13 @@ func NewClientLogger(config *Config, logger *log.Logger) (*Client, error) {
c := &Client{
config: config,
connPool: connPool,
rpcLimiter: rate.NewLimiter(config.RPCRate, config.RPCMaxBurst),
eventCh: make(chan serf.Event, serfEventBacklog),
logger: logger,
shutdownCh: make(chan struct{}),
}
c.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))
if err := c.initEnterprise(); err != nil {
c.Shutdown()
return nil, err
@ -263,7 +265,7 @@ TRY:
// Enforce the RPC limit.
metrics.IncrCounter([]string{"client", "rpc"}, 1)
if !c.rpcLimiter.Allow() {
if !c.rpcLimiter.Load().(*rate.Limiter).Allow() {
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
return structs.ErrRPCRateExceeded
}
@ -306,7 +308,7 @@ func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io
// Enforce the RPC limit.
metrics.IncrCounter([]string{"client", "rpc"}, 1)
if !c.rpcLimiter.Allow() {
if !c.rpcLimiter.Load().(*rate.Limiter).Allow() {
metrics.IncrCounter([]string{"client", "rpc", "exceeded"}, 1)
return structs.ErrRPCRateExceeded
}
@ -381,3 +383,10 @@ func (c *Client) GetLANCoordinate() (lib.CoordinateSet, error) {
cs := lib.CoordinateSet{c.config.Segment: lan}
return cs, nil
}
// ReloadConfig is used to have the Client do an online reload of
// relevant configuration information
func (c *Client) ReloadConfig(config *Config) error {
c.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst))
return nil
}

View File

@ -15,6 +15,8 @@ import (
"github.com/hashicorp/consul/testutil/retry"
"github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/serf/serf"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
func testClientConfig(t *testing.T) (string, *Config) {
@ -665,3 +667,25 @@ func TestClient_Encrypted(t *testing.T) {
t.Fatalf("should be encrypted")
}
}
func TestClient_Reload(t *testing.T) {
t.Parallel()
dir1, c := testClientWithConfig(t, func(c *Config) {
c.RPCRate = 500
c.RPCMaxBurst = 5000
})
defer os.RemoveAll(dir1)
defer c.Shutdown()
limiter := c.rpcLimiter.Load().(*rate.Limiter)
require.Equal(t, rate.Limit(500), limiter.Limit())
require.Equal(t, 5000, limiter.Burst())
c.config.RPCRate = 1000
c.config.RPCMaxBurst = 10000
require.NoError(t, c.ReloadConfig(c.config))
limiter = c.rpcLimiter.Load().(*rate.Limiter)
require.Equal(t, rate.Limit(1000), limiter.Limit())
require.Equal(t, 10000, limiter.Burst())
}

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types"
@ -346,6 +347,13 @@ type Config struct {
// autopilot tasks, such as promoting eligible non-voters and removing
// dead servers.
AutopilotInterval time.Duration
// ConnectEnabled is whether to enable Connect features such as the CA.
ConnectEnabled bool
// CAConfig is used to apply the initial Connect CA configuration when
// bootstrapping.
CAConfig *structs.CAConfiguration
}
// CheckProtocolVersion validates the protocol version.
@ -425,6 +433,13 @@ func DefaultConfig() *Config {
ServerStabilizationTime: 10 * time.Second,
},
CAConfig: &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"RotationPeriod": "2160h",
},
},
ServerHealthInterval: 2 * time.Second,
AutopilotInterval: 10 * time.Second,
}

View File

@ -0,0 +1,393 @@
package consul
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
)
var ErrConnectNotEnabled = errors.New("Connect must be enabled in order to use this endpoint")
// ConnectCA manages the Connect CA.
type ConnectCA struct {
// srv is a pointer back to the server.
srv *Server
}
// ConfigurationGet returns the configuration for the CA.
func (s *ConnectCA) ConfigurationGet(
args *structs.DCSpecificRequest,
reply *structs.CAConfiguration) error {
// Exit early if Connect hasn't been enabled.
if !s.srv.config.ConnectEnabled {
return ErrConnectNotEnabled
}
if done, err := s.srv.forward("ConnectCA.ConfigurationGet", args, args, reply); done {
return err
}
// This action requires operator read access.
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil && !rule.OperatorRead() {
return acl.ErrPermissionDenied
}
state := s.srv.fsm.State()
_, config, err := state.CAConfig()
if err != nil {
return err
}
*reply = *config
return nil
}
// ConfigurationSet updates the configuration for the CA.
func (s *ConnectCA) ConfigurationSet(
args *structs.CARequest,
reply *interface{}) error {
// Exit early if Connect hasn't been enabled.
if !s.srv.config.ConnectEnabled {
return ErrConnectNotEnabled
}
if done, err := s.srv.forward("ConnectCA.ConfigurationSet", args, args, reply); done {
return err
}
// This action requires operator write access.
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil && !rule.OperatorWrite() {
return acl.ErrPermissionDenied
}
// Exit early if it's a no-op change
state := s.srv.fsm.State()
_, config, err := state.CAConfig()
if err != nil {
return err
}
args.Config.ClusterID = config.ClusterID
if args.Config.Provider == config.Provider && reflect.DeepEqual(args.Config.Config, config.Config) {
return nil
}
// Create a new instance of the provider described by the config
// and get the current active root CA. This acts as a good validation
// of the config and makes sure the provider is functioning correctly
// before we commit any changes to Raft.
newProvider, err := s.srv.createCAProvider(args.Config)
if err != nil {
return fmt.Errorf("could not initialize provider: %v", err)
}
newRootPEM, err := newProvider.ActiveRoot()
if err != nil {
return err
}
newActiveRoot, err := parseCARoot(newRootPEM, args.Config.Provider)
if err != nil {
return err
}
// Compare the new provider's root CA ID to the current one. If they
// match, just update the existing provider with the new config.
// If they don't match, begin the root rotation process.
_, root, err := state.CARootActive(nil)
if err != nil {
return err
}
if root != nil && root.ID == newActiveRoot.ID {
args.Op = structs.CAOpSetConfig
resp, err := s.srv.raftApply(structs.ConnectCARequestType, args)
if err != nil {
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
// If the config has been committed, update the local provider instance
s.srv.setCAProvider(newProvider, newActiveRoot)
s.srv.logger.Printf("[INFO] connect: CA provider config updated")
return nil
}
// At this point, we know the config change has trigged a root rotation,
// either by swapping the provider type or changing the provider's config
// to use a different root certificate.
// If it's a config change that would trigger a rotation (different provider/root):
// 1. Get the root from the new provider.
// 2. Call CrossSignCA on the old provider to sign the new root with the old one to
// get a cross-signed certificate.
// 3. Take the active root for the new provider and append the intermediate from step 2
// to its list of intermediates.
newRoot, err := connect.ParseCert(newRootPEM)
if err != nil {
return err
}
// Have the old provider cross-sign the new intermediate
oldProvider, _ := s.srv.getCAProvider()
if oldProvider == nil {
return fmt.Errorf("internal error: CA provider is nil")
}
xcCert, err := oldProvider.CrossSignCA(newRoot)
if err != nil {
return err
}
// Add the cross signed cert to the new root's intermediates.
newActiveRoot.IntermediateCerts = []string{xcCert}
intermediate, err := newProvider.GenerateIntermediate()
if err != nil {
return err
}
if intermediate != newRootPEM {
newActiveRoot.IntermediateCerts = append(newActiveRoot.IntermediateCerts, intermediate)
}
// Update the roots and CA config in the state store at the same time
idx, roots, err := state.CARoots(nil)
if err != nil {
return err
}
var newRoots structs.CARoots
for _, r := range roots {
newRoot := *r
if newRoot.Active {
newRoot.Active = false
}
newRoots = append(newRoots, &newRoot)
}
newRoots = append(newRoots, newActiveRoot)
args.Op = structs.CAOpSetRootsAndConfig
args.Index = idx
args.Roots = newRoots
resp, err := s.srv.raftApply(structs.ConnectCARequestType, args)
if err != nil {
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
// If the config has been committed, update the local provider instance
// and call teardown on the old provider
s.srv.setCAProvider(newProvider, newActiveRoot)
if err := oldProvider.Cleanup(); err != nil {
s.srv.logger.Printf("[WARN] connect: failed to clean up old provider %q", config.Provider)
}
s.srv.logger.Printf("[INFO] connect: CA rotated to new root under provider %q", args.Config.Provider)
return nil
}
// Roots returns the currently trusted root certificates.
func (s *ConnectCA) Roots(
args *structs.DCSpecificRequest,
reply *structs.IndexedCARoots) error {
// Forward if necessary
if done, err := s.srv.forward("ConnectCA.Roots", args, args, reply); done {
return err
}
// Load the ClusterID to generate TrustDomain. We do this outside the loop
// since by definition this value should be immutable once set for lifetime of
// the cluster so we don't need to look it up more than once. We also don't
// have to worry about non-atomicity between the config fetch transaction and
// the CARoots transaction below since this field must remain immutable. Do
// not re-use this state/config for other logic that might care about changes
// of config during the blocking query below.
{
state := s.srv.fsm.State()
_, config, err := state.CAConfig()
if err != nil {
return err
}
// Check CA is actually bootstrapped...
if config != nil {
// Build TrustDomain based on the ClusterID stored.
signingID := connect.SpiffeIDSigningForCluster(config)
if signingID == nil {
// If CA is bootstrapped at all then this should never happen but be
// defensive.
return errors.New("no cluster trust domain setup")
}
reply.TrustDomain = signingID.Host()
}
}
return s.srv.blockingQuery(
&args.QueryOptions, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
index, roots, err := state.CARoots(ws)
if err != nil {
return err
}
reply.Index, reply.Roots = index, roots
if reply.Roots == nil {
reply.Roots = make(structs.CARoots, 0)
}
// The API response must NEVER contain the secret information
// such as keys and so on. We use a whitelist below to copy the
// specific fields we want to expose.
for i, r := range reply.Roots {
// IMPORTANT: r must NEVER be modified, since it is a pointer
// directly to the structure in the memdb store.
reply.Roots[i] = &structs.CARoot{
ID: r.ID,
Name: r.Name,
SerialNumber: r.SerialNumber,
SigningKeyID: r.SigningKeyID,
NotBefore: r.NotBefore,
NotAfter: r.NotAfter,
RootCert: r.RootCert,
IntermediateCerts: r.IntermediateCerts,
RaftIndex: r.RaftIndex,
Active: r.Active,
}
if r.Active {
reply.ActiveRootID = r.ID
}
}
return nil
},
)
}
// Sign signs a certificate for a service.
func (s *ConnectCA) Sign(
args *structs.CASignRequest,
reply *structs.IssuedCert) error {
// Exit early if Connect hasn't been enabled.
if !s.srv.config.ConnectEnabled {
return ErrConnectNotEnabled
}
if done, err := s.srv.forward("ConnectCA.Sign", args, args, reply); done {
return err
}
// Parse the CSR
csr, err := connect.ParseCSR(args.CSR)
if err != nil {
return err
}
// Parse the SPIFFE ID
spiffeID, err := connect.ParseCertURI(csr.URIs[0])
if err != nil {
return err
}
serviceID, ok := spiffeID.(*connect.SpiffeIDService)
if !ok {
return fmt.Errorf("SPIFFE ID in CSR must be a service ID")
}
provider, caRoot := s.srv.getCAProvider()
if provider == nil {
return fmt.Errorf("internal error: CA provider is nil")
}
// Verify that the CSR entity is in the cluster's trust domain
state := s.srv.fsm.State()
_, config, err := state.CAConfig()
if err != nil {
return err
}
signingID := connect.SpiffeIDSigningForCluster(config)
if !signingID.CanSign(serviceID) {
return fmt.Errorf("SPIFFE ID in CSR from a different trust domain: %s, "+
"we are %s", serviceID.Host, signingID.Host())
}
// Verify that the ACL token provided has permission to act as this service
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil && !rule.ServiceWrite(serviceID.Service, nil) {
return acl.ErrPermissionDenied
}
// Verify that the DC in the service URI matches us. We might relax this
// requirement later but being restrictive for now is safer.
if serviceID.Datacenter != s.srv.config.Datacenter {
return fmt.Errorf("SPIFFE ID in CSR from a different datacenter: %s, "+
"we are %s", serviceID.Datacenter, s.srv.config.Datacenter)
}
// All seems to be in order, actually sign it.
pem, err := provider.Sign(csr)
if err != nil {
return err
}
// Append any intermediates needed by this root.
for _, p := range caRoot.IntermediateCerts {
pem = strings.TrimSpace(pem) + "\n" + p
}
// TODO(banks): when we implement IssuedCerts table we can use the insert to
// that as the raft index to return in response. Right now we can rely on only
// the built-in provider being supported and the implementation detail that we
// have to write a SerialIndex update to the provider config table for every
// cert issued so in all cases this index will be higher than any previous
// sign response. This has to be reloaded after the provider.Sign call to
// observe the index update.
state = s.srv.fsm.State()
modIdx, _, err := state.CAConfig()
if err != nil {
return err
}
cert, err := connect.ParseCert(pem)
if err != nil {
return err
}
// Set the response
*reply = structs.IssuedCert{
SerialNumber: connect.HexString(cert.SerialNumber.Bytes()),
CertPEM: pem,
Service: serviceID.Service,
ServiceURI: cert.URIs[0].String(),
ValidAfter: cert.NotBefore,
ValidBefore: cert.NotAfter,
RaftIndex: structs.RaftIndex{
ModifyIndex: modIdx,
CreateIndex: modIdx,
},
}
return nil
}

View File

@ -0,0 +1,434 @@
package consul
import (
"crypto/x509"
"encoding/pem"
"fmt"
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/connect"
ca "github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/stretchr/testify/assert"
)
func testParseCert(t *testing.T, pemValue string) *x509.Certificate {
cert, err := connect.ParseCert(pemValue)
if err != nil {
t.Fatal(err)
}
return cert
}
// Test listing root CAs.
func TestConnectCARoots(t *testing.T) {
t.Parallel()
assert := assert.New(t)
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Insert some CAs
state := s1.fsm.State()
ca1 := connect.TestCA(t, nil)
ca2 := connect.TestCA(t, nil)
ca2.Active = false
idx, _, err := state.CARoots(nil)
require.NoError(err)
ok, err := state.CARootSetCAS(idx, idx, []*structs.CARoot{ca1, ca2})
assert.True(ok)
require.NoError(err)
_, caCfg, err := state.CAConfig()
require.NoError(err)
// Request
args := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var reply structs.IndexedCARoots
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
// Verify
assert.Equal(ca1.ID, reply.ActiveRootID)
assert.Len(reply.Roots, 2)
for _, r := range reply.Roots {
// These must never be set, for security
assert.Equal("", r.SigningCert)
assert.Equal("", r.SigningKey)
}
assert.Equal(fmt.Sprintf("%s.consul", caCfg.ClusterID), reply.TrustDomain)
}
func TestConnectCAConfig_GetSet(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Get the starting config
{
args := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var reply structs.CAConfiguration
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config)
assert.NoError(err)
expected, err := ca.ParseConsulCAConfig(s1.config.CAConfig.Config)
assert.NoError(err)
assert.Equal(reply.Provider, s1.config.CAConfig.Provider)
assert.Equal(actual, expected)
}
// Update a config value
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": "",
"RootCert": "",
"RotationPeriod": 180 * 24 * time.Hour,
},
}
{
args := &structs.CARequest{
Datacenter: "dc1",
Config: newConfig,
}
var reply interface{}
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
}
// Verify the new config was set
{
args := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var reply structs.CAConfiguration
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config)
assert.NoError(err)
expected, err := ca.ParseConsulCAConfig(newConfig.Config)
assert.NoError(err)
assert.Equal(reply.Provider, newConfig.Provider)
assert.Equal(actual, expected)
}
}
func TestConnectCAConfig_TriggerRotation(t *testing.T) {
t.Parallel()
assert := assert.New(t)
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Store the current root
rootReq := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var rootList structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
assert.Len(rootList.Roots, 1)
oldRoot := rootList.Roots[0]
// Update the provider config to use a new private key, which should
// cause a rotation.
_, newKey, err := connect.GeneratePrivateKey()
assert.NoError(err)
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": newKey,
"RootCert": "",
"RotationPeriod": 90 * 24 * time.Hour,
},
}
{
args := &structs.CARequest{
Datacenter: "dc1",
Config: newConfig,
}
var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
}
// Make sure the new root has been added along with an intermediate
// cross-signed by the old root.
var newRootPEM string
{
args := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var reply structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
assert.Len(reply.Roots, 2)
for _, r := range reply.Roots {
if r.ID == oldRoot.ID {
// The old root should no longer be marked as the active root,
// and none of its other fields should have changed.
assert.False(r.Active)
assert.Equal(r.Name, oldRoot.Name)
assert.Equal(r.RootCert, oldRoot.RootCert)
assert.Equal(r.SigningCert, oldRoot.SigningCert)
assert.Equal(r.IntermediateCerts, oldRoot.IntermediateCerts)
} else {
newRootPEM = r.RootCert
// The new root should have a valid cross-signed cert from the old
// root as an intermediate.
assert.True(r.Active)
assert.Len(r.IntermediateCerts, 1)
xc := testParseCert(t, r.IntermediateCerts[0])
oldRootCert := testParseCert(t, oldRoot.RootCert)
newRootCert := testParseCert(t, r.RootCert)
// Should have the authority key ID and signature algo of the
// (old) signing CA.
assert.Equal(xc.AuthorityKeyId, oldRootCert.AuthorityKeyId)
assert.NotEqual(xc.SubjectKeyId, oldRootCert.SubjectKeyId)
assert.Equal(xc.SignatureAlgorithm, oldRootCert.SignatureAlgorithm)
// The common name and SAN should not have changed.
assert.Equal(xc.Subject.CommonName, newRootCert.Subject.CommonName)
assert.Equal(xc.URIs, newRootCert.URIs)
}
}
}
// Verify the new config was set.
{
args := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var reply structs.CAConfiguration
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config)
require.NoError(err)
expected, err := ca.ParseConsulCAConfig(newConfig.Config)
require.NoError(err)
assert.Equal(reply.Provider, newConfig.Provider)
assert.Equal(actual, expected)
}
// Verify that new leaf certs get the cross-signed intermediate bundled
{
// Generate a CSR and request signing
spiffeId := connect.TestSpiffeIDService(t, "web")
csr, _ := connect.TestCSR(t, spiffeId)
args := &structs.CASignRequest{
Datacenter: "dc1",
CSR: csr,
}
var reply structs.IssuedCert
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
// Verify that the cert is signed by the new CA
{
roots := x509.NewCertPool()
require.True(roots.AppendCertsFromPEM([]byte(newRootPEM)))
leaf, err := connect.ParseCert(reply.CertPEM)
require.NoError(err)
_, err = leaf.Verify(x509.VerifyOptions{
Roots: roots,
})
require.NoError(err)
}
// And that it validates via the intermediate
{
roots := x509.NewCertPool()
assert.True(roots.AppendCertsFromPEM([]byte(oldRoot.RootCert)))
leaf, err := connect.ParseCert(reply.CertPEM)
require.NoError(err)
// Make sure the intermediate was returned as well as leaf
_, rest := pem.Decode([]byte(reply.CertPEM))
require.NotEmpty(rest)
intermediates := x509.NewCertPool()
require.True(intermediates.AppendCertsFromPEM(rest))
_, err = leaf.Verify(x509.VerifyOptions{
Roots: roots,
Intermediates: intermediates,
})
require.NoError(err)
}
// Verify other fields
assert.Equal("web", reply.Service)
assert.Equal(spiffeId.URI().String(), reply.ServiceURI)
}
}
// Test CA signing
func TestConnectCASign(t *testing.T) {
t.Parallel()
assert := assert.New(t)
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Generate a CSR and request signing
spiffeId := connect.TestSpiffeIDService(t, "web")
csr, _ := connect.TestCSR(t, spiffeId)
args := &structs.CASignRequest{
Datacenter: "dc1",
CSR: csr,
}
var reply structs.IssuedCert
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
// Get the current CA
state := s1.fsm.State()
_, ca, err := state.CARootActive(nil)
require.NoError(err)
// Verify that the cert is signed by the CA
roots := x509.NewCertPool()
assert.True(roots.AppendCertsFromPEM([]byte(ca.RootCert)))
leaf, err := connect.ParseCert(reply.CertPEM)
require.NoError(err)
_, err = leaf.Verify(x509.VerifyOptions{
Roots: roots,
})
require.NoError(err)
// Verify other fields
assert.Equal("web", reply.Service)
assert.Equal(spiffeId.URI().String(), reply.ServiceURI)
}
func TestConnectCASignValidation(t *testing.T) {
t.Parallel()
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.ACLDatacenter = "dc1"
c.ACLMasterToken = "root"
c.ACLDefaultPolicy = "deny"
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create an ACL token with service:write for web*
var webToken string
{
arg := structs.ACLRequest{
Datacenter: "dc1",
Op: structs.ACLSet,
ACL: structs.ACL{
Name: "User token",
Type: structs.ACLTypeClient,
Rules: `
service "web" {
policy = "write"
}`,
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &webToken))
}
testWebID := connect.TestSpiffeIDService(t, "web")
tests := []struct {
name string
id connect.CertURI
wantErr string
}{
{
name: "different cluster",
id: &connect.SpiffeIDService{
Host: "55555555-4444-3333-2222-111111111111.consul",
Namespace: testWebID.Namespace,
Datacenter: testWebID.Datacenter,
Service: testWebID.Service,
},
wantErr: "different trust domain",
},
{
name: "same cluster should validate",
id: testWebID,
wantErr: "",
},
{
name: "same cluster, CSR for a different DC should NOT validate",
id: &connect.SpiffeIDService{
Host: testWebID.Host,
Namespace: testWebID.Namespace,
Datacenter: "dc2",
Service: testWebID.Service,
},
wantErr: "different datacenter",
},
{
name: "same cluster and DC, different service should not have perms",
id: &connect.SpiffeIDService{
Host: testWebID.Host,
Namespace: testWebID.Namespace,
Datacenter: testWebID.Datacenter,
Service: "db",
},
wantErr: "Permission denied",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
csr, _ := connect.TestCSR(t, tt.id)
args := &structs.CASignRequest{
Datacenter: "dc1",
CSR: csr,
WriteRequest: structs.WriteRequest{Token: webToken},
}
var reply structs.IssuedCert
err := msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply)
if tt.wantErr == "" {
require.NoError(t, err)
// No other validation that is handled in different tests
} else {
require.Error(t, err)
require.Contains(t, err.Error(), tt.wantErr)
}
})
}
}

View File

@ -0,0 +1,28 @@
package consul
import (
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
)
// consulCADelegate providers callbacks for the Consul CA provider
// to use the state store for its operations.
type consulCADelegate struct {
srv *Server
}
func (c *consulCADelegate) State() *state.Store {
return c.srv.fsm.State()
}
func (c *consulCADelegate) ApplyCARequest(req *structs.CARequest) error {
resp, err := c.srv.raftApply(structs.ConnectCARequestType, req)
if err != nil {
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
return nil
}

View File

@ -20,6 +20,8 @@ func init() {
registerCommand(structs.PreparedQueryRequestType, (*FSM).applyPreparedQueryOperation)
registerCommand(structs.TxnRequestType, (*FSM).applyTxn)
registerCommand(structs.AutopilotRequestType, (*FSM).applyAutopilotUpdate)
registerCommand(structs.IntentionRequestType, (*FSM).applyIntentionOperation)
registerCommand(structs.ConnectCARequestType, (*FSM).applyConnectCAOperation)
}
func (c *FSM) applyRegister(buf []byte, index uint64) interface{} {
@ -246,3 +248,85 @@ func (c *FSM) applyAutopilotUpdate(buf []byte, index uint64) interface{} {
}
return c.state.AutopilotSetConfig(index, &req.Config)
}
// applyIntentionOperation applies the given intention operation to the state store.
func (c *FSM) applyIntentionOperation(buf []byte, index uint64) interface{} {
var req structs.IntentionRequest
if err := structs.Decode(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode request: %v", err))
}
defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "intention"}, time.Now(),
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
defer metrics.MeasureSinceWithLabels([]string{"fsm", "intention"}, time.Now(),
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
switch req.Op {
case structs.IntentionOpCreate, structs.IntentionOpUpdate:
return c.state.IntentionSet(index, req.Intention)
case structs.IntentionOpDelete:
return c.state.IntentionDelete(index, req.Intention.ID)
default:
c.logger.Printf("[WARN] consul.fsm: Invalid Intention operation '%s'", req.Op)
return fmt.Errorf("Invalid Intention operation '%s'", req.Op)
}
}
// applyConnectCAOperation applies the given CA operation to the state store.
func (c *FSM) applyConnectCAOperation(buf []byte, index uint64) interface{} {
var req structs.CARequest
if err := structs.Decode(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode request: %v", err))
}
defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "ca"}, time.Now(),
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
defer metrics.MeasureSinceWithLabels([]string{"fsm", "ca"}, time.Now(),
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
switch req.Op {
case structs.CAOpSetConfig:
if req.Config.ModifyIndex != 0 {
act, err := c.state.CACheckAndSetConfig(index, req.Config.ModifyIndex, req.Config)
if err != nil {
return err
}
return act
}
return c.state.CASetConfig(index, req.Config)
case structs.CAOpSetRoots:
act, err := c.state.CARootSetCAS(index, req.Index, req.Roots)
if err != nil {
return err
}
return act
case structs.CAOpSetProviderState:
act, err := c.state.CASetProviderState(index, req.ProviderState)
if err != nil {
return err
}
return act
case structs.CAOpDeleteProviderState:
if err := c.state.CADeleteProviderState(req.ProviderState.ID); err != nil {
return err
}
return true
case structs.CAOpSetRootsAndConfig:
act, err := c.state.CARootSetCAS(index, req.Index, req.Roots)
if err != nil {
return err
}
if err := c.state.CASetConfig(index+1, req.Config); err != nil {
return err
}
return act
default:
c.logger.Printf("[WARN] consul.fsm: Invalid CA operation '%s'", req.Op)
return fmt.Errorf("Invalid CA operation '%s'", req.Op)
}
}

View File

@ -8,13 +8,16 @@ import (
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/serf/coordinate"
"github.com/mitchellh/mapstructure"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
)
func generateUUID() (ret string) {
@ -1148,3 +1151,209 @@ func TestFSM_Autopilot(t *testing.T) {
t.Fatalf("bad: %v", config.CleanupDeadServers)
}
}
func TestFSM_Intention_CRUD(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fsm, err := New(nil, os.Stderr)
assert.Nil(err)
// Create a new intention.
ixn := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: structs.TestIntention(t),
}
ixn.Intention.ID = generateUUID()
ixn.Intention.UpdatePrecedence()
{
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err)
assert.Nil(fsm.Apply(makeLog(buf)))
}
// Verify it's in the state store.
{
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err)
actual.CreateIndex, actual.ModifyIndex = 0, 0
actual.CreatedAt = ixn.Intention.CreatedAt
actual.UpdatedAt = ixn.Intention.UpdatedAt
assert.Equal(ixn.Intention, actual)
}
// Make an update
ixn.Op = structs.IntentionOpUpdate
ixn.Intention.SourceName = "api"
{
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err)
assert.Nil(fsm.Apply(makeLog(buf)))
}
// Verify the update.
{
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err)
actual.CreateIndex, actual.ModifyIndex = 0, 0
actual.CreatedAt = ixn.Intention.CreatedAt
actual.UpdatedAt = ixn.Intention.UpdatedAt
assert.Equal(ixn.Intention, actual)
}
// Delete
ixn.Op = structs.IntentionOpDelete
{
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err)
assert.Nil(fsm.Apply(makeLog(buf)))
}
// Make sure it's gone.
{
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err)
assert.Nil(actual)
}
}
func TestFSM_CAConfig(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fsm, err := New(nil, os.Stderr)
assert.Nil(err)
// Set the autopilot config using a request.
req := structs.CARequest{
Op: structs.CAOpSetConfig,
Config: &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": "asdf",
"RootCert": "qwer",
"RotationPeriod": 90 * 24 * time.Hour,
},
},
}
buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err)
resp := fsm.Apply(makeLog(buf))
if _, ok := resp.(error); ok {
t.Fatalf("bad: %v", resp)
}
// Verify key is set directly in the state store.
_, config, err := fsm.state.CAConfig()
if err != nil {
t.Fatalf("err: %v", err)
}
var conf *structs.ConsulCAProviderConfig
if err := mapstructure.WeakDecode(config.Config, &conf); err != nil {
t.Fatalf("error decoding config: %s, %v", err, config.Config)
}
if got, want := config.Provider, req.Config.Provider; got != want {
t.Fatalf("got %v, want %v", got, want)
}
if got, want := conf.PrivateKey, "asdf"; got != want {
t.Fatalf("got %v, want %v", got, want)
}
if got, want := conf.RootCert, "qwer"; got != want {
t.Fatalf("got %v, want %v", got, want)
}
if got, want := conf.RotationPeriod, 90*24*time.Hour; got != want {
t.Fatalf("got %v, want %v", got, want)
}
// Now use CAS and provide an old index
req.Config.Provider = "static"
req.Config.ModifyIndex = config.ModifyIndex - 1
buf, err = structs.Encode(structs.ConnectCARequestType, req)
if err != nil {
t.Fatalf("err: %v", err)
}
resp = fsm.Apply(makeLog(buf))
if _, ok := resp.(error); ok {
t.Fatalf("bad: %v", resp)
}
_, config, err = fsm.state.CAConfig()
assert.Nil(err)
if config.Provider != "static" {
t.Fatalf("bad: %v", config.Provider)
}
}
func TestFSM_CARoots(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fsm, err := New(nil, os.Stderr)
assert.Nil(err)
// Roots
ca1 := connect.TestCA(t, nil)
ca2 := connect.TestCA(t, nil)
ca2.Active = false
// Create a new request.
req := structs.CARequest{
Op: structs.CAOpSetRoots,
Roots: []*structs.CARoot{ca1, ca2},
}
{
buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err)
assert.True(fsm.Apply(makeLog(buf)).(bool))
}
// Verify it's in the state store.
{
_, roots, err := fsm.state.CARoots(nil)
assert.Nil(err)
assert.Len(roots, 2)
}
}
func TestFSM_CABuiltinProvider(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fsm, err := New(nil, os.Stderr)
assert.Nil(err)
// Provider state.
expected := &structs.CAConsulProviderState{
ID: "foo",
PrivateKey: "a",
RootCert: "b",
RaftIndex: structs.RaftIndex{
CreateIndex: 1,
ModifyIndex: 1,
},
}
// Create a new request.
req := structs.CARequest{
Op: structs.CAOpSetProviderState,
ProviderState: expected,
}
{
buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err)
assert.True(fsm.Apply(makeLog(buf)).(bool))
}
// Verify it's in the state store.
{
_, state, err := fsm.state.CAProviderState("foo")
assert.Nil(err)
assert.Equal(expected, state)
}
}

View File

@ -20,6 +20,8 @@ func init() {
registerRestorer(structs.CoordinateBatchUpdateType, restoreCoordinates)
registerRestorer(structs.PreparedQueryRequestType, restorePreparedQuery)
registerRestorer(structs.AutopilotRequestType, restoreAutopilot)
registerRestorer(structs.IntentionRequestType, restoreIntention)
registerRestorer(structs.ConnectCARequestType, restoreConnectCA)
}
func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error {
@ -44,6 +46,12 @@ func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) err
if err := s.persistAutopilot(sink, encoder); err != nil {
return err
}
if err := s.persistIntentions(sink, encoder); err != nil {
return err
}
if err := s.persistConnectCA(sink, encoder); err != nil {
return err
}
return nil
}
@ -258,6 +266,42 @@ func (s *snapshot) persistAutopilot(sink raft.SnapshotSink,
return nil
}
func (s *snapshot) persistConnectCA(sink raft.SnapshotSink,
encoder *codec.Encoder) error {
roots, err := s.state.CARoots()
if err != nil {
return err
}
for _, r := range roots {
if _, err := sink.Write([]byte{byte(structs.ConnectCARequestType)}); err != nil {
return err
}
if err := encoder.Encode(r); err != nil {
return err
}
}
return nil
}
func (s *snapshot) persistIntentions(sink raft.SnapshotSink,
encoder *codec.Encoder) error {
ixns, err := s.state.Intentions()
if err != nil {
return err
}
for _, ixn := range ixns {
if _, err := sink.Write([]byte{byte(structs.IntentionRequestType)}); err != nil {
return err
}
if err := encoder.Encode(ixn); err != nil {
return err
}
}
return nil
}
func restoreRegistration(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
var req structs.RegisterRequest
if err := decoder.Decode(&req); err != nil {
@ -364,3 +408,25 @@ func restoreAutopilot(header *snapshotHeader, restore *state.Restore, decoder *c
}
return nil
}
func restoreIntention(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
var req structs.Intention
if err := decoder.Decode(&req); err != nil {
return err
}
if err := restore.Intention(&req); err != nil {
return err
}
return nil
}
func restoreConnectCA(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
var req structs.CARoot
if err := decoder.Decode(&req); err != nil {
return err
}
if err := restore.CARoot(&req); err != nil {
return err
}
return nil
}

View File

@ -7,16 +7,20 @@ import (
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
)
func TestFSM_SnapshotRestore_OSS(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fsm, err := New(nil, os.Stderr)
if err != nil {
t.Fatalf("err: %v", err)
@ -98,6 +102,27 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
t.Fatalf("err: %s", err)
}
// Intentions
ixn := structs.TestIntention(t)
ixn.ID = generateUUID()
ixn.RaftIndex = structs.RaftIndex{
CreateIndex: 14,
ModifyIndex: 14,
}
assert.Nil(fsm.state.IntentionSet(14, ixn))
// CA Roots
roots := []*structs.CARoot{
connect.TestCA(t, nil),
connect.TestCA(t, nil),
}
for _, r := range roots[1:] {
r.Active = false
}
ok, err := fsm.state.CARootSetCAS(15, 0, roots)
assert.Nil(err)
assert.True(ok)
// Snapshot
snap, err := fsm.Snapshot()
if err != nil {
@ -260,6 +285,17 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
t.Fatalf("bad: %#v, %#v", restoredConf, autopilotConf)
}
// Verify intentions are restored.
_, ixns, err := fsm2.state.Intentions(nil)
assert.Nil(err)
assert.Len(ixns, 1)
assert.Equal(ixn, ixns[0])
// Verify CA roots are restored.
_, roots, err = fsm2.state.CARoots(nil)
assert.Nil(err)
assert.Len(roots, 2)
// Snapshot
snap, err = fsm2.Snapshot()
if err != nil {

View File

@ -111,18 +111,37 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc
return fmt.Errorf("Must provide service name")
}
// Determine the function we'll call
var f func(memdb.WatchSet, *state.Store, *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error)
switch {
case args.Connect:
f = h.serviceNodesConnect
case args.TagFilter:
f = h.serviceNodesTagFilter
default:
f = h.serviceNodesDefault
}
// If we're doing a connect query, we need read access to the service
// we're trying to find proxies for, so check that.
if args.Connect {
// Fetch the ACL token, if any.
rule, err := h.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil && !rule.ServiceRead(args.ServiceName) {
// Just return nil, which will return an empty response (tested)
return nil
}
}
err := h.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
var index uint64
var nodes structs.CheckServiceNodes
var err error
if args.TagFilter {
index, nodes, err = state.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
} else {
index, nodes, err = state.CheckServiceNodes(ws, args.ServiceName)
}
index, nodes, err := f(ws, state, args)
if err != nil {
return err
}
@ -139,16 +158,37 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc
// Provide some metrics
if err == nil {
metrics.IncrCounterWithLabels([]string{"health", "service", "query"}, 1,
// For metrics, we separate Connect-based lookups from non-Connect
key := "service"
if args.Connect {
key = "connect"
}
metrics.IncrCounterWithLabels([]string{"health", key, "query"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
if args.ServiceTag != "" {
metrics.IncrCounterWithLabels([]string{"health", "service", "query-tag"}, 1,
metrics.IncrCounterWithLabels([]string{"health", key, "query-tag"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}, {Name: "tag", Value: args.ServiceTag}})
}
if len(reply.Nodes) == 0 {
metrics.IncrCounterWithLabels([]string{"health", "service", "not-found"}, 1,
metrics.IncrCounterWithLabels([]string{"health", key, "not-found"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
}
}
return err
}
// The serviceNodes* functions below are the various lookup methods that
// can be used by the ServiceNodes endpoint.
func (h *Health) serviceNodesConnect(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
return s.CheckConnectServiceNodes(ws, args.ServiceName)
}
func (h *Health) serviceNodesTagFilter(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
return s.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
}
func (h *Health) serviceNodesDefault(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
return s.CheckServiceNodes(ws, args.ServiceName)
}

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/stretchr/testify/assert"
)
func TestHealth_ChecksInState(t *testing.T) {
@ -821,6 +822,106 @@ func TestHealth_ServiceNodes_DistanceSort(t *testing.T) {
}
}
func TestHealth_ServiceNodes_ConnectProxy_ACL(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.ACLDatacenter = "dc1"
c.ACLMasterToken = "root"
c.ACLDefaultPolicy = "deny"
c.ACLEnforceVersion8 = false
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create the ACL.
arg := structs.ACLRequest{
Datacenter: "dc1",
Op: structs.ACLSet,
ACL: structs.ACL{
Name: "User token",
Type: structs.ACLTypeClient,
Rules: `
service "foo" {
policy = "write"
}
`,
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
var token string
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", arg, &token))
{
var out struct{}
// Register a service
args := structs.TestRegisterRequestProxy(t)
args.WriteRequest.Token = "root"
args.Service.ID = "foo-proxy-0"
args.Service.Service = "foo-proxy"
args.Service.ProxyDestination = "bar"
args.Check = &structs.HealthCheck{
Name: "proxy",
Status: api.HealthPassing,
ServiceID: args.Service.ID,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a service
args = structs.TestRegisterRequestProxy(t)
args.WriteRequest.Token = "root"
args.Service.Service = "foo-proxy"
args.Service.ProxyDestination = "foo"
args.Check = &structs.HealthCheck{
Name: "proxy",
Status: api.HealthPassing,
ServiceID: args.Service.Service,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a service
args = structs.TestRegisterRequestProxy(t)
args.WriteRequest.Token = "root"
args.Service.Service = "another-proxy"
args.Service.ProxyDestination = "foo"
args.Check = &structs.HealthCheck{
Name: "proxy",
Status: api.HealthPassing,
ServiceID: args.Service.Service,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
}
// List w/ token. This should disallow because we don't have permission
// to read "bar"
req := structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: "bar",
QueryOptions: structs.QueryOptions{Token: token},
}
var resp structs.IndexedCheckServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
assert.Len(resp.Nodes, 0)
// List w/ token. This should work since we're requesting "foo", but should
// also only contain the proxies with names that adhere to our ACL.
req = structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: "foo",
QueryOptions: structs.QueryOptions{Token: token},
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
assert.Len(resp.Nodes, 1)
}
func TestHealth_NodeChecks_FilterACL(t *testing.T) {
t.Parallel()
dir, token, srv, codec := testACLFilterServer(t)

View File

@ -0,0 +1,358 @@
package consul
import (
"errors"
"fmt"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-uuid"
)
var (
// ErrIntentionNotFound is returned if the intention lookup failed.
ErrIntentionNotFound = errors.New("Intention not found")
)
// Intention manages the Connect intentions.
type Intention struct {
// srv is a pointer back to the server.
srv *Server
}
// Apply creates or updates an intention in the data store.
func (s *Intention) Apply(
args *structs.IntentionRequest,
reply *string) error {
if done, err := s.srv.forward("Intention.Apply", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"consul", "intention", "apply"}, time.Now())
defer metrics.MeasureSince([]string{"intention", "apply"}, time.Now())
// Always set a non-nil intention to avoid nil-access below
if args.Intention == nil {
args.Intention = &structs.Intention{}
}
// If no ID is provided, generate a new ID. This must be done prior to
// appending to the Raft log, because the ID is not deterministic. Once
// the entry is in the log, the state update MUST be deterministic or
// the followers will not converge.
if args.Op == structs.IntentionOpCreate {
if args.Intention.ID != "" {
return fmt.Errorf("ID must be empty when creating a new intention")
}
state := s.srv.fsm.State()
for {
var err error
args.Intention.ID, err = uuid.GenerateUUID()
if err != nil {
s.srv.logger.Printf("[ERR] consul.intention: UUID generation failed: %v", err)
return err
}
_, ixn, err := state.IntentionGet(nil, args.Intention.ID)
if err != nil {
s.srv.logger.Printf("[ERR] consul.intention: intention lookup failed: %v", err)
return err
}
if ixn == nil {
break
}
}
// Set the created at
args.Intention.CreatedAt = time.Now().UTC()
}
*reply = args.Intention.ID
// Get the ACL token for the request for the checks below.
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
// Perform the ACL check
if prefix, ok := args.Intention.GetACLPrefix(); ok {
if rule != nil && !rule.IntentionWrite(prefix) {
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention '%s' denied due to ACLs", args.Intention.ID)
return acl.ErrPermissionDenied
}
}
// If this is not a create, then we have to verify the ID.
if args.Op != structs.IntentionOpCreate {
state := s.srv.fsm.State()
_, ixn, err := state.IntentionGet(nil, args.Intention.ID)
if err != nil {
return fmt.Errorf("Intention lookup failed: %v", err)
}
if ixn == nil {
return fmt.Errorf("Cannot modify non-existent intention: '%s'", args.Intention.ID)
}
// Perform the ACL check that we have write to the old prefix too,
// which must be true to perform any rename.
if prefix, ok := ixn.GetACLPrefix(); ok {
if rule != nil && !rule.IntentionWrite(prefix) {
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention '%s' denied due to ACLs", args.Intention.ID)
return acl.ErrPermissionDenied
}
}
}
// We always update the updatedat field. This has no effect for deletion.
args.Intention.UpdatedAt = time.Now().UTC()
// Default source type
if args.Intention.SourceType == "" {
args.Intention.SourceType = structs.IntentionSourceConsul
}
// Until we support namespaces, we force all namespaces to be default
if args.Intention.SourceNS == "" {
args.Intention.SourceNS = structs.IntentionDefaultNamespace
}
if args.Intention.DestinationNS == "" {
args.Intention.DestinationNS = structs.IntentionDefaultNamespace
}
// Validate. We do not validate on delete since it is valid to only
// send an ID in that case.
if args.Op != structs.IntentionOpDelete {
// Set the precedence
args.Intention.UpdatePrecedence()
if err := args.Intention.Validate(); err != nil {
return err
}
}
// Commit
resp, err := s.srv.raftApply(structs.IntentionRequestType, args)
if err != nil {
s.srv.logger.Printf("[ERR] consul.intention: Apply failed %v", err)
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
return nil
}
// Get returns a single intention by ID.
func (s *Intention) Get(
args *structs.IntentionQueryRequest,
reply *structs.IndexedIntentions) error {
// Forward if necessary
if done, err := s.srv.forward("Intention.Get", args, args, reply); done {
return err
}
return s.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
index, ixn, err := state.IntentionGet(ws, args.IntentionID)
if err != nil {
return err
}
if ixn == nil {
return ErrIntentionNotFound
}
reply.Index = index
reply.Intentions = structs.Intentions{ixn}
// Filter
if err := s.srv.filterACL(args.Token, reply); err != nil {
return err
}
// If ACLs prevented any responses, error
if len(reply.Intentions) == 0 {
s.srv.logger.Printf("[WARN] consul.intention: Request to get intention '%s' denied due to ACLs", args.IntentionID)
return acl.ErrPermissionDenied
}
return nil
},
)
}
// List returns all the intentions.
func (s *Intention) List(
args *structs.DCSpecificRequest,
reply *structs.IndexedIntentions) error {
// Forward if necessary
if done, err := s.srv.forward("Intention.List", args, args, reply); done {
return err
}
return s.srv.blockingQuery(
&args.QueryOptions, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
index, ixns, err := state.Intentions(ws)
if err != nil {
return err
}
reply.Index, reply.Intentions = index, ixns
if reply.Intentions == nil {
reply.Intentions = make(structs.Intentions, 0)
}
return s.srv.filterACL(args.Token, reply)
},
)
}
// Match returns the set of intentions that match the given source/destination.
func (s *Intention) Match(
args *structs.IntentionQueryRequest,
reply *structs.IndexedIntentionMatches) error {
// Forward if necessary
if done, err := s.srv.forward("Intention.Match", args, args, reply); done {
return err
}
// Get the ACL token for the request for the checks below.
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil {
// We go through each entry and test the destination to check if it
// matches.
for _, entry := range args.Match.Entries {
if prefix := entry.Name; prefix != "" && !rule.IntentionRead(prefix) {
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention prefix '%s' denied due to ACLs", prefix)
return acl.ErrPermissionDenied
}
}
}
return s.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
index, matches, err := state.IntentionMatch(ws, args.Match)
if err != nil {
return err
}
reply.Index = index
reply.Matches = matches
return nil
},
)
}
// Check tests a source/destination and returns whether it would be allowed
// or denied based on the current ACL configuration.
//
// Note: Whenever the logic for this method is changed, you should take
// a look at the agent authorize endpoint (agent/agent_endpoint.go) since
// the logic there is similar.
func (s *Intention) Check(
args *structs.IntentionQueryRequest,
reply *structs.IntentionQueryCheckResponse) error {
// Forward maybe
if done, err := s.srv.forward("Intention.Check", args, args, reply); done {
return err
}
// Get the test args, and defensively guard against nil
query := args.Check
if query == nil {
return errors.New("Check must be specified on args")
}
// Build the URI
var uri connect.CertURI
switch query.SourceType {
case structs.IntentionSourceConsul:
uri = &connect.SpiffeIDService{
Namespace: query.SourceNS,
Service: query.SourceName,
}
default:
return fmt.Errorf("unsupported SourceType: %q", query.SourceType)
}
// Get the ACL token for the request for the checks below.
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
// Perform the ACL check. For Check we only require ServiceRead and
// NOT IntentionRead because the Check API only returns pass/fail and
// returns no other information about the intentions used.
if prefix, ok := query.GetACLPrefix(); ok {
if rule != nil && !rule.ServiceRead(prefix) {
s.srv.logger.Printf("[WARN] consul.intention: test on intention '%s' denied due to ACLs", prefix)
return acl.ErrPermissionDenied
}
}
// Get the matches for this destination
state := s.srv.fsm.State()
_, matches, err := state.IntentionMatch(nil, &structs.IntentionQueryMatch{
Type: structs.IntentionMatchDestination,
Entries: []structs.IntentionMatchEntry{
structs.IntentionMatchEntry{
Namespace: query.DestinationNS,
Name: query.DestinationName,
},
},
})
if err != nil {
return err
}
if len(matches) != 1 {
// This should never happen since the documented behavior of the
// Match call is that it'll always return exactly the number of results
// as entries passed in. But we guard against misbehavior.
return errors.New("internal error loading matches")
}
// Check the authorization for each match
for _, ixn := range matches[0] {
if auth, ok := uri.Authorize(ixn); ok {
reply.Allowed = auth
return nil
}
}
// No match, we need to determine the default behavior. We do this by
// specifying the anonymous token token, which will get that behavior.
// The default behavior if ACLs are disabled is to allow connections
// to mimic the behavior of Consul itself: everything is allowed if
// ACLs are disabled.
//
// NOTE(mitchellh): This is the same behavior as the agent authorize
// endpoint. If this behavior is incorrect, we should also change it there
// which is much more important.
rule, err = s.srv.resolveToken("")
if err != nil {
return err
}
reply.Allowed = true
if rule != nil {
reply.Allowed = rule.IntentionDefaultAllow()
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -4,16 +4,20 @@ import (
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/connect"
ca "github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/types"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-version"
"github.com/hashicorp/raft"
"github.com/hashicorp/serf/serf"
@ -210,6 +214,12 @@ func (s *Server) establishLeadership() error {
s.getOrCreateAutopilotConfig()
s.autopilot.Start()
// todo(kyhavlov): start a goroutine here for handling periodic CA rotation
if err := s.initializeCA(); err != nil {
return err
}
s.setConsistentReadReady()
return nil
}
@ -226,6 +236,8 @@ func (s *Server) revokeLeadership() error {
return err
}
s.setCAProvider(nil, nil)
s.resetConsistentReadReady()
s.autopilot.Stop()
return nil
@ -359,6 +371,185 @@ func (s *Server) getOrCreateAutopilotConfig() *autopilot.Config {
return config
}
// initializeCAConfig is used to initialize the CA config if necessary
// when setting up the CA during establishLeadership
func (s *Server) initializeCAConfig() (*structs.CAConfiguration, error) {
state := s.fsm.State()
_, config, err := state.CAConfig()
if err != nil {
return nil, err
}
if config != nil {
return config, nil
}
config = s.config.CAConfig
if config.ClusterID == "" {
id, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
config.ClusterID = id
}
req := structs.CARequest{
Op: structs.CAOpSetConfig,
Config: config,
}
if _, err = s.raftApply(structs.ConnectCARequestType, req); err != nil {
return nil, err
}
return config, nil
}
// initializeCA sets up the CA provider when gaining leadership, bootstrapping
// the root in the state store if necessary.
func (s *Server) initializeCA() error {
// Bail if connect isn't enabled.
if !s.config.ConnectEnabled {
return nil
}
conf, err := s.initializeCAConfig()
if err != nil {
return err
}
// Initialize the right provider based on the config
provider, err := s.createCAProvider(conf)
if err != nil {
return err
}
// Get the active root cert from the CA
rootPEM, err := provider.ActiveRoot()
if err != nil {
return fmt.Errorf("error getting root cert: %v", err)
}
rootCA, err := parseCARoot(rootPEM, conf.Provider)
if err != nil {
return err
}
// TODO(banks): in the case that we've just gained leadership in an already
// configured cluster. We really need to fetch RootCA from state to provide it
// in setCAProvider. This matters because if the current active root has
// intermediates, parsing the rootCA from only the root cert PEM above will
// not include them and so leafs we sign will not bundle the intermediates.
s.setCAProvider(provider, rootCA)
// Check if the CA root is already initialized and exit if it is.
// Every change to the CA after this initial bootstrapping should
// be done through the rotation process.
state := s.fsm.State()
_, activeRoot, err := state.CARootActive(nil)
if err != nil {
return err
}
if activeRoot != nil {
if activeRoot.ID != rootCA.ID {
// TODO(banks): this seems like a pretty catastrophic state to get into.
// Shouldn't we do something stronger than warn and continue signing with
// a key that's not the active CA according to the state?
s.logger.Printf("[WARN] connect: CA root %q is not the active root (%q)", rootCA.ID, activeRoot.ID)
}
return nil
}
// Get the highest index
idx, _, err := state.CARoots(nil)
if err != nil {
return err
}
// Store the root cert in raft
resp, err := s.raftApply(structs.ConnectCARequestType, &structs.CARequest{
Op: structs.CAOpSetRoots,
Index: idx,
Roots: []*structs.CARoot{rootCA},
})
if err != nil {
s.logger.Printf("[ERR] connect: Apply failed %v", err)
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
s.logger.Printf("[INFO] connect: initialized CA with provider %q", conf.Provider)
return nil
}
// parseCARoot returns a filled-in structs.CARoot from a raw PEM value.
func parseCARoot(pemValue, provider string) (*structs.CARoot, error) {
id, err := connect.CalculateCertFingerprint(pemValue)
if err != nil {
return nil, fmt.Errorf("error parsing root fingerprint: %v", err)
}
rootCert, err := connect.ParseCert(pemValue)
if err != nil {
return nil, fmt.Errorf("error parsing root cert: %v", err)
}
return &structs.CARoot{
ID: id,
Name: fmt.Sprintf("%s CA Root Cert", strings.Title(provider)),
SerialNumber: rootCert.SerialNumber.Uint64(),
SigningKeyID: connect.HexString(rootCert.AuthorityKeyId),
NotBefore: rootCert.NotBefore,
NotAfter: rootCert.NotAfter,
RootCert: pemValue,
Active: true,
}, nil
}
// createProvider returns a connect CA provider from the given config.
func (s *Server) createCAProvider(conf *structs.CAConfiguration) (ca.Provider, error) {
switch conf.Provider {
case structs.ConsulCAProvider:
return ca.NewConsulProvider(conf.Config, &consulCADelegate{s})
case structs.VaultCAProvider:
return ca.NewVaultProvider(conf.Config, conf.ClusterID)
default:
return nil, fmt.Errorf("unknown CA provider %q", conf.Provider)
}
}
func (s *Server) getCAProvider() (ca.Provider, *structs.CARoot) {
retries := 0
var result ca.Provider
var resultRoot *structs.CARoot
for result == nil {
s.caProviderLock.RLock()
result = s.caProvider
resultRoot = s.caProviderRoot
s.caProviderLock.RUnlock()
// In cases where an agent is started with managed proxies, we may ask
// for the provider before establishLeadership completes. If we're the
// leader, then wait and get the provider again
if result == nil && s.IsLeader() && retries < 10 {
retries++
time.Sleep(50 * time.Millisecond)
continue
}
break
}
return result, resultRoot
}
func (s *Server) setCAProvider(newProvider ca.Provider, root *structs.CARoot) {
s.caProviderLock.Lock()
defer s.caProviderLock.Unlock()
s.caProvider = newProvider
s.caProviderRoot = root
}
// reconcileReaped is used to reconcile nodes that have failed and been reaped
// from Serf but remain in the catalog. This is done by looking for unknown nodes with serfHealth checks registered.
// We generate a "reap" event to cause the node to be cleaned up.

View File

@ -354,7 +354,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
}
// Execute the query for the local DC.
if err := p.execute(query, reply); err != nil {
if err := p.execute(query, reply, args.Connect); err != nil {
return err
}
@ -450,7 +450,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
// by the query setup.
if len(reply.Nodes) == 0 {
wrapper := &queryServerWrapper{p.srv}
if err := queryFailover(wrapper, query, args.Limit, args.QueryOptions, reply); err != nil {
if err := queryFailover(wrapper, query, args, reply); err != nil {
return err
}
}
@ -479,7 +479,7 @@ func (p *PreparedQuery) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRe
}
// Run the query locally to see what we can find.
if err := p.execute(&args.Query, reply); err != nil {
if err := p.execute(&args.Query, reply, args.Connect); err != nil {
return err
}
@ -509,9 +509,18 @@ func (p *PreparedQuery) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRe
// execute runs a prepared query in the local DC without any failover. We don't
// apply any sorting options or ACL checks at this level - it should be done up above.
func (p *PreparedQuery) execute(query *structs.PreparedQuery,
reply *structs.PreparedQueryExecuteResponse) error {
reply *structs.PreparedQueryExecuteResponse,
forceConnect bool) error {
state := p.srv.fsm.State()
_, nodes, err := state.CheckServiceNodes(nil, query.Service.Service)
// If we're requesting Connect-capable services, then switch the
// lookup to be the Connect function.
f := state.CheckServiceNodes
if query.Service.Connect || forceConnect {
f = state.CheckConnectServiceNodes
}
_, nodes, err := f(nil, query.Service.Service)
if err != nil {
return err
}
@ -651,7 +660,7 @@ func (q *queryServerWrapper) ForwardDC(method, dc string, args interface{}, repl
// queryFailover runs an algorithm to determine which DCs to try and then calls
// them to try to locate alternative services.
func queryFailover(q queryServer, query *structs.PreparedQuery,
limit int, options structs.QueryOptions,
args *structs.PreparedQueryExecuteRequest,
reply *structs.PreparedQueryExecuteResponse) error {
// Pull the list of other DCs. This is sorted by RTT in case the user
@ -719,8 +728,9 @@ func queryFailover(q queryServer, query *structs.PreparedQuery,
remote := &structs.PreparedQueryExecuteRemoteRequest{
Datacenter: dc,
Query: *query,
Limit: limit,
QueryOptions: options,
Limit: args.Limit,
QueryOptions: args.QueryOptions,
Connect: args.Connect,
}
if err := q.ForwardDC("PreparedQuery.ExecuteRemote", dc, remote, reply); err != nil {
q.GetLogger().Printf("[WARN] consul.prepared_query: Failed querying for service '%s' in datacenter '%s': %s", query.Service.Service, dc, err)

View File

@ -20,6 +20,7 @@ import (
"github.com/hashicorp/consul/types"
"github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/serf/coordinate"
"github.com/stretchr/testify/require"
)
func TestPreparedQuery_Apply(t *testing.T) {
@ -2617,6 +2618,159 @@ func TestPreparedQuery_Execute_ForwardLeader(t *testing.T) {
}
}
func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
// Setup 3 services on 3 nodes: one is non-Connect, one is Connect native,
// and one is a proxy to the non-Connect one.
for i := 0; i < 3; i++ {
req := structs.RegisterRequest{
Datacenter: "dc1",
Node: fmt.Sprintf("node%d", i+1),
Address: fmt.Sprintf("127.0.0.%d", i+1),
Service: &structs.NodeService{
Service: "foo",
Port: 8000,
},
}
switch i {
case 0:
// Default do nothing
case 1:
// Connect native
req.Service.Connect.Native = true
case 2:
// Connect proxy
req.Service.Kind = structs.ServiceKindConnectProxy
req.Service.ProxyDestination = req.Service.Service
req.Service.Service = "proxy"
}
var reply struct{}
require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &reply))
}
// The query, start with connect disabled
query := structs.PreparedQueryRequest{
Datacenter: "dc1",
Op: structs.PreparedQueryCreate,
Query: &structs.PreparedQuery{
Name: "test",
Service: structs.ServiceQuery{
Service: "foo",
},
DNS: structs.QueryDNSOptions{
TTL: "10s",
},
},
}
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
// In the future we'll run updates
query.Op = structs.PreparedQueryUpdate
// Run the registered query.
{
req := structs.PreparedQueryExecuteRequest{
Datacenter: "dc1",
QueryIDOrName: query.Query.ID,
}
var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because it omits the proxy whose name
// doesn't match the query.
require.Len(reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader")
}
// Run with the Connect setting specified on the request
{
req := structs.PreparedQueryExecuteRequest{
Datacenter: "dc1",
QueryIDOrName: query.Query.ID,
Connect: true,
}
var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because we should get the native AND
// the proxy (since the destination matches our service name).
require.Len(reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader")
// Make sure the native is the first one
if !reply.Nodes[0].Service.Connect.Native {
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
}
require.True(reply.Nodes[0].Service.Connect.Native, "native")
require.Equal(reply.Service, reply.Nodes[0].Service.Service)
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
require.Equal(reply.Service, reply.Nodes[1].Service.ProxyDestination)
}
// Update the query
query.Query.Service.Connect = true
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
// Run the registered query.
{
req := structs.PreparedQueryExecuteRequest{
Datacenter: "dc1",
QueryIDOrName: query.Query.ID,
}
var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because we should get the native AND
// the proxy (since the destination matches our service name).
require.Len(reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader")
// Make sure the native is the first one
if !reply.Nodes[0].Service.Connect.Native {
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
}
require.True(reply.Nodes[0].Service.Connect.Native, "native")
require.Equal(reply.Service, reply.Nodes[0].Service.Service)
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
require.Equal(reply.Service, reply.Nodes[1].Service.ProxyDestination)
}
// Unset the query
query.Query.Service.Connect = false
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
}
func TestPreparedQuery_tagFilter(t *testing.T) {
t.Parallel()
testNodes := func() structs.CheckServiceNodes {
@ -2820,7 +2974,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 0 || reply.Datacenter != "" || reply.Failovers != 0 {
@ -2836,7 +2990,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply)
err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply)
if err == nil || !strings.Contains(err.Error(), "XXX") {
t.Fatalf("bad: %v", err)
}
@ -2853,7 +3007,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 0 || reply.Datacenter != "" || reply.Failovers != 0 {
@ -2876,7 +3030,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -2904,7 +3058,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -2925,7 +3079,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 0 ||
@ -2954,7 +3108,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -2983,7 +3137,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -3012,7 +3166,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -3047,7 +3201,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -3079,7 +3233,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -3115,7 +3269,10 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 5, structs.QueryOptions{RequireConsistent: true}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{
Limit: 5,
QueryOptions: structs.QueryOptions{RequireConsistent: true},
}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||

View File

@ -18,6 +18,7 @@ import (
"time"
"github.com/hashicorp/consul/acl"
ca "github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/consul/fsm"
"github.com/hashicorp/consul/agent/consul/state"
@ -96,6 +97,16 @@ type Server struct {
// autopilotWaitGroup is used to block until Autopilot shuts down.
autopilotWaitGroup sync.WaitGroup
// caProvider is the current CA provider in use for Connect. This is
// only non-nil when we are the leader.
caProvider ca.Provider
// caProviderRoot is the CARoot that was stored along with the ca.Provider
// active. It's only updated in lock-step with the caProvider. This prevents
// races between state updates to active roots and the fetch of the provider
// instance.
caProviderRoot *structs.CARoot
caProviderLock sync.RWMutex
// Consul configuration
config *Config
@ -1066,6 +1077,12 @@ func (s *Server) GetLANCoordinate() (lib.CoordinateSet, error) {
return cs, nil
}
// ReloadConfig is used to have the Server do an online reload of
// relevant configuration information
func (s *Server) ReloadConfig(config *Config) error {
return nil
}
// Atomically sets a readiness state flag when leadership is obtained, to indicate that server is past its barrier write
func (s *Server) setConsistentReadReady() {
atomic.StoreInt32(&s.readyForConsistentReads, 1)

View File

@ -4,7 +4,9 @@ func init() {
registerEndpoint(func(s *Server) interface{} { return &ACL{s} })
registerEndpoint(func(s *Server) interface{} { return &Catalog{s} })
registerEndpoint(func(s *Server) interface{} { return NewCoordinate(s) })
registerEndpoint(func(s *Server) interface{} { return &ConnectCA{s} })
registerEndpoint(func(s *Server) interface{} { return &Health{s} })
registerEndpoint(func(s *Server) interface{} { return &Intention{s} })
registerEndpoint(func(s *Server) interface{} { return &Internal{s} })
registerEndpoint(func(s *Server) interface{} { return &KVS{s} })
registerEndpoint(func(s *Server) interface{} { return &Operator{s} })

View File

@ -10,7 +10,9 @@ import (
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/lib/freeport"
"github.com/hashicorp/consul/testrpc"
@ -91,6 +93,17 @@ func testServerConfig(t *testing.T) (string, *Config) {
// looks like several depend on it.
config.RPCHoldTimeout = 5 * time.Second
config.ConnectEnabled = true
config.CAConfig = &structs.CAConfiguration{
ClusterID: connect.TestClusterID,
Provider: structs.ConsulCAProvider,
Config: map[string]interface{}{
"PrivateKey": "",
"RootCert": "",
"RotationPeriod": 90 * 24 * time.Hour,
},
}
return dir, config
}

View File

@ -10,6 +10,10 @@ import (
"github.com/hashicorp/go-memdb"
)
const (
servicesTableName = "services"
)
// nodesTableSchema returns a new table schema used for storing node
// information.
func nodesTableSchema() *memdb.TableSchema {
@ -87,6 +91,12 @@ func servicesTableSchema() *memdb.TableSchema {
Lowercase: true,
},
},
"connect": &memdb.IndexSchema{
Name: "connect",
AllowMissing: true,
Unique: false,
Indexer: &IndexConnectService{},
},
},
}
}
@ -779,15 +789,39 @@ func maxIndexForService(tx *memdb.Txn, serviceName string, checks bool) uint64 {
return maxIndexTxn(tx, "nodes", "services")
}
// ConnectServiceNodes returns the nodes associated with a Connect
// compatible destination for the given service name. This will include
// both proxies and native integrations.
func (s *Store) ConnectServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.ServiceNodes, error) {
return s.serviceNodes(ws, serviceName, true)
}
// ServiceNodes returns the nodes associated with a given service name.
func (s *Store) ServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.ServiceNodes, error) {
return s.serviceNodes(ws, serviceName, false)
}
func (s *Store) serviceNodes(ws memdb.WatchSet, serviceName string, connect bool) (uint64, structs.ServiceNodes, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexForService(tx, serviceName, false)
// Function for lookup
var f func() (memdb.ResultIterator, error)
if !connect {
f = func() (memdb.ResultIterator, error) {
return tx.Get("services", "service", serviceName)
}
} else {
f = func() (memdb.ResultIterator, error) {
return tx.Get("services", "connect", serviceName)
}
}
// List all the services.
services, err := tx.Get("services", "service", serviceName)
services, err := f()
if err != nil {
return 0, nil, fmt.Errorf("failed service lookup: %s", err)
}
@ -1479,14 +1513,36 @@ func (s *Store) deleteCheckTxn(tx *memdb.Txn, idx uint64, node string, checkID t
// CheckServiceNodes is used to query all nodes and checks for a given service.
func (s *Store) CheckServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.CheckServiceNodes, error) {
return s.checkServiceNodes(ws, serviceName, false)
}
// CheckConnectServiceNodes is used to query all nodes and checks for Connect
// compatible endpoints for a given service.
func (s *Store) CheckConnectServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.CheckServiceNodes, error) {
return s.checkServiceNodes(ws, serviceName, true)
}
func (s *Store) checkServiceNodes(ws memdb.WatchSet, serviceName string, connect bool) (uint64, structs.CheckServiceNodes, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexForService(tx, serviceName, true)
// Function for lookup
var f func() (memdb.ResultIterator, error)
if !connect {
f = func() (memdb.ResultIterator, error) {
return tx.Get("services", "service", serviceName)
}
} else {
f = func() (memdb.ResultIterator, error) {
return tx.Get("services", "connect", serviceName)
}
}
// Query the state store for the service.
iter, err := tx.Get("services", "service", serviceName)
iter, err := f()
if err != nil {
return 0, nil, fmt.Errorf("failed service lookup: %s", err)
}

View File

@ -14,6 +14,7 @@ import (
"github.com/hashicorp/go-memdb"
uuid "github.com/hashicorp/go-uuid"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
)
func makeRandomNodeID(t *testing.T) types.NodeID {
@ -981,6 +982,35 @@ func TestStateStore_EnsureService(t *testing.T) {
}
}
func TestStateStore_EnsureService_connectProxy(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Create the service registration.
ns1 := &structs.NodeService{
Kind: structs.ServiceKindConnectProxy,
ID: "connect-proxy",
Service: "connect-proxy",
Address: "1.1.1.1",
Port: 1111,
ProxyDestination: "foo",
}
// Service successfully registers into the state store.
testRegisterNode(t, s, 0, "node1")
assert.Nil(s.EnsureService(10, "node1", ns1))
// Retrieve and verify
_, out, err := s.NodeServices(nil, "node1")
assert.Nil(err)
assert.NotNil(out)
assert.Len(out.Services, 1)
expect1 := *ns1
expect1.CreateIndex, expect1.ModifyIndex = 10, 10
assert.Equal(&expect1, out.Services["connect-proxy"])
}
func TestStateStore_Services(t *testing.T) {
s := testStateStore(t)
@ -1542,6 +1572,51 @@ func TestStateStore_DeleteService(t *testing.T) {
}
}
func TestStateStore_ConnectServiceNodes(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Listing with no results returns an empty list.
ws := memdb.NewWatchSet()
idx, nodes, err := s.ConnectServiceNodes(ws, "db")
assert.Nil(err)
assert.Equal(idx, uint64(0))
assert.Len(nodes, 0)
// Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "native-db", Service: "db", Connect: structs.ServiceConnect{Native: true}}))
assert.Nil(s.EnsureService(17, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}))
assert.True(watchFired(ws))
// Read everything back.
ws = memdb.NewWatchSet()
idx, nodes, err = s.ConnectServiceNodes(ws, "db")
assert.Nil(err)
assert.Equal(idx, uint64(idx))
assert.Len(nodes, 3)
for _, n := range nodes {
assert.True(
n.ServiceKind == structs.ServiceKindConnectProxy ||
n.ServiceConnect.Native,
"either proxy or connect native")
}
// Registering some unrelated node should not fire the watch.
testRegisterNode(t, s, 17, "nope")
assert.False(watchFired(ws))
// But removing a node with the "db" service should fire the watch.
assert.Nil(s.DeleteNode(18, "bar"))
assert.True(watchFired(ws))
}
func TestStateStore_Service_Snapshot(t *testing.T) {
s := testStateStore(t)
@ -2457,6 +2532,48 @@ func TestStateStore_CheckServiceNodes(t *testing.T) {
}
}
func TestStateStore_CheckConnectServiceNodes(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Listing with no results returns an empty list.
ws := memdb.NewWatchSet()
idx, nodes, err := s.CheckConnectServiceNodes(ws, "db")
assert.Nil(err)
assert.Equal(idx, uint64(0))
assert.Len(nodes, 0)
// Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}))
assert.True(watchFired(ws))
// Register node checks
testRegisterCheck(t, s, 17, "foo", "", "check1", api.HealthPassing)
testRegisterCheck(t, s, 18, "bar", "", "check2", api.HealthPassing)
// Register checks against the services.
testRegisterCheck(t, s, 19, "foo", "db", "check3", api.HealthPassing)
testRegisterCheck(t, s, 20, "bar", "proxy", "check4", api.HealthPassing)
// Read everything back.
ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db")
assert.Nil(err)
assert.Equal(idx, uint64(idx))
assert.Len(nodes, 2)
for _, n := range nodes {
assert.Equal(structs.ServiceKindConnectProxy, n.Service.Kind)
assert.Equal("db", n.Service.ProxyDestination)
}
}
func BenchmarkCheckServiceNodes(b *testing.B) {
s, err := NewStateStore(nil)
if err != nil {

View File

@ -0,0 +1,435 @@
package state
import (
"fmt"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
)
const (
caBuiltinProviderTableName = "connect-ca-builtin"
caConfigTableName = "connect-ca-config"
caRootTableName = "connect-ca-roots"
)
// caBuiltinProviderTableSchema returns a new table schema used for storing
// the built-in CA provider's state for connect. This is only used by
// the internal Consul CA provider.
func caBuiltinProviderTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: caBuiltinProviderTableName,
Indexes: map[string]*memdb.IndexSchema{
"id": &memdb.IndexSchema{
Name: "id",
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "ID",
},
},
},
}
}
// caConfigTableSchema returns a new table schema used for storing
// the CA config for Connect.
func caConfigTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: caConfigTableName,
Indexes: map[string]*memdb.IndexSchema{
// This table only stores one row, so this just ignores the ID field
// and always overwrites the same config object.
"id": &memdb.IndexSchema{
Name: "id",
AllowMissing: true,
Unique: true,
Indexer: &memdb.ConditionalIndex{
Conditional: func(obj interface{}) (bool, error) { return true, nil },
},
},
},
}
}
// caRootTableSchema returns a new table schema used for storing
// CA roots for Connect.
func caRootTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: caRootTableName,
Indexes: map[string]*memdb.IndexSchema{
"id": &memdb.IndexSchema{
Name: "id",
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "ID",
},
},
},
}
}
func init() {
registerSchema(caBuiltinProviderTableSchema)
registerSchema(caConfigTableSchema)
registerSchema(caRootTableSchema)
}
// CAConfig is used to pull the CA config from the snapshot.
func (s *Snapshot) CAConfig() (*structs.CAConfiguration, error) {
c, err := s.tx.First(caConfigTableName, "id")
if err != nil {
return nil, err
}
config, ok := c.(*structs.CAConfiguration)
if !ok {
return nil, nil
}
return config, nil
}
// CAConfig is used when restoring from a snapshot.
func (s *Restore) CAConfig(config *structs.CAConfiguration) error {
if err := s.tx.Insert(caConfigTableName, config); err != nil {
return fmt.Errorf("failed restoring CA config: %s", err)
}
return nil
}
// CAConfig is used to get the current CA configuration.
func (s *Store) CAConfig() (uint64, *structs.CAConfiguration, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the CA config
c, err := tx.First(caConfigTableName, "id")
if err != nil {
return 0, nil, fmt.Errorf("failed CA config lookup: %s", err)
}
config, ok := c.(*structs.CAConfiguration)
if !ok {
return 0, nil, nil
}
return config.ModifyIndex, config, nil
}
// CASetConfig is used to set the current CA configuration.
func (s *Store) CASetConfig(idx uint64, config *structs.CAConfiguration) error {
tx := s.db.Txn(true)
defer tx.Abort()
if err := s.caSetConfigTxn(idx, tx, config); err != nil {
return err
}
tx.Commit()
return nil
}
// CACheckAndSetConfig is used to try updating the CA configuration with a
// given Raft index. If the CAS index specified is not equal to the last observed index
// for the config, then the call is a noop,
func (s *Store) CACheckAndSetConfig(idx, cidx uint64, config *structs.CAConfiguration) (bool, error) {
tx := s.db.Txn(true)
defer tx.Abort()
// Check for an existing config
existing, err := tx.First(caConfigTableName, "id")
if err != nil {
return false, fmt.Errorf("failed CA config lookup: %s", err)
}
// If the existing index does not match the provided CAS
// index arg, then we shouldn't update anything and can safely
// return early here.
e, ok := existing.(*structs.CAConfiguration)
if !ok || e.ModifyIndex != cidx {
return false, nil
}
if err := s.caSetConfigTxn(idx, tx, config); err != nil {
return false, err
}
tx.Commit()
return true, nil
}
func (s *Store) caSetConfigTxn(idx uint64, tx *memdb.Txn, config *structs.CAConfiguration) error {
// Check for an existing config
prev, err := tx.First(caConfigTableName, "id")
if err != nil {
return fmt.Errorf("failed CA config lookup: %s", err)
}
// Set the indexes, prevent the cluster ID from changing.
if prev != nil {
existing := prev.(*structs.CAConfiguration)
config.CreateIndex = existing.CreateIndex
config.ClusterID = existing.ClusterID
} else {
config.CreateIndex = idx
}
config.ModifyIndex = idx
if err := tx.Insert(caConfigTableName, config); err != nil {
return fmt.Errorf("failed updating CA config: %s", err)
}
return nil
}
// CARoots is used to pull all the CA roots for the snapshot.
func (s *Snapshot) CARoots() (structs.CARoots, error) {
ixns, err := s.tx.Get(caRootTableName, "id")
if err != nil {
return nil, err
}
var ret structs.CARoots
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
ret = append(ret, wrapped.(*structs.CARoot))
}
return ret, nil
}
// CARoots is used when restoring from a snapshot.
func (s *Restore) CARoot(r *structs.CARoot) error {
// Insert
if err := s.tx.Insert(caRootTableName, r); err != nil {
return fmt.Errorf("failed restoring CA root: %s", err)
}
if err := indexUpdateMaxTxn(s.tx, r.ModifyIndex, caRootTableName); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
return nil
}
// CARoots returns the list of all CA roots.
func (s *Store) CARoots(ws memdb.WatchSet) (uint64, structs.CARoots, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the index
idx := maxIndexTxn(tx, caRootTableName)
// Get all
iter, err := tx.Get(caRootTableName, "id")
if err != nil {
return 0, nil, fmt.Errorf("failed CA root lookup: %s", err)
}
ws.Add(iter.WatchCh())
var results structs.CARoots
for v := iter.Next(); v != nil; v = iter.Next() {
results = append(results, v.(*structs.CARoot))
}
return idx, results, nil
}
// CARootActive returns the currently active CARoot.
func (s *Store) CARootActive(ws memdb.WatchSet) (uint64, *structs.CARoot, error) {
// Get all the roots since there should never be that many and just
// do the filtering in this method.
var result *structs.CARoot
idx, roots, err := s.CARoots(ws)
if err == nil {
for _, r := range roots {
if r.Active {
result = r
break
}
}
}
return idx, result, err
}
// CARootSetCAS sets the current CA root state using a check-and-set operation.
// On success, this will replace the previous set of CARoots completely with
// the given set of roots.
//
// The first boolean result returns whether the transaction succeeded or not.
func (s *Store) CARootSetCAS(idx, cidx uint64, rs []*structs.CARoot) (bool, error) {
tx := s.db.Txn(true)
defer tx.Abort()
// There must be exactly one active CA root.
activeCount := 0
for _, r := range rs {
if r.Active {
activeCount++
}
}
if activeCount != 1 {
return false, fmt.Errorf("there must be exactly one active CA")
}
// Get the current max index
if midx := maxIndexTxn(tx, caRootTableName); midx != cidx {
return false, nil
}
// Go through and find any existing matching CAs so we can preserve and
// update their Create/ModifyIndex values.
for _, r := range rs {
if r.ID == "" {
return false, ErrMissingCARootID
}
existing, err := tx.First(caRootTableName, "id", r.ID)
if err != nil {
return false, fmt.Errorf("failed CA root lookup: %s", err)
}
if existing != nil {
r.CreateIndex = existing.(*structs.CARoot).CreateIndex
} else {
r.CreateIndex = idx
}
r.ModifyIndex = idx
}
// Delete all
_, err := tx.DeleteAll(caRootTableName, "id")
if err != nil {
return false, err
}
// Insert all
for _, r := range rs {
if err := tx.Insert(caRootTableName, r); err != nil {
return false, err
}
}
// Update the index
if err := tx.Insert("index", &IndexEntry{caRootTableName, idx}); err != nil {
return false, fmt.Errorf("failed updating index: %s", err)
}
tx.Commit()
return true, nil
}
// CAProviderState is used to pull the built-in provider states from the snapshot.
func (s *Snapshot) CAProviderState() ([]*structs.CAConsulProviderState, error) {
ixns, err := s.tx.Get(caBuiltinProviderTableName, "id")
if err != nil {
return nil, err
}
var ret []*structs.CAConsulProviderState
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
ret = append(ret, wrapped.(*structs.CAConsulProviderState))
}
return ret, nil
}
// CAProviderState is used when restoring from a snapshot.
func (s *Restore) CAProviderState(state *structs.CAConsulProviderState) error {
if err := s.tx.Insert(caBuiltinProviderTableName, state); err != nil {
return fmt.Errorf("failed restoring built-in CA state: %s", err)
}
if err := indexUpdateMaxTxn(s.tx, state.ModifyIndex, caBuiltinProviderTableName); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
return nil
}
// CAProviderState is used to get the Consul CA provider state for the given ID.
func (s *Store) CAProviderState(id string) (uint64, *structs.CAConsulProviderState, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the index
idx := maxIndexTxn(tx, caBuiltinProviderTableName)
// Get the provider config
c, err := tx.First(caBuiltinProviderTableName, "id", id)
if err != nil {
return 0, nil, fmt.Errorf("failed built-in CA state lookup: %s", err)
}
state, ok := c.(*structs.CAConsulProviderState)
if !ok {
return 0, nil, nil
}
return idx, state, nil
}
// CASetProviderState is used to set the current built-in CA provider state.
func (s *Store) CASetProviderState(idx uint64, state *structs.CAConsulProviderState) (bool, error) {
tx := s.db.Txn(true)
defer tx.Abort()
// Check for an existing config
existing, err := tx.First(caBuiltinProviderTableName, "id", state.ID)
if err != nil {
return false, fmt.Errorf("failed built-in CA state lookup: %s", err)
}
// Set the indexes.
if existing != nil {
state.CreateIndex = existing.(*structs.CAConsulProviderState).CreateIndex
} else {
state.CreateIndex = idx
}
state.ModifyIndex = idx
if err := tx.Insert(caBuiltinProviderTableName, state); err != nil {
return false, fmt.Errorf("failed updating built-in CA state: %s", err)
}
// Update the index
if err := tx.Insert("index", &IndexEntry{caBuiltinProviderTableName, idx}); err != nil {
return false, fmt.Errorf("failed updating index: %s", err)
}
tx.Commit()
return true, nil
}
// CADeleteProviderState is used to remove the built-in Consul CA provider
// state for the given ID.
func (s *Store) CADeleteProviderState(id string) error {
tx := s.db.Txn(true)
defer tx.Abort()
// Get the index
idx := maxIndexTxn(tx, caBuiltinProviderTableName)
// Check for an existing config
existing, err := tx.First(caBuiltinProviderTableName, "id", id)
if err != nil {
return fmt.Errorf("failed built-in CA state lookup: %s", err)
}
if existing == nil {
return nil
}
providerState := existing.(*structs.CAConsulProviderState)
// Do the delete and update the index
if err := tx.Delete(caBuiltinProviderTableName, providerState); err != nil {
return err
}
if err := tx.Insert("index", &IndexEntry{caBuiltinProviderTableName, idx}); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
tx.Commit()
return nil
}

View File

@ -0,0 +1,449 @@
package state
import (
"reflect"
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
)
func TestStore_CAConfig(t *testing.T) {
s := testStateStore(t)
expected := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": "asdf",
"RootCert": "qwer",
"RotationPeriod": 90 * 24 * time.Hour,
},
}
if err := s.CASetConfig(0, expected); err != nil {
t.Fatal(err)
}
idx, config, err := s.CAConfig()
if err != nil {
t.Fatal(err)
}
if idx != 0 {
t.Fatalf("bad: %d", idx)
}
if !reflect.DeepEqual(expected, config) {
t.Fatalf("bad: %#v, %#v", expected, config)
}
}
func TestStore_CAConfigCAS(t *testing.T) {
s := testStateStore(t)
expected := &structs.CAConfiguration{
Provider: "consul",
}
if err := s.CASetConfig(0, expected); err != nil {
t.Fatal(err)
}
// Do an extra operation to move the index up by 1 for the
// check-and-set operation after this
if err := s.CASetConfig(1, expected); err != nil {
t.Fatal(err)
}
// Do a CAS with an index lower than the entry
ok, err := s.CACheckAndSetConfig(2, 0, &structs.CAConfiguration{
Provider: "static",
})
if ok || err != nil {
t.Fatalf("expected (false, nil), got: (%v, %#v)", ok, err)
}
// Check that the index is untouched and the entry
// has not been updated.
idx, config, err := s.CAConfig()
if err != nil {
t.Fatal(err)
}
if idx != 1 {
t.Fatalf("bad: %d", idx)
}
if config.Provider != "consul" {
t.Fatalf("bad: %#v", config)
}
// Do another CAS, this time with the correct index
ok, err = s.CACheckAndSetConfig(2, 1, &structs.CAConfiguration{
Provider: "static",
})
if !ok || err != nil {
t.Fatalf("expected (true, nil), got: (%v, %#v)", ok, err)
}
// Make sure the config was updated
idx, config, err = s.CAConfig()
if err != nil {
t.Fatal(err)
}
if idx != 2 {
t.Fatalf("bad: %d", idx)
}
if config.Provider != "static" {
t.Fatalf("bad: %#v", config)
}
}
func TestStore_CAConfig_Snapshot_Restore(t *testing.T) {
s := testStateStore(t)
before := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": "asdf",
"RootCert": "qwer",
"RotationPeriod": 90 * 24 * time.Hour,
},
}
if err := s.CASetConfig(99, before); err != nil {
t.Fatal(err)
}
snap := s.Snapshot()
defer snap.Close()
after := &structs.CAConfiguration{
Provider: "static",
Config: map[string]interface{}{},
}
if err := s.CASetConfig(100, after); err != nil {
t.Fatal(err)
}
snapped, err := snap.CAConfig()
if err != nil {
t.Fatalf("err: %s", err)
}
verify.Values(t, "", before, snapped)
s2 := testStateStore(t)
restore := s2.Restore()
if err := restore.CAConfig(snapped); err != nil {
t.Fatalf("err: %s", err)
}
restore.Commit()
idx, res, err := s2.CAConfig()
if err != nil {
t.Fatalf("err: %s", err)
}
if idx != 99 {
t.Fatalf("bad index: %d", idx)
}
verify.Values(t, "", before, res)
}
func TestStore_CARootSetList(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
assert.Nil(err)
assert.True(ok)
// Make sure the index got updated.
assert.Equal(s.maxIndex(caRootTableName), uint64(1))
assert.True(watchFired(ws), "watch fired")
// Read it back out and verify it.
expected := *ca1
expected.RaftIndex = structs.RaftIndex{
CreateIndex: 1,
ModifyIndex: 1,
}
ws = memdb.NewWatchSet()
_, roots, err := s.CARoots(ws)
assert.Nil(err)
assert.Len(roots, 1)
actual := roots[0]
assert.Equal(&expected, actual)
}
func TestStore_CARootSet_emptyID(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca1.ID = ""
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
assert.NotNil(err)
assert.Contains(err.Error(), ErrMissingCARootID.Error())
assert.False(ok)
// Make sure the index got updated.
assert.Equal(s.maxIndex(caRootTableName), uint64(0))
assert.False(watchFired(ws), "watch fired")
// Read it back out and verify it.
ws = memdb.NewWatchSet()
_, roots, err := s.CARoots(ws)
assert.Nil(err)
assert.Len(roots, 0)
}
func TestStore_CARootSet_noActive(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca1.Active = false
ca2 := connect.TestCA(t, nil)
ca2.Active = false
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
assert.NotNil(err)
assert.Contains(err.Error(), "exactly one active")
assert.False(ok)
}
func TestStore_CARootSet_multipleActive(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca2 := connect.TestCA(t, nil)
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
assert.NotNil(err)
assert.Contains(err.Error(), "exactly one active")
assert.False(ok)
}
func TestStore_CARootActive_valid(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca1.Active = false
ca2 := connect.TestCA(t, nil)
ca3 := connect.TestCA(t, nil)
ca3.Active = false
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2, ca3})
assert.Nil(err)
assert.True(ok)
// Query
ws := memdb.NewWatchSet()
idx, res, err := s.CARootActive(ws)
assert.Equal(idx, uint64(1))
assert.Nil(err)
assert.NotNil(res)
assert.Equal(ca2.ID, res.ID)
}
// Test that querying the active CA returns the correct value.
func TestStore_CARootActive_none(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Querying with no results returns nil.
ws := memdb.NewWatchSet()
idx, res, err := s.CARootActive(ws)
assert.Equal(idx, uint64(0))
assert.Nil(res)
assert.Nil(err)
}
func TestStore_CARoot_Snapshot_Restore(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Create some intentions.
roots := structs.CARoots{
connect.TestCA(t, nil),
connect.TestCA(t, nil),
connect.TestCA(t, nil),
}
for _, r := range roots[1:] {
r.Active = false
}
// Force the sort order of the UUIDs before we create them so the
// order is deterministic.
id := testUUID()
roots[0].ID = "a" + id[1:]
roots[1].ID = "b" + id[1:]
roots[2].ID = "c" + id[1:]
// Now create
ok, err := s.CARootSetCAS(1, 0, roots)
assert.Nil(err)
assert.True(ok)
// Snapshot the queries.
snap := s.Snapshot()
defer snap.Close()
// Alter the real state store.
ok, err = s.CARootSetCAS(2, 1, roots[:1])
assert.Nil(err)
assert.True(ok)
// Verify the snapshot.
assert.Equal(snap.LastIndex(), uint64(1))
dump, err := snap.CARoots()
assert.Nil(err)
assert.Equal(roots, dump)
// Restore the values into a new state store.
func() {
s := testStateStore(t)
restore := s.Restore()
for _, r := range dump {
assert.Nil(restore.CARoot(r))
}
restore.Commit()
// Read the restored values back out and verify that they match.
idx, actual, err := s.CARoots(nil)
assert.Nil(err)
assert.Equal(idx, uint64(2))
assert.Equal(roots, actual)
}()
}
func TestStore_CABuiltinProvider(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
{
expected := &structs.CAConsulProviderState{
ID: "foo",
PrivateKey: "a",
RootCert: "b",
}
ok, err := s.CASetProviderState(0, expected)
assert.NoError(err)
assert.True(ok)
idx, state, err := s.CAProviderState(expected.ID)
assert.NoError(err)
assert.Equal(idx, uint64(0))
assert.Equal(expected, state)
}
{
expected := &structs.CAConsulProviderState{
ID: "bar",
PrivateKey: "c",
RootCert: "d",
}
ok, err := s.CASetProviderState(1, expected)
assert.NoError(err)
assert.True(ok)
idx, state, err := s.CAProviderState(expected.ID)
assert.NoError(err)
assert.Equal(idx, uint64(1))
assert.Equal(expected, state)
}
}
func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Create multiple state entries.
before := []*structs.CAConsulProviderState{
{
ID: "bar",
PrivateKey: "y",
RootCert: "z",
},
{
ID: "foo",
PrivateKey: "a",
RootCert: "b",
},
}
for i, state := range before {
ok, err := s.CASetProviderState(uint64(98+i), state)
assert.NoError(err)
assert.True(ok)
}
// Take a snapshot.
snap := s.Snapshot()
defer snap.Close()
// Modify the state store.
after := &structs.CAConsulProviderState{
ID: "foo",
PrivateKey: "c",
RootCert: "d",
}
ok, err := s.CASetProviderState(100, after)
assert.NoError(err)
assert.True(ok)
snapped, err := snap.CAProviderState()
assert.NoError(err)
assert.Equal(before, snapped)
// Restore onto a new state store.
s2 := testStateStore(t)
restore := s2.Restore()
for _, entry := range snapped {
assert.NoError(restore.CAProviderState(entry))
}
restore.Commit()
// Verify the restored values match those from before the snapshot.
for _, state := range before {
idx, res, err := s2.CAProviderState(state.ID)
assert.NoError(err)
assert.Equal(idx, uint64(99))
assert.Equal(state, res)
}
}

View File

@ -0,0 +1,54 @@
package state
import (
"fmt"
"strings"
"github.com/hashicorp/consul/agent/structs"
)
// IndexConnectService indexes a *struct.ServiceNode for querying by
// services that support Connect to some target service. This will
// properly index the proxy destination for proxies and the service name
// for native services.
type IndexConnectService struct{}
func (idx *IndexConnectService) FromObject(obj interface{}) (bool, []byte, error) {
sn, ok := obj.(*structs.ServiceNode)
if !ok {
return false, nil, fmt.Errorf("Object must be ServiceNode, got %T", obj)
}
var result []byte
switch {
case sn.ServiceKind == structs.ServiceKindConnectProxy:
// For proxies, this service supports Connect for the destination
result = []byte(strings.ToLower(sn.ServiceProxyDestination))
case sn.ServiceConnect.Native:
// For native, this service supports Connect directly
result = []byte(strings.ToLower(sn.ServiceName))
default:
// Doesn't support Connect at all
return false, nil, nil
}
// Return the result with the null terminator appended so we can
// differentiate prefix vs. non-prefix matches.
return true, append(result, '\x00'), nil
}
func (idx *IndexConnectService) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
arg, ok := args[0].(string)
if !ok {
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
}
// Add the null character as a terminator
return append([]byte(strings.ToLower(arg)), '\x00'), nil
}

View File

@ -0,0 +1,122 @@
package state
import (
"testing"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/require"
)
func TestIndexConnectService_FromObject(t *testing.T) {
cases := []struct {
Name string
Input interface{}
ExpectMatch bool
ExpectVal []byte
ExpectErr string
}{
{
"not a ServiceNode",
42,
false,
nil,
"ServiceNode",
},
{
"typical service, not native",
&structs.ServiceNode{
ServiceName: "db",
},
false,
nil,
"",
},
{
"typical service, is native",
&structs.ServiceNode{
ServiceName: "dB",
ServiceConnect: structs.ServiceConnect{Native: true},
},
true,
[]byte("db\x00"),
"",
},
{
"proxy service",
&structs.ServiceNode{
ServiceKind: structs.ServiceKindConnectProxy,
ServiceName: "db",
ServiceProxyDestination: "fOo",
},
true,
[]byte("foo\x00"),
"",
},
}
for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
var idx IndexConnectService
match, val, err := idx.FromObject(tc.Input)
if tc.ExpectErr != "" {
require.Error(err)
require.Contains(err.Error(), tc.ExpectErr)
return
}
require.NoError(err)
require.Equal(tc.ExpectMatch, match)
require.Equal(tc.ExpectVal, val)
})
}
}
func TestIndexConnectService_FromArgs(t *testing.T) {
cases := []struct {
Name string
Args []interface{}
ExpectVal []byte
ExpectErr string
}{
{
"multiple arguments",
[]interface{}{"foo", "bar"},
nil,
"single",
},
{
"not a string",
[]interface{}{42},
nil,
"must be a string",
},
{
"string",
[]interface{}{"fOO"},
[]byte("foo\x00"),
"",
},
}
for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
var idx IndexConnectService
val, err := idx.FromArgs(tc.Args...)
if tc.ExpectErr != "" {
require.Error(err)
require.Contains(err.Error(), tc.ExpectErr)
return
}
require.NoError(err)
require.Equal(tc.ExpectVal, val)
})
}
}

View File

@ -0,0 +1,366 @@
package state
import (
"fmt"
"sort"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
)
const (
intentionsTableName = "connect-intentions"
)
// intentionsTableSchema returns a new table schema used for storing
// intentions for Connect.
func intentionsTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: intentionsTableName,
Indexes: map[string]*memdb.IndexSchema{
"id": &memdb.IndexSchema{
Name: "id",
AllowMissing: false,
Unique: true,
Indexer: &memdb.UUIDFieldIndex{
Field: "ID",
},
},
"destination": &memdb.IndexSchema{
Name: "destination",
AllowMissing: true,
// This index is not unique since we need uniqueness across the whole
// 4-tuple.
Unique: false,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "DestinationNS",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "DestinationName",
Lowercase: true,
},
},
},
},
"source": &memdb.IndexSchema{
Name: "source",
AllowMissing: true,
// This index is not unique since we need uniqueness across the whole
// 4-tuple.
Unique: false,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "SourceNS",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "SourceName",
Lowercase: true,
},
},
},
},
"source_destination": &memdb.IndexSchema{
Name: "source_destination",
AllowMissing: true,
Unique: true,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "SourceNS",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "SourceName",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "DestinationNS",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "DestinationName",
Lowercase: true,
},
},
},
},
},
}
}
func init() {
registerSchema(intentionsTableSchema)
}
// Intentions is used to pull all the intentions from the snapshot.
func (s *Snapshot) Intentions() (structs.Intentions, error) {
ixns, err := s.tx.Get(intentionsTableName, "id")
if err != nil {
return nil, err
}
var ret structs.Intentions
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
ret = append(ret, wrapped.(*structs.Intention))
}
return ret, nil
}
// Intention is used when restoring from a snapshot.
func (s *Restore) Intention(ixn *structs.Intention) error {
// Insert the intention
if err := s.tx.Insert(intentionsTableName, ixn); err != nil {
return fmt.Errorf("failed restoring intention: %s", err)
}
if err := indexUpdateMaxTxn(s.tx, ixn.ModifyIndex, intentionsTableName); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
return nil
}
// Intentions returns the list of all intentions.
func (s *Store) Intentions(ws memdb.WatchSet) (uint64, structs.Intentions, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the index
idx := maxIndexTxn(tx, intentionsTableName)
if idx < 1 {
idx = 1
}
// Get all intentions
iter, err := tx.Get(intentionsTableName, "id")
if err != nil {
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
}
ws.Add(iter.WatchCh())
var results structs.Intentions
for ixn := iter.Next(); ixn != nil; ixn = iter.Next() {
results = append(results, ixn.(*structs.Intention))
}
// Sort by precedence just because that's nicer and probably what most clients
// want for presentation.
sort.Sort(structs.IntentionPrecedenceSorter(results))
return idx, results, nil
}
// IntentionSet creates or updates an intention.
func (s *Store) IntentionSet(idx uint64, ixn *structs.Intention) error {
tx := s.db.Txn(true)
defer tx.Abort()
if err := s.intentionSetTxn(tx, idx, ixn); err != nil {
return err
}
tx.Commit()
return nil
}
// intentionSetTxn is the inner method used to insert an intention with
// the proper indexes into the state store.
func (s *Store) intentionSetTxn(tx *memdb.Txn, idx uint64, ixn *structs.Intention) error {
// ID is required
if ixn.ID == "" {
return ErrMissingIntentionID
}
// Ensure Precedence is populated correctly on "write"
ixn.UpdatePrecedence()
// Check for an existing intention
existing, err := tx.First(intentionsTableName, "id", ixn.ID)
if err != nil {
return fmt.Errorf("failed intention lookup: %s", err)
}
if existing != nil {
oldIxn := existing.(*structs.Intention)
ixn.CreateIndex = oldIxn.CreateIndex
ixn.CreatedAt = oldIxn.CreatedAt
} else {
ixn.CreateIndex = idx
}
ixn.ModifyIndex = idx
// Check for duplicates on the 4-tuple.
duplicate, err := tx.First(intentionsTableName, "source_destination",
ixn.SourceNS, ixn.SourceName, ixn.DestinationNS, ixn.DestinationName)
if err != nil {
return fmt.Errorf("failed intention lookup: %s", err)
}
if duplicate != nil {
dupIxn := duplicate.(*structs.Intention)
// Same ID is OK - this is an update
if dupIxn.ID != ixn.ID {
return fmt.Errorf("duplicate intention found: %s", dupIxn.String())
}
}
// We always force meta to be non-nil so that we its an empty map.
// This makes it easy for API responses to not nil-check this everywhere.
if ixn.Meta == nil {
ixn.Meta = make(map[string]string)
}
// Insert
if err := tx.Insert(intentionsTableName, ixn); err != nil {
return err
}
if err := tx.Insert("index", &IndexEntry{intentionsTableName, idx}); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
return nil
}
// IntentionGet returns the given intention by ID.
func (s *Store) IntentionGet(ws memdb.WatchSet, id string) (uint64, *structs.Intention, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexTxn(tx, intentionsTableName)
if idx < 1 {
idx = 1
}
// Look up by its ID.
watchCh, intention, err := tx.FirstWatch(intentionsTableName, "id", id)
if err != nil {
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
}
ws.Add(watchCh)
// Convert the interface{} if it is non-nil
var result *structs.Intention
if intention != nil {
result = intention.(*structs.Intention)
}
return idx, result, nil
}
// IntentionDelete deletes the given intention by ID.
func (s *Store) IntentionDelete(idx uint64, id string) error {
tx := s.db.Txn(true)
defer tx.Abort()
if err := s.intentionDeleteTxn(tx, idx, id); err != nil {
return fmt.Errorf("failed intention delete: %s", err)
}
tx.Commit()
return nil
}
// intentionDeleteTxn is the inner method used to delete a intention
// with the proper indexes into the state store.
func (s *Store) intentionDeleteTxn(tx *memdb.Txn, idx uint64, queryID string) error {
// Pull the query.
wrapped, err := tx.First(intentionsTableName, "id", queryID)
if err != nil {
return fmt.Errorf("failed intention lookup: %s", err)
}
if wrapped == nil {
return nil
}
// Delete the query and update the index.
if err := tx.Delete(intentionsTableName, wrapped); err != nil {
return fmt.Errorf("failed intention delete: %s", err)
}
if err := tx.Insert("index", &IndexEntry{intentionsTableName, idx}); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
return nil
}
// IntentionMatch returns the list of intentions that match the namespace and
// name for either a source or destination. This applies the resolution rules
// so wildcards will match any value.
//
// The returned value is the list of intentions in the same order as the
// entries in args. The intentions themselves are sorted based on the
// intention precedence rules. i.e. result[0][0] is the highest precedent
// rule to match for the first entry.
func (s *Store) IntentionMatch(ws memdb.WatchSet, args *structs.IntentionQueryMatch) (uint64, []structs.Intentions, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexTxn(tx, intentionsTableName)
if idx < 1 {
idx = 1
}
// Make all the calls and accumulate the results
results := make([]structs.Intentions, len(args.Entries))
for i, entry := range args.Entries {
// Each search entry may require multiple queries to memdb, so this
// returns the arguments for each necessary Get. Note on performance:
// this is not the most optimal set of queries since we repeat some
// many times (such as */*). We can work on improving that in the
// future, the test cases shouldn't have to change for that.
getParams, err := s.intentionMatchGetParams(entry)
if err != nil {
return 0, nil, err
}
// Perform each call and accumulate the result.
var ixns structs.Intentions
for _, params := range getParams {
iter, err := tx.Get(intentionsTableName, string(args.Type), params...)
if err != nil {
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
}
ws.Add(iter.WatchCh())
for ixn := iter.Next(); ixn != nil; ixn = iter.Next() {
ixns = append(ixns, ixn.(*structs.Intention))
}
}
// Sort the results by precedence
sort.Sort(structs.IntentionPrecedenceSorter(ixns))
// Store the result
results[i] = ixns
}
return idx, results, nil
}
// intentionMatchGetParams returns the tx.Get parameters to find all the
// intentions for a certain entry.
func (s *Store) intentionMatchGetParams(entry structs.IntentionMatchEntry) ([][]interface{}, error) {
// We always query for "*/*" so include that. If the namespace is a
// wildcard, then we're actually done.
result := make([][]interface{}, 0, 3)
result = append(result, []interface{}{"*", "*"})
if entry.Namespace == structs.IntentionWildcard {
return result, nil
}
// Search for NS/* intentions. If we have a wildcard name, then we're done.
result = append(result, []interface{}{entry.Namespace, "*"})
if entry.Name == structs.IntentionWildcard {
return result, nil
}
// Search for the exact NS/N value.
result = append(result, []interface{}{entry.Namespace, entry.Name})
return result, nil
}

View File

@ -0,0 +1,559 @@
package state
import (
"testing"
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStore_IntentionGet_none(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Querying with no results returns nil.
ws := memdb.NewWatchSet()
idx, res, err := s.IntentionGet(ws, testUUID())
assert.Equal(uint64(1), idx)
assert.Nil(res)
assert.Nil(err)
}
func TestStore_IntentionSetGet_basic(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call Get to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.IntentionGet(ws, testUUID())
assert.Nil(err)
// Build a valid intention
ixn := &structs.Intention{
ID: testUUID(),
SourceNS: "default",
SourceName: "*",
DestinationNS: "default",
DestinationName: "web",
Meta: map[string]string{},
}
// Inserting a with empty ID is disallowed.
assert.NoError(s.IntentionSet(1, ixn))
// Make sure the index got updated.
assert.Equal(uint64(1), s.maxIndex(intentionsTableName))
assert.True(watchFired(ws), "watch fired")
// Read it back out and verify it.
expected := &structs.Intention{
ID: ixn.ID,
SourceNS: "default",
SourceName: "*",
DestinationNS: "default",
DestinationName: "web",
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 1,
ModifyIndex: 1,
},
}
expected.UpdatePrecedence()
ws = memdb.NewWatchSet()
idx, actual, err := s.IntentionGet(ws, ixn.ID)
assert.NoError(err)
assert.Equal(expected.CreateIndex, idx)
assert.Equal(expected, actual)
// Change a value and test updating
ixn.SourceNS = "foo"
assert.NoError(s.IntentionSet(2, ixn))
// Change a value that isn't in the unique 4 tuple and check we don't
// incorrectly consider this a duplicate when updating.
ixn.Action = structs.IntentionActionDeny
assert.NoError(s.IntentionSet(2, ixn))
// Make sure the index got updated.
assert.Equal(uint64(2), s.maxIndex(intentionsTableName))
assert.True(watchFired(ws), "watch fired")
// Read it back and verify the data was updated
expected.SourceNS = ixn.SourceNS
expected.Action = structs.IntentionActionDeny
expected.ModifyIndex = 2
ws = memdb.NewWatchSet()
idx, actual, err = s.IntentionGet(ws, ixn.ID)
assert.NoError(err)
assert.Equal(expected.ModifyIndex, idx)
assert.Equal(expected, actual)
// Attempt to insert another intention with duplicate 4-tuple
ixn = &structs.Intention{
ID: testUUID(),
SourceNS: "default",
SourceName: "*",
DestinationNS: "default",
DestinationName: "web",
Meta: map[string]string{},
}
// Duplicate 4-tuple should cause an error
ws = memdb.NewWatchSet()
assert.Error(s.IntentionSet(3, ixn))
// Make sure the index did NOT get updated.
assert.Equal(uint64(2), s.maxIndex(intentionsTableName))
assert.False(watchFired(ws), "watch not fired")
}
func TestStore_IntentionSet_emptyId(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
ws := memdb.NewWatchSet()
_, _, err := s.IntentionGet(ws, testUUID())
assert.NoError(err)
// Inserting a with empty ID is disallowed.
err = s.IntentionSet(1, &structs.Intention{})
assert.Error(err)
assert.Contains(err.Error(), ErrMissingIntentionID.Error())
// Index is not updated if nothing is saved.
assert.Equal(s.maxIndex(intentionsTableName), uint64(0))
assert.False(watchFired(ws), "watch fired")
}
func TestStore_IntentionSet_updateCreatedAt(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Build a valid intention
now := time.Now()
ixn := structs.Intention{
ID: testUUID(),
CreatedAt: now,
}
// Insert
assert.NoError(s.IntentionSet(1, &ixn))
// Change a value and test updating
ixnUpdate := ixn
ixnUpdate.CreatedAt = now.Add(10 * time.Second)
assert.NoError(s.IntentionSet(2, &ixnUpdate))
// Read it back and verify
_, actual, err := s.IntentionGet(nil, ixn.ID)
assert.NoError(err)
assert.Equal(now, actual.CreatedAt)
}
func TestStore_IntentionSet_metaNil(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Build a valid intention
ixn := structs.Intention{
ID: testUUID(),
}
// Insert
assert.NoError(s.IntentionSet(1, &ixn))
// Read it back and verify
_, actual, err := s.IntentionGet(nil, ixn.ID)
assert.NoError(err)
assert.NotNil(actual.Meta)
}
func TestStore_IntentionSet_metaSet(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Build a valid intention
ixn := structs.Intention{
ID: testUUID(),
Meta: map[string]string{"foo": "bar"},
}
// Insert
assert.NoError(s.IntentionSet(1, &ixn))
// Read it back and verify
_, actual, err := s.IntentionGet(nil, ixn.ID)
assert.NoError(err)
assert.Equal(ixn.Meta, actual.Meta)
}
func TestStore_IntentionDelete(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call Get to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.IntentionGet(ws, testUUID())
assert.NoError(err)
// Create
ixn := &structs.Intention{ID: testUUID()}
assert.NoError(s.IntentionSet(1, ixn))
// Make sure the index got updated.
assert.Equal(s.maxIndex(intentionsTableName), uint64(1))
assert.True(watchFired(ws), "watch fired")
// Delete
assert.NoError(s.IntentionDelete(2, ixn.ID))
// Make sure the index got updated.
assert.Equal(s.maxIndex(intentionsTableName), uint64(2))
assert.True(watchFired(ws), "watch fired")
// Sanity check to make sure it's not there.
idx, actual, err := s.IntentionGet(nil, ixn.ID)
assert.NoError(err)
assert.Equal(idx, uint64(2))
assert.Nil(actual)
}
func TestStore_IntentionsList(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Querying with no results returns nil.
ws := memdb.NewWatchSet()
idx, res, err := s.Intentions(ws)
assert.NoError(err)
assert.Nil(res)
assert.Equal(uint64(1), idx)
// Create some intentions
ixns := structs.Intentions{
&structs.Intention{
ID: testUUID(),
Meta: map[string]string{},
},
&structs.Intention{
ID: testUUID(),
Meta: map[string]string{},
},
}
// Force deterministic sort order
ixns[0].ID = "a" + ixns[0].ID[1:]
ixns[1].ID = "b" + ixns[1].ID[1:]
// Create
for i, ixn := range ixns {
assert.NoError(s.IntentionSet(uint64(1+i), ixn))
}
assert.True(watchFired(ws), "watch fired")
// Read it back and verify.
expected := structs.Intentions{
&structs.Intention{
ID: ixns[0].ID,
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 1,
ModifyIndex: 1,
},
},
&structs.Intention{
ID: ixns[1].ID,
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 2,
ModifyIndex: 2,
},
},
}
for i := range expected {
expected[i].UpdatePrecedence() // to match what is returned...
}
idx, actual, err := s.Intentions(nil)
assert.NoError(err)
assert.Equal(idx, uint64(2))
assert.Equal(expected, actual)
}
// Test the matrix of match logic.
//
// Note that this doesn't need to test the intention sort logic exhaustively
// since this is tested in their sort implementation in the structs.
func TestStore_IntentionMatch_table(t *testing.T) {
type testCase struct {
Name string
Insert [][]string // List of intentions to insert
Query [][]string // List of intentions to match
Expected [][][]string // List of matches, where each match is a list of intentions
}
cases := []testCase{
{
"single exact namespace/name",
[][]string{
{"foo", "*"},
{"foo", "bar"},
{"foo", "baz"}, // shouldn't match
{"bar", "bar"}, // shouldn't match
{"bar", "*"}, // shouldn't match
{"*", "*"},
},
[][]string{
{"foo", "bar"},
},
[][][]string{
{
{"foo", "bar"},
{"foo", "*"},
{"*", "*"},
},
},
},
{
"multiple exact namespace/name",
[][]string{
{"foo", "*"},
{"foo", "bar"},
{"foo", "baz"}, // shouldn't match
{"bar", "bar"},
{"bar", "*"},
},
[][]string{
{"foo", "bar"},
{"bar", "bar"},
},
[][][]string{
{
{"foo", "bar"},
{"foo", "*"},
},
{
{"bar", "bar"},
{"bar", "*"},
},
},
},
{
"single exact namespace/name with duplicate destinations",
[][]string{
// 4-tuple specifies src and destination to test duplicate destinations
// with different sources. We flip them around to test in both
// directions. The first pair are the ones searched on in both cases so
// the duplicates need to be there.
{"foo", "bar", "foo", "*"},
{"foo", "bar", "bar", "*"},
{"*", "*", "*", "*"},
},
[][]string{
{"foo", "bar"},
},
[][][]string{
{
// Note the first two have the same precedence so we rely on arbitrary
// lexicographical tie-break behaviour.
{"foo", "bar", "bar", "*"},
{"foo", "bar", "foo", "*"},
{"*", "*", "*", "*"},
},
},
},
}
// testRunner implements the test for a single case, but can be
// parameterized to run for both source and destination so we can
// test both cases.
testRunner := func(t *testing.T, tc testCase, typ structs.IntentionMatchType) {
// Insert the set
assert := assert.New(t)
s := testStateStore(t)
var idx uint64 = 1
for _, v := range tc.Insert {
ixn := &structs.Intention{ID: testUUID()}
switch typ {
case structs.IntentionMatchDestination:
ixn.DestinationNS = v[0]
ixn.DestinationName = v[1]
if len(v) == 4 {
ixn.SourceNS = v[2]
ixn.SourceName = v[3]
}
case structs.IntentionMatchSource:
ixn.SourceNS = v[0]
ixn.SourceName = v[1]
if len(v) == 4 {
ixn.DestinationNS = v[2]
ixn.DestinationName = v[3]
}
}
assert.NoError(s.IntentionSet(idx, ixn))
idx++
}
// Build the arguments
args := &structs.IntentionQueryMatch{Type: typ}
for _, q := range tc.Query {
args.Entries = append(args.Entries, structs.IntentionMatchEntry{
Namespace: q[0],
Name: q[1],
})
}
// Match
_, matches, err := s.IntentionMatch(nil, args)
assert.NoError(err)
// Should have equal lengths
require.Len(t, matches, len(tc.Expected))
// Verify matches
for i, expected := range tc.Expected {
var actual [][]string
for _, ixn := range matches[i] {
switch typ {
case structs.IntentionMatchDestination:
if len(expected) > 1 && len(expected[0]) == 4 {
actual = append(actual, []string{
ixn.DestinationNS,
ixn.DestinationName,
ixn.SourceNS,
ixn.SourceName,
})
} else {
actual = append(actual, []string{ixn.DestinationNS, ixn.DestinationName})
}
case structs.IntentionMatchSource:
if len(expected) > 1 && len(expected[0]) == 4 {
actual = append(actual, []string{
ixn.SourceNS,
ixn.SourceName,
ixn.DestinationNS,
ixn.DestinationName,
})
} else {
actual = append(actual, []string{ixn.SourceNS, ixn.SourceName})
}
}
}
assert.Equal(expected, actual)
}
}
for _, tc := range cases {
t.Run(tc.Name+" (destination)", func(t *testing.T) {
testRunner(t, tc, structs.IntentionMatchDestination)
})
t.Run(tc.Name+" (source)", func(t *testing.T) {
testRunner(t, tc, structs.IntentionMatchSource)
})
}
}
func TestStore_Intention_Snapshot_Restore(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Create some intentions.
ixns := structs.Intentions{
&structs.Intention{
DestinationName: "foo",
},
&structs.Intention{
DestinationName: "bar",
},
&structs.Intention{
DestinationName: "baz",
},
}
// Force the sort order of the UUIDs before we create them so the
// order is deterministic.
id := testUUID()
ixns[0].ID = "a" + id[1:]
ixns[1].ID = "b" + id[1:]
ixns[2].ID = "c" + id[1:]
// Now create
for i, ixn := range ixns {
assert.NoError(s.IntentionSet(uint64(4+i), ixn))
}
// Snapshot the queries.
snap := s.Snapshot()
defer snap.Close()
// Alter the real state store.
assert.NoError(s.IntentionDelete(7, ixns[0].ID))
// Verify the snapshot.
assert.Equal(snap.LastIndex(), uint64(6))
// Expect them sorted in insertion order
expected := structs.Intentions{
&structs.Intention{
ID: ixns[0].ID,
DestinationName: "foo",
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 4,
ModifyIndex: 4,
},
},
&structs.Intention{
ID: ixns[1].ID,
DestinationName: "bar",
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 5,
ModifyIndex: 5,
},
},
&structs.Intention{
ID: ixns[2].ID,
DestinationName: "baz",
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 6,
ModifyIndex: 6,
},
},
}
for i := range expected {
expected[i].UpdatePrecedence() // to match what is returned...
}
dump, err := snap.Intentions()
assert.NoError(err)
assert.Equal(expected, dump)
// Restore the values into a new state store.
func() {
s := testStateStore(t)
restore := s.Restore()
for _, ixn := range dump {
assert.NoError(restore.Intention(ixn))
}
restore.Commit()
// Read the restored values back out and verify that they match. Note that
// Intentions are returned precedence sorted unlike the snapshot so we need
// to rearrange the expected slice some.
expected[0], expected[1], expected[2] = expected[1], expected[2], expected[0]
idx, actual, err := s.Intentions(nil)
assert.NoError(err)
assert.Equal(idx, uint64(6))
assert.Equal(expected, actual)
}()
}

View File

@ -28,6 +28,14 @@ var (
// ErrMissingQueryID is returned when a Query set is called on
// a Query with an empty ID.
ErrMissingQueryID = errors.New("Missing Query ID")
// ErrMissingCARootID is returned when an CARoot set is called
// with an CARoot with an empty ID.
ErrMissingCARootID = errors.New("Missing CA Root ID")
// ErrMissingIntentionID is returned when an Intention set is called
// with an Intention with an empty ID.
ErrMissingIntentionID = errors.New("Missing Intention ID")
)
const (

View File

@ -51,6 +51,7 @@ type dnsConfig struct {
ServiceTTL map[string]time.Duration
UDPAnswerLimit int
ARecordLimit int
NodeMetaTXT bool
}
// DNSServer is used to wrap an Agent and expose various
@ -109,6 +110,7 @@ func GetDNSConfig(conf *config.RuntimeConfig) *dnsConfig {
SegmentName: conf.SegmentName,
ServiceTTL: conf.DNSServiceTTL,
UDPAnswerLimit: conf.DNSUDPAnswerLimit,
NodeMetaTXT: conf.DNSNodeMetaTXT,
}
}
@ -337,7 +339,7 @@ func (d *DNSServer) addSOA(msg *dns.Msg) {
// nameservers returns the names and ip addresses of up to three random servers
// in the current cluster which serve as authoritative name servers for zone.
func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "")
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false)
if err != nil {
d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err)
return nil, nil
@ -374,7 +376,7 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
}
ns = append(ns, nsrr)
glue := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns)
glue := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns, false)
extra = append(extra, glue...)
// don't provide more than 3 servers
@ -415,7 +417,7 @@ PARSE:
n = n + 1
}
switch labels[n-1] {
switch kind := labels[n-1]; kind {
case "service":
if n == 1 {
goto INVALID
@ -433,7 +435,7 @@ PARSE:
}
// _name._tag.service.consul
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, req, resp)
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp)
// Consul 0.3 and prior format for SRV queries
} else {
@ -445,9 +447,17 @@ PARSE:
}
// tag[.tag].name.service.consul
d.serviceLookup(network, datacenter, labels[n-2], tag, req, resp)
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp)
}
case "connect":
if n == 1 {
goto INVALID
}
// name.connect.consul
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp)
case "node":
if n == 1 {
goto INVALID
@ -582,7 +592,7 @@ RPC:
n := out.NodeServices.Node
edns := req.IsEdns0() != nil
addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses)
records := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns)
records := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns, true)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
@ -610,7 +620,7 @@ func encodeKVasRFC1464(key, value string) (txt string) {
}
// formatNodeRecord takes a Node and returns an A, AAAA, TXT or CNAME record
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records []dns.RR) {
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns, answer bool) (records []dns.RR) {
// Parse the IP
ip := net.ParseIP(addr)
var ipv4 net.IP
@ -671,7 +681,20 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
}
}
if node != nil && (qType == dns.TypeANY || qType == dns.TypeTXT) {
node_meta_txt := false
if node == nil {
node_meta_txt = false
} else if answer {
node_meta_txt = true
} else {
// Use configuration when the TXT RR would
// end up in the Additional section of the
// DNS response
node_meta_txt = d.config.NodeMetaTXT
}
if node_meta_txt {
for key, value := range node.Meta {
txt := value
if !strings.HasPrefix(strings.ToLower(key), "rfc1035-") {
@ -782,8 +805,8 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
originalNumRecords := len(resp.Answer)
// It is not possible to return more than 4k records even with compression
// Since we are performing binary search it is not a big deal, but it
// improves a bit performance, even with binary search
// Since we are performing binary search it is not a big deal, but it
// improves a bit performance, even with binary search
truncateAt := 4096
if req.Question[0].Qtype == dns.TypeSRV {
// More than 1024 SRV records do not fit in 64k
@ -898,8 +921,9 @@ func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed
}
// lookupServiceNodes returns nodes with a given service.
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string) (structs.IndexedCheckServiceNodes, error) {
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool) (structs.IndexedCheckServiceNodes, error) {
args := structs.ServiceSpecificRequest{
Connect: connect,
Datacenter: datacenter,
ServiceName: service,
ServiceTag: tag,
@ -935,8 +959,8 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string) (structs
}
// serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, resp *dns.Msg) {
out, err := d.lookupServiceNodes(datacenter, service, tag)
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg) {
out, err := d.lookupServiceNodes(datacenter, service, tag, connect)
if err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
@ -1143,7 +1167,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
handled[addr] = struct{}{}
// Add the node record
records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns)
records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns, true)
if records != nil {
resp.Answer = append(resp.Answer, records...)
count++
@ -1192,7 +1216,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
}
// Add the extra record
records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns)
records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, false)
if len(records) > 0 {
// Use the node address if it doesn't differ from the service address
if addr == node.Node.Address {

View File

@ -17,6 +17,7 @@ import (
"github.com/hashicorp/serf/coordinate"
"github.com/miekg/dns"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -472,6 +473,51 @@ func TestDNS_NodeLookup_TXT(t *testing.T) {
}
}
func TestDNS_NodeLookup_TXT_DontSuppress(t *testing.T) {
a := NewTestAgent(t.Name(), `dns_config = { enable_additional_node_meta_txt = false }`)
defer a.Shutdown()
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "google",
Address: "127.0.0.1",
NodeMeta: map[string]string{
"rfc1035-00": "value0",
"key0": "value1",
},
}
var out struct{}
if err := a.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
m := new(dns.Msg)
m.SetQuestion("google.node.consul.", dns.TypeTXT)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
// Should have the 1 TXT record reply
if len(in.Answer) != 2 {
t.Fatalf("Bad: %#v", in)
}
txtRec, ok := in.Answer[0].(*dns.TXT)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if len(txtRec.Txt) != 1 {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if txtRec.Txt[0] != "value0" && txtRec.Txt[0] != "key0=value1" {
t.Fatalf("Bad: %#v", in.Answer[0])
}
}
func TestDNS_NodeLookup_ANY(t *testing.T) {
a := NewTestAgent(t.Name(), ``)
defer a.Shutdown()
@ -510,7 +556,46 @@ func TestDNS_NodeLookup_ANY(t *testing.T) {
},
}
verify.Values(t, "answer", in.Answer, wantAnswer)
}
func TestDNS_NodeLookup_ANY_DontSuppressTXT(t *testing.T) {
a := NewTestAgent(t.Name(), `dns_config = { enable_additional_node_meta_txt = false }`)
defer a.Shutdown()
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "bar",
Address: "127.0.0.1",
NodeMeta: map[string]string{
"key": "value",
},
}
var out struct{}
if err := a.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
m := new(dns.Msg)
m.SetQuestion("bar.node.consul.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
wantAnswer := []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "bar.node.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4},
A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1
},
&dns.TXT{
Hdr: dns.RR_Header{Name: "bar.node.consul.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Rdlength: 0xa},
Txt: []string{"key=value"},
},
}
verify.Values(t, "answer", in.Answer, wantAnswer)
}
func TestDNS_EDNS0(t *testing.T) {
@ -1041,6 +1126,51 @@ func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) {
verify.Values(t, "extra", in.Extra, wantExtra)
}
func TestDNS_ConnectServiceLookup(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
{
args := structs.TestRegisterRequestProxy(t)
args.Address = "127.0.0.55"
args.Service.ProxyDestination = "db"
args.Service.Address = ""
args.Service.Port = 12345
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
}
// Look up the service
questions := []string{
"db.connect.consul.",
}
for _, question := range questions {
m := new(dns.Msg)
m.SetQuestion(question, dns.TypeSRV)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
assert.Nil(err)
assert.Len(in.Answer, 1)
srvRec, ok := in.Answer[0].(*dns.SRV)
assert.True(ok)
assert.Equal(uint16(12345), srvRec.Port)
assert.Equal("foo.node.dc1.consul.", srvRec.Target)
assert.Equal(uint32(0), srvRec.Hdr.Ttl)
cnameRec, ok := in.Extra[0].(*dns.A)
assert.True(ok)
assert.Equal("foo.node.dc1.consul.", cnameRec.Hdr.Name)
assert.Equal(uint32(0), srvRec.Hdr.Ttl)
assert.Equal("127.0.0.55", cnameRec.A.String())
}
}
func TestDNS_ExternalServiceLookup(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
@ -4613,6 +4743,93 @@ func TestDNS_ServiceLookup_FilterACL(t *testing.T) {
}
}
func TestDNS_ServiceLookup_MetaTXT(t *testing.T) {
a := NewTestAgent(t.Name(), `dns_config = { enable_additional_node_meta_txt = true }`)
defer a.Shutdown()
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "bar",
Address: "127.0.0.1",
NodeMeta: map[string]string{
"key": "value",
},
Service: &structs.NodeService{
Service: "db",
Tags: []string{"master"},
Port: 12345,
},
}
var out struct{}
if err := a.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
m := new(dns.Msg)
m.SetQuestion("db.service.consul.", dns.TypeSRV)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
wantAdditional := []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4},
A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1
},
&dns.TXT{
Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Rdlength: 0xa},
Txt: []string{"key=value"},
},
}
verify.Values(t, "additional", in.Extra, wantAdditional)
}
func TestDNS_ServiceLookup_SuppressTXT(t *testing.T) {
a := NewTestAgent(t.Name(), `dns_config = { enable_additional_node_meta_txt = false }`)
defer a.Shutdown()
// Register a node with a service.
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "bar",
Address: "127.0.0.1",
NodeMeta: map[string]string{
"key": "value",
},
Service: &structs.NodeService{
Service: "db",
Tags: []string{"master"},
Port: 12345,
},
}
var out struct{}
if err := a.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
m := new(dns.Msg)
m.SetQuestion("db.service.consul.", dns.TypeSRV)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
wantAdditional := []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4},
A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1
},
}
verify.Values(t, "additional", in.Extra, wantAdditional)
}
func TestDNS_AddressLookup(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")

View File

@ -143,9 +143,17 @@ RETRY_ONCE:
return out.HealthChecks, nil
}
func (s *HTTPServer) HealthConnectServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return s.healthServiceNodes(resp, req, true)
}
func (s *HTTPServer) HealthServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return s.healthServiceNodes(resp, req, false)
}
func (s *HTTPServer) healthServiceNodes(resp http.ResponseWriter, req *http.Request, connect bool) (interface{}, error) {
// Set default DC
args := structs.ServiceSpecificRequest{}
args := structs.ServiceSpecificRequest{Connect: connect}
s.parseSource(req, &args.Source)
args.NodeMetaFilters = s.parseMetaFilter(req)
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
@ -159,8 +167,14 @@ func (s *HTTPServer) HealthServiceNodes(resp http.ResponseWriter, req *http.Requ
args.TagFilter = true
}
// Determine the prefix
prefix := "/v1/health/service/"
if connect {
prefix = "/v1/health/connect/"
}
// Pull out the service name
args.ServiceName = strings.TrimPrefix(req.URL.Path, "/v1/health/service/")
args.ServiceName = strings.TrimPrefix(req.URL.Path, prefix)
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service name")

View File

@ -13,6 +13,7 @@ import (
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/testutil/retry"
"github.com/hashicorp/serf/coordinate"
"github.com/stretchr/testify/assert"
)
func TestHealthChecksInState(t *testing.T) {
@ -770,6 +771,105 @@ func TestHealthServiceNodes_WanTranslation(t *testing.T) {
}
}
func TestHealthConnectServiceNodes(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
// Request
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?dc=dc1", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
// Should be a non-nil empty list for checks
nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 1)
assert.Len(nodes[0].Checks, 0)
}
func TestHealthConnectServiceNodes_PassingFilter(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
args := structs.TestRegisterRequestProxy(t)
args.Check = &structs.HealthCheck{
Node: args.Node,
Name: "check",
ServiceID: args.Service.Service,
Status: api.HealthCritical,
}
var out struct{}
assert.Nil(t, a.RPC("Catalog.Register", args, &out))
t.Run("bc_no_query_value", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
// Should be 0 health check for consul
nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 0)
})
t.Run("passing_true", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing=true", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
// Should be 0 health check for consul
nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 0)
})
t.Run("passing_false", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing=false", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
// Should be 1
nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 1)
})
t.Run("passing_bad", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing=nope-nope", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
a.srv.HealthConnectServiceNodes(resp, req)
assert.Equal(400, resp.Code)
body, err := ioutil.ReadAll(resp.Body)
assert.Nil(err)
assert.True(bytes.Contains(body, []byte("Invalid value for ?passing")))
})
}
func TestFilterNonPassing(t *testing.T) {
t.Parallel()
nodes := structs.CheckServiceNodes{

View File

@ -3,6 +3,7 @@ package agent
import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/pprof"
@ -16,6 +17,7 @@ import (
"github.com/NYTimes/gziphandler"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-cleanhttp"
"github.com/mitchellh/mapstructure"
@ -157,9 +159,9 @@ func (s *HTTPServer) handler(enableDebug bool) http.Handler {
}
if s.IsUIEnabled() {
new_ui, err := strconv.ParseBool(os.Getenv("CONSUL_UI_BETA"))
legacy_ui, err := strconv.ParseBool(os.Getenv("CONSUL_UI_LEGACY"))
if err != nil {
new_ui = false
legacy_ui = false
}
var uifs http.FileSystem
@ -169,15 +171,15 @@ func (s *HTTPServer) handler(enableDebug bool) http.Handler {
} else {
fs := assetFS()
if new_ui {
fs.Prefix += "/v2/"
} else {
if legacy_ui {
fs.Prefix += "/v1/"
} else {
fs.Prefix += "/v2/"
}
uifs = fs
}
if new_ui {
if !legacy_ui {
uifs = &redirectFS{fs: uifs}
}
@ -384,6 +386,13 @@ func (s *HTTPServer) Index(resp http.ResponseWriter, req *http.Request) {
// decodeBody is used to decode a JSON request body
func decodeBody(req *http.Request, out interface{}, cb func(interface{}) error) error {
// This generally only happens in tests since real HTTP requests set
// a non-nil body with no content. We guard against it anyways to prevent
// a panic. The EOF response is the same behavior as an empty reader.
if req.Body == nil {
return io.EOF
}
var raw interface{}
dec := json.NewDecoder(req.Body)
if err := dec.Decode(&raw); err != nil {
@ -409,6 +418,14 @@ func setTranslateAddr(resp http.ResponseWriter, active bool) {
// setIndex is used to set the index response header
func setIndex(resp http.ResponseWriter, index uint64) {
// If we ever return X-Consul-Index of 0 blocking clients will go into a busy
// loop and hammer us since ?index=0 will never block. It's always safe to
// return index=1 since the very first Raft write is always an internal one
// writing the raft config for the cluster so no user-facing blocking query
// will ever legitimately have an X-Consul-Index of 1.
if index == 0 {
index = 1
}
resp.Header().Set("X-Consul-Index", strconv.FormatUint(index, 10))
}
@ -444,6 +461,15 @@ func setMeta(resp http.ResponseWriter, m *structs.QueryMeta) {
setConsistency(resp, m.ConsistencyLevel)
}
// setCacheMeta sets http response headers to indicate cache status.
func setCacheMeta(resp http.ResponseWriter, m *cache.ResultMeta) {
str := "MISS"
if m != nil && m.Hit {
str = "HIT"
}
resp.Header().Set("X-Cache", str)
}
// setHeaders is used to set canonical response header fields
func setHeaders(resp http.ResponseWriter, headers map[string]string) {
for field, value := range headers {

View File

@ -29,16 +29,27 @@ func init() {
registerEndpoint("/v1/agent/check/warn/", []string{"PUT"}, (*HTTPServer).AgentCheckWarn)
registerEndpoint("/v1/agent/check/fail/", []string{"PUT"}, (*HTTPServer).AgentCheckFail)
registerEndpoint("/v1/agent/check/update/", []string{"PUT"}, (*HTTPServer).AgentCheckUpdate)
registerEndpoint("/v1/agent/connect/authorize", []string{"POST"}, (*HTTPServer).AgentConnectAuthorize)
registerEndpoint("/v1/agent/connect/ca/roots", []string{"GET"}, (*HTTPServer).AgentConnectCARoots)
registerEndpoint("/v1/agent/connect/ca/leaf/", []string{"GET"}, (*HTTPServer).AgentConnectCALeafCert)
registerEndpoint("/v1/agent/connect/proxy/", []string{"GET"}, (*HTTPServer).AgentConnectProxyConfig)
registerEndpoint("/v1/agent/service/register", []string{"PUT"}, (*HTTPServer).AgentRegisterService)
registerEndpoint("/v1/agent/service/deregister/", []string{"PUT"}, (*HTTPServer).AgentDeregisterService)
registerEndpoint("/v1/agent/service/maintenance/", []string{"PUT"}, (*HTTPServer).AgentServiceMaintenance)
registerEndpoint("/v1/catalog/register", []string{"PUT"}, (*HTTPServer).CatalogRegister)
registerEndpoint("/v1/catalog/connect/", []string{"GET"}, (*HTTPServer).CatalogConnectServiceNodes)
registerEndpoint("/v1/catalog/deregister", []string{"PUT"}, (*HTTPServer).CatalogDeregister)
registerEndpoint("/v1/catalog/datacenters", []string{"GET"}, (*HTTPServer).CatalogDatacenters)
registerEndpoint("/v1/catalog/nodes", []string{"GET"}, (*HTTPServer).CatalogNodes)
registerEndpoint("/v1/catalog/services", []string{"GET"}, (*HTTPServer).CatalogServices)
registerEndpoint("/v1/catalog/service/", []string{"GET"}, (*HTTPServer).CatalogServiceNodes)
registerEndpoint("/v1/catalog/node/", []string{"GET"}, (*HTTPServer).CatalogNodeServices)
registerEndpoint("/v1/connect/ca/configuration", []string{"GET", "PUT"}, (*HTTPServer).ConnectCAConfiguration)
registerEndpoint("/v1/connect/ca/roots", []string{"GET"}, (*HTTPServer).ConnectCARoots)
registerEndpoint("/v1/connect/intentions", []string{"GET", "POST"}, (*HTTPServer).IntentionEndpoint)
registerEndpoint("/v1/connect/intentions/match", []string{"GET"}, (*HTTPServer).IntentionMatch)
registerEndpoint("/v1/connect/intentions/check", []string{"GET"}, (*HTTPServer).IntentionCheck)
registerEndpoint("/v1/connect/intentions/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).IntentionSpecific)
registerEndpoint("/v1/coordinate/datacenters", []string{"GET"}, (*HTTPServer).CoordinateDatacenters)
registerEndpoint("/v1/coordinate/nodes", []string{"GET"}, (*HTTPServer).CoordinateNodes)
registerEndpoint("/v1/coordinate/node/", []string{"GET"}, (*HTTPServer).CoordinateNode)
@ -49,6 +60,7 @@ func init() {
registerEndpoint("/v1/health/checks/", []string{"GET"}, (*HTTPServer).HealthServiceChecks)
registerEndpoint("/v1/health/state/", []string{"GET"}, (*HTTPServer).HealthChecksInState)
registerEndpoint("/v1/health/service/", []string{"GET"}, (*HTTPServer).HealthServiceNodes)
registerEndpoint("/v1/health/connect/", []string{"GET"}, (*HTTPServer).HealthConnectServiceNodes)
registerEndpoint("/v1/internal/ui/nodes", []string{"GET"}, (*HTTPServer).UINodes)
registerEndpoint("/v1/internal/ui/node/", []string{"GET"}, (*HTTPServer).UINodeInfo)
registerEndpoint("/v1/internal/ui/services", []string{"GET"}, (*HTTPServer).UIServices)

View File

@ -0,0 +1,302 @@
package agent
import (
"fmt"
"net/http"
"strings"
"github.com/hashicorp/consul/agent/consul"
"github.com/hashicorp/consul/agent/structs"
)
// /v1/connection/intentions
func (s *HTTPServer) IntentionEndpoint(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
switch req.Method {
case "GET":
return s.IntentionList(resp, req)
case "POST":
return s.IntentionCreate(resp, req)
default:
return nil, MethodNotAllowedError{req.Method, []string{"GET", "POST"}}
}
}
// GET /v1/connect/intentions
func (s *HTTPServer) IntentionList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in IntentionEndpoint
var args structs.DCSpecificRequest
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
var reply structs.IndexedIntentions
if err := s.agent.RPC("Intention.List", &args, &reply); err != nil {
return nil, err
}
return reply.Intentions, nil
}
// POST /v1/connect/intentions
func (s *HTTPServer) IntentionCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in IntentionEndpoint
args := structs.IntentionRequest{
Op: structs.IntentionOpCreate,
}
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req, &args.Intention, nil); err != nil {
return nil, fmt.Errorf("Failed to decode request body: %s", err)
}
var reply string
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
return nil, err
}
return intentionCreateResponse{reply}, nil
}
// GET /v1/connect/intentions/match
func (s *HTTPServer) IntentionMatch(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Prepare args
args := &structs.IntentionQueryRequest{Match: &structs.IntentionQueryMatch{}}
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
q := req.URL.Query()
// Extract the "by" query parameter
if by, ok := q["by"]; !ok || len(by) != 1 {
return nil, fmt.Errorf("required query parameter 'by' not set")
} else {
switch v := structs.IntentionMatchType(by[0]); v {
case structs.IntentionMatchSource, structs.IntentionMatchDestination:
args.Match.Type = v
default:
return nil, fmt.Errorf("'by' parameter must be one of 'source' or 'destination'")
}
}
// Extract all the match names
names, ok := q["name"]
if !ok || len(names) == 0 {
return nil, fmt.Errorf("required query parameter 'name' not set")
}
// Build the entries in order. The order matters since that is the
// order of the returned responses.
args.Match.Entries = make([]structs.IntentionMatchEntry, len(names))
for i, n := range names {
entry, err := parseIntentionMatchEntry(n)
if err != nil {
return nil, fmt.Errorf("name %q is invalid: %s", n, err)
}
args.Match.Entries[i] = entry
}
var reply structs.IndexedIntentionMatches
if err := s.agent.RPC("Intention.Match", args, &reply); err != nil {
return nil, err
}
// We must have an identical count of matches
if len(reply.Matches) != len(names) {
return nil, fmt.Errorf("internal error: match response count didn't match input count")
}
// Use empty list instead of nil.
response := make(map[string]structs.Intentions)
for i, ixns := range reply.Matches {
response[names[i]] = ixns
}
return response, nil
}
// GET /v1/connect/intentions/check
func (s *HTTPServer) IntentionCheck(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Prepare args
args := &structs.IntentionQueryRequest{Check: &structs.IntentionQueryCheck{}}
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
q := req.URL.Query()
// Set the source type if set
args.Check.SourceType = structs.IntentionSourceConsul
if sourceType, ok := q["source-type"]; ok && len(sourceType) > 0 {
args.Check.SourceType = structs.IntentionSourceType(sourceType[0])
}
// Extract the source/destination
source, ok := q["source"]
if !ok || len(source) != 1 {
return nil, fmt.Errorf("required query parameter 'source' not set")
}
destination, ok := q["destination"]
if !ok || len(destination) != 1 {
return nil, fmt.Errorf("required query parameter 'destination' not set")
}
// We parse them the same way as matches to extract namespace/name
args.Check.SourceName = source[0]
if args.Check.SourceType == structs.IntentionSourceConsul {
entry, err := parseIntentionMatchEntry(source[0])
if err != nil {
return nil, fmt.Errorf("source %q is invalid: %s", source[0], err)
}
args.Check.SourceNS = entry.Namespace
args.Check.SourceName = entry.Name
}
// The destination is always in the Consul format
entry, err := parseIntentionMatchEntry(destination[0])
if err != nil {
return nil, fmt.Errorf("destination %q is invalid: %s", destination[0], err)
}
args.Check.DestinationNS = entry.Namespace
args.Check.DestinationName = entry.Name
var reply structs.IntentionQueryCheckResponse
if err := s.agent.RPC("Intention.Check", args, &reply); err != nil {
return nil, err
}
return &reply, nil
}
// IntentionSpecific handles the endpoint for /v1/connection/intentions/:id
func (s *HTTPServer) IntentionSpecific(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
id := strings.TrimPrefix(req.URL.Path, "/v1/connect/intentions/")
switch req.Method {
case "GET":
return s.IntentionSpecificGet(id, resp, req)
case "PUT":
return s.IntentionSpecificUpdate(id, resp, req)
case "DELETE":
return s.IntentionSpecificDelete(id, resp, req)
default:
return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}}
}
}
// GET /v1/connect/intentions/:id
func (s *HTTPServer) IntentionSpecificGet(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in IntentionEndpoint
args := structs.IntentionQueryRequest{
IntentionID: id,
}
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
var reply structs.IndexedIntentions
if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil {
// We have to check the string since the RPC sheds the error type
if err.Error() == consul.ErrIntentionNotFound.Error() {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error())
return nil, nil
}
return nil, err
}
// This shouldn't happen since the called API documents it shouldn't,
// but we check since the alternative if it happens is a panic.
if len(reply.Intentions) == 0 {
resp.WriteHeader(http.StatusNotFound)
return nil, nil
}
return reply.Intentions[0], nil
}
// PUT /v1/connect/intentions/:id
func (s *HTTPServer) IntentionSpecificUpdate(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in IntentionEndpoint
args := structs.IntentionRequest{
Op: structs.IntentionOpUpdate,
}
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req, &args.Intention, nil); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
}
// Use the ID from the URL
args.Intention.ID = id
var reply string
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
return nil, err
}
// Update uses the same create response
return intentionCreateResponse{reply}, nil
}
// DELETE /v1/connect/intentions/:id
func (s *HTTPServer) IntentionSpecificDelete(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in IntentionEndpoint
args := structs.IntentionRequest{
Op: structs.IntentionOpDelete,
Intention: &structs.Intention{ID: id},
}
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
var reply string
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
return nil, err
}
return true, nil
}
// intentionCreateResponse is the response structure for creating an intention.
type intentionCreateResponse struct{ ID string }
// parseIntentionMatchEntry parses the query parameter for an intention
// match query entry.
func parseIntentionMatchEntry(input string) (structs.IntentionMatchEntry, error) {
var result structs.IntentionMatchEntry
result.Namespace = structs.IntentionDefaultNamespace
// TODO(mitchellh): when namespaces are introduced, set the default
// namespace to be the namespace of the requestor.
// Get the index to the '/'. If it doesn't exist, we have just a name
// so just set that and return.
idx := strings.IndexByte(input, '/')
if idx == -1 {
result.Name = input
return result, nil
}
result.Namespace = input[:idx]
result.Name = input[idx+1:]
if strings.IndexByte(result.Name, '/') != -1 {
return result, fmt.Errorf("input can contain at most one '/'")
}
return result, nil
}

View File

@ -0,0 +1,502 @@
package agent
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIntentionsList_empty(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Make sure an empty list is non-nil.
req, _ := http.NewRequest("GET", "/v1/connect/intentions", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionList(resp, req)
assert.Nil(err)
value := obj.(structs.Intentions)
assert.NotNil(value)
assert.Len(value, 0)
}
func TestIntentionsList_values(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Create some intentions, note we create the lowest precedence first to test
// sorting.
for _, v := range []string{"*", "foo", "bar"} {
req := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: structs.TestIntention(t),
}
req.Intention.SourceName = v
var reply string
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
}
// Request
req, _ := http.NewRequest("GET", "/v1/connect/intentions", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionList(resp, req)
assert.NoError(err)
value := obj.(structs.Intentions)
assert.Len(value, 3)
expected := []string{"bar", "foo", "*"}
actual := []string{
value[0].SourceName,
value[1].SourceName,
value[2].SourceName,
}
assert.Equal(expected, actual)
}
func TestIntentionsMatch_basic(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Create some intentions
{
insert := [][]string{
{"foo", "*", "foo", "*"},
{"foo", "*", "foo", "bar"},
{"foo", "*", "foo", "baz"}, // shouldn't match
{"foo", "*", "bar", "bar"}, // shouldn't match
{"foo", "*", "bar", "*"}, // shouldn't match
{"foo", "*", "*", "*"},
{"bar", "*", "foo", "bar"}, // duplicate destination different source
}
for _, v := range insert {
ixn := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: structs.TestIntention(t),
}
ixn.Intention.SourceNS = v[0]
ixn.Intention.SourceName = v[1]
ixn.Intention.DestinationNS = v[2]
ixn.Intention.DestinationName = v[3]
// Create
var reply string
assert.Nil(a.RPC("Intention.Apply", &ixn, &reply))
}
}
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/match?by=destination&name=foo/bar", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionMatch(resp, req)
assert.Nil(err)
value := obj.(map[string]structs.Intentions)
assert.Len(value, 1)
var actual [][]string
expected := [][]string{
{"bar", "*", "foo", "bar"},
{"foo", "*", "foo", "bar"},
{"foo", "*", "foo", "*"},
{"foo", "*", "*", "*"},
}
for _, ixn := range value["foo/bar"] {
actual = append(actual, []string{
ixn.SourceNS,
ixn.SourceName,
ixn.DestinationNS,
ixn.DestinationName,
})
}
assert.Equal(expected, actual)
}
func TestIntentionsMatch_noBy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/match?name=foo/bar", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionMatch(resp, req)
assert.NotNil(err)
assert.Contains(err.Error(), "by")
assert.Nil(obj)
}
func TestIntentionsMatch_byInvalid(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/match?by=datacenter", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionMatch(resp, req)
assert.NotNil(err)
assert.Contains(err.Error(), "'by' parameter")
assert.Nil(obj)
}
func TestIntentionsMatch_noName(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/match?by=source", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionMatch(resp, req)
assert.NotNil(err)
assert.Contains(err.Error(), "'name' not set")
assert.Nil(obj)
}
func TestIntentionsCheck_basic(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Create some intentions
{
insert := [][]string{
{"foo", "*", "foo", "*"},
{"foo", "*", "foo", "bar"},
{"bar", "*", "foo", "bar"},
}
for _, v := range insert {
ixn := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: structs.TestIntention(t),
}
ixn.Intention.SourceNS = v[0]
ixn.Intention.SourceName = v[1]
ixn.Intention.DestinationNS = v[2]
ixn.Intention.DestinationName = v[3]
ixn.Intention.Action = structs.IntentionActionDeny
// Create
var reply string
require.Nil(a.RPC("Intention.Apply", &ixn, &reply))
}
}
// Request matching intention
{
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/test?source=foo/bar&destination=foo/baz", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionCheck(resp, req)
require.Nil(err)
value := obj.(*structs.IntentionQueryCheckResponse)
require.False(value.Allowed)
}
// Request non-matching intention
{
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/test?source=foo/bar&destination=bar/qux", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionCheck(resp, req)
require.Nil(err)
value := obj.(*structs.IntentionQueryCheckResponse)
require.True(value.Allowed)
}
}
func TestIntentionsCheck_noSource(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/test?destination=B", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionCheck(resp, req)
require.NotNil(err)
require.Contains(err.Error(), "'source' not set")
require.Nil(obj)
}
func TestIntentionsCheck_noDestination(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/test?source=B", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionCheck(resp, req)
require.NotNil(err)
require.Contains(err.Error(), "'destination' not set")
require.Nil(obj)
}
func TestIntentionsCreate_good(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Make sure an empty list is non-nil.
args := structs.TestIntention(t)
args.SourceName = "foo"
req, _ := http.NewRequest("POST", "/v1/connect/intentions", jsonReader(args))
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionCreate(resp, req)
assert.Nil(err)
value := obj.(intentionCreateResponse)
assert.NotEqual("", value.ID)
// Read the value
{
req := &structs.IntentionQueryRequest{
Datacenter: "dc1",
IntentionID: value.ID,
}
var resp structs.IndexedIntentions
assert.Nil(a.RPC("Intention.Get", req, &resp))
assert.Len(resp.Intentions, 1)
actual := resp.Intentions[0]
assert.Equal("foo", actual.SourceName)
}
}
func TestIntentionsCreate_noBody(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Create with no body
req, _ := http.NewRequest("POST", "/v1/connect/intentions", nil)
resp := httptest.NewRecorder()
_, err := a.srv.IntentionCreate(resp, req)
require.Error(t, err)
}
func TestIntentionsSpecificGet_good(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// The intention
ixn := structs.TestIntention(t)
// Create an intention directly
var reply string
{
req := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: ixn,
}
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
}
// Get the value
req, _ := http.NewRequest("GET", fmt.Sprintf("/v1/connect/intentions/%s", reply), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionSpecific(resp, req)
assert.Nil(err)
value := obj.(*structs.Intention)
assert.Equal(reply, value.ID)
ixn.ID = value.ID
ixn.RaftIndex = value.RaftIndex
ixn.CreatedAt, ixn.UpdatedAt = value.CreatedAt, value.UpdatedAt
assert.Equal(ixn, value)
}
func TestIntentionsSpecificUpdate_good(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// The intention
ixn := structs.TestIntention(t)
// Create an intention directly
var reply string
{
req := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: ixn,
}
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
}
// Update the intention
ixn.ID = "bogus"
ixn.SourceName = "bar"
req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/connect/intentions/%s", reply), jsonReader(ixn))
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionSpecific(resp, req)
assert.Nil(err)
value := obj.(intentionCreateResponse)
assert.Equal(reply, value.ID)
// Read the value
{
req := &structs.IntentionQueryRequest{
Datacenter: "dc1",
IntentionID: reply,
}
var resp structs.IndexedIntentions
assert.Nil(a.RPC("Intention.Get", req, &resp))
assert.Len(resp.Intentions, 1)
actual := resp.Intentions[0]
assert.Equal("bar", actual.SourceName)
}
}
func TestIntentionsSpecificDelete_good(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// The intention
ixn := structs.TestIntention(t)
ixn.SourceName = "foo"
// Create an intention directly
var reply string
{
req := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: ixn,
}
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
}
// Sanity check that the intention exists
{
req := &structs.IntentionQueryRequest{
Datacenter: "dc1",
IntentionID: reply,
}
var resp structs.IndexedIntentions
assert.Nil(a.RPC("Intention.Get", req, &resp))
assert.Len(resp.Intentions, 1)
actual := resp.Intentions[0]
assert.Equal("foo", actual.SourceName)
}
// Delete the intention
req, _ := http.NewRequest("DELETE", fmt.Sprintf("/v1/connect/intentions/%s", reply), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionSpecific(resp, req)
assert.Nil(err)
assert.Equal(true, obj)
// Verify the intention is gone
{
req := &structs.IntentionQueryRequest{
Datacenter: "dc1",
IntentionID: reply,
}
var resp structs.IndexedIntentions
err := a.RPC("Intention.Get", req, &resp)
assert.NotNil(err)
assert.Contains(err.Error(), "not found")
}
}
func TestParseIntentionMatchEntry(t *testing.T) {
cases := []struct {
Input string
Expected structs.IntentionMatchEntry
Err bool
}{
{
"foo",
structs.IntentionMatchEntry{
Namespace: structs.IntentionDefaultNamespace,
Name: "foo",
},
false,
},
{
"foo/bar",
structs.IntentionMatchEntry{
Namespace: "foo",
Name: "bar",
},
false,
},
{
"foo/bar/baz",
structs.IntentionMatchEntry{},
true,
},
}
for _, tc := range cases {
t.Run(tc.Input, func(t *testing.T) {
assert := assert.New(t)
actual, err := parseIntentionMatchEntry(tc.Input)
assert.Equal(err != nil, tc.Err, err)
if err != nil {
return
}
assert.Equal(tc.Expected, actual)
})
}
}

Some files were not shown because too many files have changed in this diff Show More