Merge remote-tracking branch 'connect/f-connect'

This commit is contained in:
mkeeler 2018-06-25 19:42:51 +00:00
commit 1da3c42867
1149 changed files with 283446 additions and 6659 deletions

View File

@ -2,12 +2,15 @@
FEATURES:
* **Connect Feature Beta**: This version includes a major new feature for Consul named Connect. Connect enables secure service-to-service communication with automatic TLS encryption and identity-based authorization. For more details and links to demos and getting started guides, see the [announcement blog post](https://www.hashicorp.com/blog/consul-1-2-service-mesh).
* Connect must be enabled explicitly in configuration so upgrading a cluster will not affect any existing functionality until it's enabled.
* This is a Beta release, we don't recommend enabling this in production yet. Please see the documentation for more information.
* dns: Enable PTR record lookups for services with IPs that have no registered node [[PR-4083](https://github.com/hashicorp/consul/pull/4083)]
* ui: Default to serving the new UI. Setting the `CONSUL_UI_LEGACY` environment variable to `1` of `true` will revert to serving the old UI
IMPROVEMENTS:
* agent: A Consul user-agent string is now sent to providers when making retry-join requests [GH-4013](https://github.com/hashicorp/consul/pull/4013)
* agent: A Consul user-agent string is now sent to providers when making retry-join requests [[GH-4013](https://github.com/hashicorp/consul/issues/4013)](https://github.com/hashicorp/consul/pull/4013)
* client: Add metrics for failed RPCs [PR-4220](https://github.com/hashicorp/consul/pull/4220)
* agent: Add configuration entry to control including TXT records for node meta in DNS responses [PR-4215](https://github.com/hashicorp/consul/pull/4215)
* client: Make RPC rate limit configuration reloadable [GH-4012](https://github.com/hashicorp/consul/issues/4012)

View File

@ -145,10 +145,16 @@ test: other-consul dev-build vet
@# _something_ to stop them terminating us due to inactivity...
{ go test $(GOTEST_FLAGS) -tags '$(GOTAGS)' -timeout 5m $(GOTEST_PKGS) 2>&1 ; echo $$? > exit-code ; } | tee test.log | egrep '^(ok|FAIL)\s*github.com/hashicorp/consul'
@echo "Exit code: $$(cat exit-code)" >> test.log
@grep -A5 'DATA RACE' test.log || true
@grep -A10 'panic: test timed out' test.log || true
@grep -A1 -- '--- SKIP:' test.log || true
@grep -A1 -- '--- FAIL:' test.log || true
@# This prints all the race report between ====== lines
@awk '/^WARNING: DATA RACE/ {do_print=1; print "=================="} do_print==1 {print} /^={10,}/ {do_print=0}' test.log || true
@grep -A10 'panic: ' test.log || true
@# Prints all the failure output until the next non-indented line - testify
@# helpers often output multiple lines for readability but useless if we can't
@# see them. Un-intuitive order of matches is necessary. No || true because
@# awk always returns true even if there is no match and it breaks non-bash
@# shells locally.
@awk '/^[^[:space:]]/ {do_print=0} /--- SKIP/ {do_print=1} do_print==1 {print}' test.log
@awk '/^[^[:space:]]/ {do_print=0} /--- FAIL/ {do_print=1} do_print==1 {print}' test.log
@grep '^FAIL' test.log || true
@if [ "$$(cat exit-code)" == "0" ] ; then echo "PASS" ; exit 0 ; else exit 1 ; fi

105
README.md
View File

@ -1,75 +1,66 @@
# Consul [![Build Status](https://travis-ci.org/hashicorp/consul.svg?branch=master)](https://travis-ci.org/hashicorp/consul) [![Join the chat at https://gitter.im/hashicorp-consul/Lobby](https://badges.gitter.im/hashicorp-consul/Lobby.svg)](https://gitter.im/hashicorp-consul/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
**This is a temporary README. We'll restore the old README prior to PR upstream.**
* Website: https://www.consul.io
* Chat: [Gitter](https://gitter.im/hashicorp-consul/Lobby)
* Mailing list: [Google Groups](https://groups.google.com/group/consul-tool/)
# Consul Connect
Consul is a tool for service discovery and configuration. Consul is
distributed, highly available, and extremely scalable.
This repository is the forked repository for Consul Connect work to happen
in private prior to public release. This README will explain how to safely
use this fork, how to bring in upstream changes, etc.
Consul provides several key features:
## Cloning
* **Service Discovery** - Consul makes it simple for services to register
themselves and to discover other services via a DNS or HTTP interface.
External services such as SaaS providers can be registered as well.
To use this repository, clone it into your GOPATH as usual but you must
**rename `consul-connect` to `consul`** so that Go imports continue working
as usual.
* **Health Checking** - Health Checking enables Consul to quickly alert
operators about any issues in a cluster. The integration with service
discovery prevents routing traffic to unhealthy hosts and enables service
level circuit breakers.
## Important: Never Modify Master
* **Key/Value Storage** - A flexible key/value store enables storing
dynamic configuration, feature flagging, coordination, leader election and
more. The simple HTTP API makes it easy to use anywhere.
**NEVER MODIFY MASTER! NEVER MODIFY MASTER!**
* **Multi-Datacenter** - Consul is built to be datacenter aware, and can
support any number of regions without complex configuration.
We want to keep the "master" branch equivalent to OSS master. This will make
rebasing easy for master. Instead, we'll use the branch `f-connect`. All
feature branches should branch from `f-connect` and make PRs against
`f-connect`.
Consul runs on Linux, Mac OS X, FreeBSD, Solaris, and Windows. A commercial
version called [Consul Enterprise](https://www.hashicorp.com/products/consul)
is also available.
When we're ready to merge back to upstream, we can make a single mega PR
merging `f-connect` into OSS master. This way we don't have a sudden mega
push to master on OSS.
## Quick Start
## Creating a Feature Branch
An extensive quick start is viewable on the Consul website:
To create a feature branch, branch from `f-connect`:
https://www.consul.io/intro/getting-started/install.html
## Documentation
Full, comprehensive documentation is viewable on the Consul website:
https://www.consul.io/docs
## Developing Consul
If you wish to work on Consul itself, you'll first need [Go](https://golang.org)
installed (version 1.9+ is _required_). Make sure you have Go properly installed,
including setting up your [GOPATH](https://golang.org/doc/code.html#GOPATH).
Next, clone this repository into `$GOPATH/src/github.com/hashicorp/consul` and
then just type `make`. In a few moments, you'll have a working `consul` executable:
```
$ make
...
$ bin/consul
...
```sh
git checkout f-connect
git checkout -b my-new-branch
```
*Note: `make` will build all os/architecture combinations. Set the environment variable `CONSUL_DEV=1` to build it just for your local machine's os/architecture, or use `make dev`.*
All merged Connect features will be in `f-connect`, so you want to work
from that branch. When making a PR for your feature branch, target the
`f-connect` branch as the merge target. You can do this by using the dropdowns
in the GitHub UI when creating a PR.
*Note: `make` will also place a copy of the binary in the first part of your `$GOPATH`.*
## Syncing Upstream
You can run tests by typing `make test`. The test suite may fail if
over-parallelized, so if you are seeing stochastic failures try
`GOTEST_FLAGS="-p 2 -parallel 2" make test`.
First update our local master:
If you make any changes to the code, run `make format` in order to automatically
format the code according to Go standards.
```sh
# This has to happen on forked master
git checkout master
## Vendoring
# Add upstream to OSS Consul
git remote add upstream https://github.com/hashicorp/consul.git
Consul currently uses [govendor](https://github.com/kardianos/govendor) for
vendoring and [vendorfmt](https://github.com/magiconair/vendorfmt) for formatting
`vendor.json` to a more merge-friendly "one line per package" format.
# Fetch it
git fetch upstream
# Rebase forked master onto upstream. This should have no changes since
# we're never modifying master.
git rebase upstream master
```
Next, update the `f-connect` branch:
```sh
git checkout f-connect
git rebase origin master
```

View File

@ -60,6 +60,17 @@ type ACL interface {
// EventWrite determines if a specific event may be fired.
EventWrite(string) bool
// IntentionDefaultAllow determines the default authorized behavior
// when no intentions match a Connect request.
IntentionDefaultAllow() bool
// IntentionRead determines if a specific intention can be read.
IntentionRead(string) bool
// IntentionWrite determines if a specific intention can be
// created, modified, or deleted.
IntentionWrite(string) bool
// KeyList checks for permission to list keys under a prefix
KeyList(string) bool
@ -154,6 +165,18 @@ func (s *StaticACL) EventWrite(string) bool {
return s.defaultAllow
}
func (s *StaticACL) IntentionDefaultAllow() bool {
return s.defaultAllow
}
func (s *StaticACL) IntentionRead(string) bool {
return s.defaultAllow
}
func (s *StaticACL) IntentionWrite(string) bool {
return s.defaultAllow
}
func (s *StaticACL) KeyRead(string) bool {
return s.defaultAllow
}
@ -275,6 +298,9 @@ type PolicyACL struct {
// agentRules contains the agent policies
agentRules *radix.Tree
// intentionRules contains the service intention policies
intentionRules *radix.Tree
// keyRules contains the key policies
keyRules *radix.Tree
@ -308,6 +334,7 @@ func New(parent ACL, policy *Policy, sentinel sentinel.Evaluator) (*PolicyACL, e
p := &PolicyACL{
parent: parent,
agentRules: radix.New(),
intentionRules: radix.New(),
keyRules: radix.New(),
nodeRules: radix.New(),
serviceRules: radix.New(),
@ -347,6 +374,25 @@ func New(parent ACL, policy *Policy, sentinel sentinel.Evaluator) (*PolicyACL, e
sentinelPolicy: sp.Sentinel,
}
p.serviceRules.Insert(sp.Name, policyRule)
// Determine the intention. The intention could be blank (not set).
// If the intention is not set, the value depends on the value of
// the service policy.
intention := sp.Intentions
if intention == "" {
switch sp.Policy {
case PolicyRead, PolicyWrite:
intention = PolicyRead
default:
intention = PolicyDeny
}
}
policyRule = PolicyRule{
aclPolicy: intention,
sentinelPolicy: sp.Sentinel,
}
p.intentionRules.Insert(sp.Name, policyRule)
}
// Load the session policy
@ -455,6 +501,51 @@ func (p *PolicyACL) EventWrite(name string) bool {
return p.parent.EventWrite(name)
}
// IntentionDefaultAllow returns whether the default behavior when there are
// no matching intentions is to allow or deny.
func (p *PolicyACL) IntentionDefaultAllow() bool {
// We always go up, this can't be determined by a policy.
return p.parent.IntentionDefaultAllow()
}
// IntentionRead checks if writing (creating, updating, or deleting) of an
// intention is allowed.
func (p *PolicyACL) IntentionRead(prefix string) bool {
// Check for an exact rule or catch-all
_, rule, ok := p.intentionRules.LongestPrefix(prefix)
if ok {
pr := rule.(PolicyRule)
switch pr.aclPolicy {
case PolicyRead, PolicyWrite:
return true
default:
return false
}
}
// No matching rule, use the parent.
return p.parent.IntentionRead(prefix)
}
// IntentionWrite checks if writing (creating, updating, or deleting) of an
// intention is allowed.
func (p *PolicyACL) IntentionWrite(prefix string) bool {
// Check for an exact rule or catch-all
_, rule, ok := p.intentionRules.LongestPrefix(prefix)
if ok {
pr := rule.(PolicyRule)
switch pr.aclPolicy {
case PolicyWrite:
return true
default:
return false
}
}
// No matching rule, use the parent.
return p.parent.IntentionWrite(prefix)
}
// KeyRead returns if a key is allowed to be read
func (p *PolicyACL) KeyRead(key string) bool {
// Look for a matching rule

View File

@ -53,6 +53,12 @@ func TestStaticACL(t *testing.T) {
if !all.EventWrite("foobar") {
t.Fatalf("should allow")
}
if !all.IntentionDefaultAllow() {
t.Fatalf("should allow")
}
if !all.IntentionWrite("foobar") {
t.Fatalf("should allow")
}
if !all.KeyRead("foobar") {
t.Fatalf("should allow")
}
@ -123,6 +129,12 @@ func TestStaticACL(t *testing.T) {
if none.EventWrite("") {
t.Fatalf("should not allow")
}
if none.IntentionDefaultAllow() {
t.Fatalf("should not allow")
}
if none.IntentionWrite("foo") {
t.Fatalf("should not allow")
}
if none.KeyRead("foobar") {
t.Fatalf("should not allow")
}
@ -187,6 +199,12 @@ func TestStaticACL(t *testing.T) {
if !manage.EventWrite("foobar") {
t.Fatalf("should allow")
}
if !manage.IntentionDefaultAllow() {
t.Fatalf("should allow")
}
if !manage.IntentionWrite("foobar") {
t.Fatalf("should allow")
}
if !manage.KeyRead("foobar") {
t.Fatalf("should allow")
}
@ -307,6 +325,12 @@ func TestPolicyACL(t *testing.T) {
&ServicePolicy{
Name: "barfoo",
Policy: PolicyWrite,
Intentions: PolicyWrite,
},
&ServicePolicy{
Name: "intbaz",
Policy: PolicyWrite,
Intentions: PolicyDeny,
},
},
}
@ -344,6 +368,31 @@ func TestPolicyACL(t *testing.T) {
}
}
// Test the intentions
type intentioncase struct {
inp string
read bool
write bool
}
icases := []intentioncase{
{"other", true, false},
{"foo", true, false},
{"bar", false, false},
{"foobar", true, false},
{"barfo", false, false},
{"barfoo", true, true},
{"barfoo2", true, true},
{"intbaz", false, false},
}
for _, c := range icases {
if c.read != acl.IntentionRead(c.inp) {
t.Fatalf("Read fail: %#v", c)
}
if c.write != acl.IntentionWrite(c.inp) {
t.Fatalf("Write fail: %#v", c)
}
}
// Test the services
type servicecase struct {
inp string
@ -414,6 +463,11 @@ func TestPolicyACL(t *testing.T) {
t.Fatalf("Prepared query fail: %#v", c)
}
}
// Check default intentions bubble up
if !acl.IntentionDefaultAllow() {
t.Fatal("should allow")
}
}
func TestPolicyACL_Parent(t *testing.T) {
@ -567,6 +621,11 @@ func TestPolicyACL_Parent(t *testing.T) {
if acl.Snapshot() {
t.Fatalf("should not allow")
}
// Check default intentions
if acl.IntentionDefaultAllow() {
t.Fatal("should not allow")
}
}
func TestPolicyACL_Agent(t *testing.T) {

View File

@ -73,6 +73,11 @@ type ServicePolicy struct {
Name string `hcl:",key"`
Policy string
Sentinel Sentinel
// Intentions is the policy for intentions where this service is the
// destination. This may be empty, in which case the Policy determines
// the intentions policy.
Intentions string
}
func (s *ServicePolicy) GoString() string {
@ -197,6 +202,9 @@ func Parse(rules string, sentinel sentinel.Evaluator) (*Policy, error) {
if !isPolicyValid(sp.Policy) {
return nil, fmt.Errorf("Invalid service policy: %#v", sp)
}
if sp.Intentions != "" && !isPolicyValid(sp.Intentions) {
return nil, fmt.Errorf("Invalid service intentions policy: %#v", sp)
}
if err := isSentinelValid(sentinel, sp.Policy, sp.Sentinel); err != nil {
return nil, fmt.Errorf("Invalid service Sentinel policy: %#v, got error:%v", sp, err)
}

View File

@ -4,8 +4,85 @@ import (
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParse_table(t *testing.T) {
// Note that the table tests are newer than other tests. Many of the
// other aspects of policy parsing are tested in older tests below. New
// parsing tests should be added to this table as its easier to maintain.
cases := []struct {
Name string
Input string
Expected *Policy
Err string
}{
{
"service no intentions",
`
service "foo" {
policy = "write"
}
`,
&Policy{
Services: []*ServicePolicy{
{
Name: "foo",
Policy: "write",
},
},
},
"",
},
{
"service intentions",
`
service "foo" {
policy = "write"
intentions = "read"
}
`,
&Policy{
Services: []*ServicePolicy{
{
Name: "foo",
Policy: "write",
Intentions: "read",
},
},
},
"",
},
{
"service intention: invalid value",
`
service "foo" {
policy = "write"
intentions = "foo"
}
`,
nil,
"service intentions",
},
}
for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
assert := assert.New(t)
actual, err := Parse(tc.Input, nil)
assert.Equal(tc.Err != "", err != nil, err)
if err != nil {
assert.Contains(err.Error(), tc.Err)
return
}
assert.Equal(tc.Expected, actual)
})
}
}
func TestACLPolicy_Parse_HCL(t *testing.T) {
inp := `
agent "foo" {

View File

@ -8,6 +8,7 @@ import (
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/local"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/golang-lru"
@ -239,6 +240,18 @@ func (a *Agent) resolveToken(id string) (acl.ACL, error) {
return a.acls.lookupACL(a, id)
}
// resolveProxyToken attempts to resolve an ACL ID to a local proxy token.
// If a local proxy isn't found with that token, nil is returned.
func (a *Agent) resolveProxyToken(id string) *local.ManagedProxy {
for _, p := range a.State.Proxies() {
if p.ProxyToken == id {
return p
}
}
return nil
}
// vetServiceRegister makes sure the service registration action is allowed by
// the given token.
func (a *Agent) vetServiceRegister(token string, service *structs.NodeService) error {

View File

@ -21,16 +21,20 @@ import (
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/ae"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/checks"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/consul"
"github.com/hashicorp/consul/agent/local"
"github.com/hashicorp/consul/agent/proxy"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/systemd"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/lib/file"
"github.com/hashicorp/consul/logger"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/consul/watch"
@ -46,6 +50,9 @@ const (
// Path to save agent service definitions
servicesDir = "services"
// Path to save agent proxy definitions
proxyDir = "proxies"
// Path to save local agent checks
checksDir = "checks"
checkStateDir = "checks/state"
@ -119,6 +126,9 @@ type Agent struct {
// and the remote state.
sync *ae.StateSyncer
// cache is the in-memory cache for data the Agent requests.
cache *cache.Cache
// checkReapAfter maps the check ID to a timeout after which we should
// reap its associated service
checkReapAfter map[types.CheckID]time.Duration
@ -195,6 +205,9 @@ type Agent struct {
// be updated at runtime, so should always be used instead of going to
// the configuration directly.
tokens *token.Store
// proxyManager is the proxy process manager for managed Connect proxies.
proxyManager *proxy.Manager
}
func New(c *config.RuntimeConfig) (*Agent, error) {
@ -247,6 +260,8 @@ func LocalConfig(cfg *config.RuntimeConfig) local.Config {
NodeID: cfg.NodeID,
NodeName: cfg.NodeName,
TaggedAddresses: map[string]string{},
ProxyBindMinPort: cfg.ConnectProxyBindMinPort,
ProxyBindMaxPort: cfg.ConnectProxyBindMaxPort,
}
for k, v := range cfg.TaggedAddresses {
lc.TaggedAddresses[k] = v
@ -289,6 +304,9 @@ func (a *Agent) Start() error {
// regular and on-demand state synchronizations (anti-entropy).
a.sync = ae.NewStateSyncer(a.State, c.AEInterval, a.shutdownCh, a.logger)
// create the cache
a.cache = cache.New(nil)
// create the config for the rpc server/client
consulCfg, err := a.consulConfig()
if err != nil {
@ -325,10 +343,17 @@ func (a *Agent) Start() error {
a.State.Delegate = a.delegate
a.State.TriggerSyncChanges = a.sync.SyncChanges.Trigger
// Register the cache. We do this much later so the delegate is
// populated from above.
a.registerCache()
// Load checks/services/metadata.
if err := a.loadServices(c); err != nil {
return err
}
if err := a.loadProxies(c); err != nil {
return err
}
if err := a.loadChecks(c); err != nil {
return err
}
@ -336,6 +361,28 @@ func (a *Agent) Start() error {
return err
}
// create the proxy process manager and start it. This is purposely
// done here after the local state above is loaded in so we can have
// a more accurate initial state view.
if !c.ConnectTestDisableManagedProxies {
a.proxyManager = proxy.NewManager()
a.proxyManager.AllowRoot = a.config.ConnectProxyAllowManagedRoot
a.proxyManager.State = a.State
a.proxyManager.Logger = a.logger
if a.config.DataDir != "" {
// DataDir is required for all non-dev mode agents, but we want
// to allow setting the data dir for demos and so on for the agent,
// so do the check above instead.
a.proxyManager.DataDir = filepath.Join(a.config.DataDir, "proxy")
// Restore from our snapshot (if it exists)
if err := a.proxyManager.Restore(a.proxyManager.SnapshotPath()); err != nil {
a.logger.Printf("[WARN] agent: error restoring proxy state: %s", err)
}
}
go a.proxyManager.Run()
}
// Start watching for critical services to deregister, based on their
// checks.
go a.reapServices()
@ -605,6 +652,16 @@ func (a *Agent) reloadWatches(cfg *config.RuntimeConfig) error {
return fmt.Errorf("Handler type '%s' not recognized", params["handler_type"])
}
// Don't let people use connect watches via this mechanism for now as it
// needs thought about how to do securely and shouldn't be necessary. Note
// that if the type assertion fails an type is not a string then
// ParseExample below will error so we don't need to handle that case.
if typ, ok := params["type"].(string); ok {
if strings.HasPrefix(typ, "connect_") {
return fmt.Errorf("Watch type %s is not allowed in agent config", typ)
}
}
// Parse the watches, excluding 'handler' and 'args'
wp, err := watch.ParseExempt(params, []string{"handler", "args"})
if err != nil {
@ -878,6 +935,47 @@ func (a *Agent) consulConfig() (*consul.Config, error) {
base.TLSCipherSuites = a.config.TLSCipherSuites
base.TLSPreferServerCipherSuites = a.config.TLSPreferServerCipherSuites
// Copy the Connect CA bootstrap config
if a.config.ConnectEnabled {
base.ConnectEnabled = true
// Allow config to specify cluster_id provided it's a valid UUID. This is
// meant only for tests where a deterministic ID makes fixtures much simpler
// to work with but since it's only read on initial cluster bootstrap it's not
// that much of a liability in production. The worst a user could do is
// configure logically separate clusters with same ID by mistake but we can
// avoid documenting this is even an option.
if clusterID, ok := a.config.ConnectCAConfig["cluster_id"]; ok {
if cIDStr, ok := clusterID.(string); ok {
if _, err := uuid.ParseUUID(cIDStr); err == nil {
// Valid UUID configured, use that
base.CAConfig.ClusterID = cIDStr
}
}
if base.CAConfig.ClusterID == "" {
// If the tried to specify an ID but typoed it don't ignore as they will
// then bootstrap with a new ID and have to throw away the whole cluster
// and start again.
a.logger.Println("[ERR] connect CA config cluster_id specified but " +
"is not a valid UUID, aborting startup")
return nil, fmt.Errorf("cluster_id was supplied but was not a valid UUID")
}
}
if a.config.ConnectCAProvider != "" {
base.CAConfig.Provider = a.config.ConnectCAProvider
// Merge with the default config if it's the consul provider.
if a.config.ConnectCAProvider == "consul" {
for k, v := range a.config.ConnectCAConfig {
base.CAConfig.Config[k] = v
}
} else {
base.CAConfig.Config = a.config.ConnectCAConfig
}
}
}
// Setup the user event callback
base.UserEventHandler = func(e serf.UserEvent) {
select {
@ -1010,7 +1108,7 @@ func (a *Agent) setupNodeID(config *config.RuntimeConfig) error {
}
// For dev mode we have no filesystem access so just make one.
if a.config.DevMode {
if a.config.DataDir == "" {
id, err := a.makeNodeID()
if err != nil {
return err
@ -1224,6 +1322,24 @@ func (a *Agent) ShutdownAgent() error {
chk.Stop()
}
// Stop the proxy manager
if a.proxyManager != nil {
// If persistence is disabled (implies DevMode but a subset of DevMode) then
// don't leave the proxies running since the agent will not be able to
// recover them later.
if a.config.DataDir == "" {
a.logger.Printf("[WARN] agent: dev mode disabled persistence, killing " +
"all proxies since we can't recover them")
if err := a.proxyManager.Kill(); err != nil {
a.logger.Printf("[WARN] agent: error shutting down proxy manager: %s", err)
}
} else {
if err := a.proxyManager.Close(); err != nil {
a.logger.Printf("[WARN] agent: error shutting down proxy manager: %s", err)
}
}
}
var err error
if a.delegate != nil {
err = a.delegate.Shutdown()
@ -1493,7 +1609,7 @@ func (a *Agent) persistService(service *structs.NodeService) error {
return err
}
return writeFileAtomic(svcPath, encoded)
return file.WriteAtomic(svcPath, encoded)
}
// purgeService removes a persisted service definition file from the data dir
@ -1505,6 +1621,39 @@ func (a *Agent) purgeService(serviceID string) error {
return nil
}
// persistedProxy is used to wrap a proxy definition and bundle it with an Proxy
// token so we can continue to authenticate the running proxy after a restart.
type persistedProxy struct {
ProxyToken string
Proxy *structs.ConnectManagedProxy
}
// persistProxy saves a proxy definition to a JSON file in the data dir
func (a *Agent) persistProxy(proxy *local.ManagedProxy) error {
proxyPath := filepath.Join(a.config.DataDir, proxyDir,
stringHash(proxy.Proxy.ProxyService.ID))
wrapped := persistedProxy{
ProxyToken: proxy.ProxyToken,
Proxy: proxy.Proxy,
}
encoded, err := json.Marshal(wrapped)
if err != nil {
return err
}
return file.WriteAtomic(proxyPath, encoded)
}
// purgeProxy removes a persisted proxy definition file from the data dir
func (a *Agent) purgeProxy(proxyID string) error {
proxyPath := filepath.Join(a.config.DataDir, proxyDir, stringHash(proxyID))
if _, err := os.Stat(proxyPath); err == nil {
return os.Remove(proxyPath)
}
return nil
}
// persistCheck saves a check definition to the local agent's state directory
func (a *Agent) persistCheck(check *structs.HealthCheck, chkType *structs.CheckType) error {
checkPath := filepath.Join(a.config.DataDir, checksDir, checkIDHash(check.CheckID))
@ -1521,7 +1670,7 @@ func (a *Agent) persistCheck(check *structs.HealthCheck, chkType *structs.CheckT
return err
}
return writeFileAtomic(checkPath, encoded)
return file.WriteAtomic(checkPath, encoded)
}
// purgeCheck removes a persisted check definition file from the data dir
@ -1533,43 +1682,6 @@ func (a *Agent) purgeCheck(checkID types.CheckID) error {
return nil
}
// writeFileAtomic writes the given contents to a temporary file in the same
// directory, does an fsync and then renames the file to its real path
func writeFileAtomic(path string, contents []byte) error {
uuid, err := uuid.GenerateUUID()
if err != nil {
return err
}
tempPath := fmt.Sprintf("%s-%s.tmp", path, uuid)
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return err
}
fh, err := os.OpenFile(tempPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return err
}
if _, err := fh.Write(contents); err != nil {
fh.Close()
os.Remove(tempPath)
return err
}
if err := fh.Sync(); err != nil {
fh.Close()
os.Remove(tempPath)
return err
}
if err := fh.Close(); err != nil {
os.Remove(tempPath)
return err
}
if err := os.Rename(tempPath, path); err != nil {
os.Remove(tempPath)
return err
}
return nil
}
// AddService is used to add a service entry.
// This entry is persistent and the agent will make a best effort to
// ensure it is registered
@ -1623,7 +1735,7 @@ func (a *Agent) AddService(service *structs.NodeService, chkTypes []*structs.Che
a.State.AddService(service, token)
// Persist the service to a file
if persist && !a.config.DevMode {
if persist && a.config.DataDir != "" {
if err := a.persistService(service); err != nil {
return err
}
@ -1922,7 +2034,7 @@ func (a *Agent) AddCheck(check *structs.HealthCheck, chkType *structs.CheckType,
}
// Persist the check
if persist && !a.config.DevMode {
if persist && a.config.DataDir != "" {
return a.persistCheck(check, chkType)
}
@ -1974,6 +2086,277 @@ func (a *Agent) RemoveCheck(checkID types.CheckID, persist bool) error {
return nil
}
// AddProxy adds a new local Connect Proxy instance to be managed by the agent.
//
// It REQUIRES that the service that is being proxied is already present in the
// local state. Note that this is only used for agent-managed proxies so we can
// ensure that we always make this true. For externally managed and registered
// proxies we explicitly allow the proxy to be registered first to make
// bootstrap ordering of a new service simpler but the same is not true here
// since this is only ever called when setting up a _managed_ proxy which was
// registered as part of a service registration either from config or HTTP API
// call.
//
// The restoredProxyToken argument should only be used when restoring proxy
// definitions from disk; new proxies must leave it blank to get a new token
// assigned. We need to restore from disk to enable to continue authenticating
// running proxies that already had that credential injected.
func (a *Agent) AddProxy(proxy *structs.ConnectManagedProxy, persist bool,
restoredProxyToken string) error {
// Lookup the target service token in state if there is one.
token := a.State.ServiceToken(proxy.TargetServiceID)
// Copy the basic proxy structure so it isn't modified w/ defaults
proxyCopy := *proxy
proxy = &proxyCopy
if err := a.applyProxyDefaults(proxy); err != nil {
return err
}
// Add the proxy to local state first since we may need to assign a port which
// needs to be coordinate under state lock. AddProxy will generate the
// NodeService for the proxy populated with the allocated (or configured) port
// and an ID, but it doesn't add it to the agent directly since that could
// deadlock and we may need to coordinate adding it and persisting etc.
proxyState, err := a.State.AddProxy(proxy, token, restoredProxyToken)
if err != nil {
return err
}
proxyService := proxyState.Proxy.ProxyService
// Register proxy TCP check. The built in proxy doesn't listen publically
// until it's loaded certs so this ensures we won't route traffic until it's
// ready.
proxyCfg, err := a.applyProxyConfigDefaults(proxyState.Proxy)
if err != nil {
return err
}
chkTypes := []*structs.CheckType{
&structs.CheckType{
Name: "Connect Proxy Listening",
TCP: fmt.Sprintf("%s:%d", proxyCfg["bind_address"],
proxyCfg["bind_port"]),
Interval: 10 * time.Second,
},
}
err = a.AddService(proxyService, chkTypes, persist, token)
if err != nil {
// Remove the state too
a.State.RemoveProxy(proxyService.ID)
return err
}
// Persist the proxy
if persist && a.config.DataDir != "" {
return a.persistProxy(proxyState)
}
return nil
}
// applyProxyConfigDefaults takes a *structs.ConnectManagedProxy and returns
// it's Config map merged with any defaults from the Agent's config. It would be
// nicer if this were defined as a method on structs.ConnectManagedProxy but we
// can't do that because ot the import cycle it causes with agent/config.
func (a *Agent) applyProxyConfigDefaults(p *structs.ConnectManagedProxy) (map[string]interface{}, error) {
if p == nil || p.ProxyService == nil {
// Should never happen but protect from panic
return nil, fmt.Errorf("invalid proxy state")
}
// Lookup the target service
target := a.State.Service(p.TargetServiceID)
if target == nil {
// Can happen during deregistration race between proxy and scheduler.
return nil, fmt.Errorf("unknown target service ID: %s", p.TargetServiceID)
}
// Merge globals defaults
config := make(map[string]interface{})
for k, v := range a.config.ConnectProxyDefaultConfig {
if _, ok := config[k]; !ok {
config[k] = v
}
}
// Copy config from the proxy
for k, v := range p.Config {
config[k] = v
}
// Set defaults for anything that is still not specified but required.
// Note that these are not included in the content hash. Since we expect
// them to be static in general but some like the default target service
// port might not be. In that edge case services can set that explicitly
// when they re-register which will be caught though.
if _, ok := config["bind_port"]; !ok {
config["bind_port"] = p.ProxyService.Port
}
if _, ok := config["bind_address"]; !ok {
// Default to binding to the same address the agent is configured to
// bind to.
config["bind_address"] = a.config.BindAddr.String()
}
if _, ok := config["local_service_address"]; !ok {
// Default to localhost and the port the service registered with
config["local_service_address"] = fmt.Sprintf("127.0.0.1:%d", target.Port)
}
// Basic type conversions for expected types.
if raw, ok := config["bind_port"]; ok {
switch v := raw.(type) {
case float64:
// Common since HCL/JSON parse as float64
config["bind_port"] = int(v)
// NOTE(mitchellh): No default case since errors and validation
// are handled by the ServiceDefinition.Validate function.
}
}
return config, nil
}
// applyProxyDefaults modifies the given proxy by applying any configured
// defaults, such as the default execution mode, command, etc.
func (a *Agent) applyProxyDefaults(proxy *structs.ConnectManagedProxy) error {
// Set the default exec mode
if proxy.ExecMode == structs.ProxyExecModeUnspecified {
mode, err := structs.NewProxyExecMode(a.config.ConnectProxyDefaultExecMode)
if err != nil {
return err
}
proxy.ExecMode = mode
}
if proxy.ExecMode == structs.ProxyExecModeUnspecified {
proxy.ExecMode = structs.ProxyExecModeDaemon
}
// Set the default command to the globally configured default
if len(proxy.Command) == 0 {
switch proxy.ExecMode {
case structs.ProxyExecModeDaemon:
proxy.Command = a.config.ConnectProxyDefaultDaemonCommand
case structs.ProxyExecModeScript:
proxy.Command = a.config.ConnectProxyDefaultScriptCommand
}
}
// If there is no globally configured default we need to get the
// default command so we can do "consul connect proxy"
if len(proxy.Command) == 0 {
command, err := defaultProxyCommand(a.config)
if err != nil {
return err
}
proxy.Command = command
}
return nil
}
// RemoveProxy stops and removes a local proxy instance.
func (a *Agent) RemoveProxy(proxyID string, persist bool) error {
// Validate proxyID
if proxyID == "" {
return fmt.Errorf("proxyID missing")
}
// Remove the proxy from the local state
p, err := a.State.RemoveProxy(proxyID)
if err != nil {
return err
}
// Remove the proxy service as well. The proxy ID is also the ID
// of the servie, but we might as well use the service pointer.
if err := a.RemoveService(p.Proxy.ProxyService.ID, persist); err != nil {
return err
}
if persist && a.config.DataDir != "" {
return a.purgeProxy(proxyID)
}
return nil
}
// verifyProxyToken takes a token and attempts to verify it against the
// targetService name. If targetProxy is specified, then the local proxy token
// must exactly match the given proxy ID. cert, config, etc.).
//
// The given token may be a local-only proxy token or it may be an ACL token. We
// will attempt to verify the local proxy token first.
//
// The effective ACL token is returned along with a boolean which is true if the
// match was against a proxy token rather than an ACL token, and any error. In
// the case the token matches a proxy token, then the ACL token used to register
// that proxy's target service is returned for use in any RPC calls the proxy
// needs to make on behalf of that service. If the token was an ACL token
// already then it is always returned. Provided error is nil, a valid ACL token
// is always returned.
func (a *Agent) verifyProxyToken(token, targetService,
targetProxy string) (string, bool, error) {
// If we specify a target proxy, we look up that proxy directly. Otherwise,
// we resolve with any proxy we can find.
var proxy *local.ManagedProxy
if targetProxy != "" {
proxy = a.State.Proxy(targetProxy)
if proxy == nil {
return "", false, fmt.Errorf("unknown proxy service ID: %q", targetProxy)
}
// If the token DOESN'T match, then we reset the proxy which will
// cause the logic below to fall back to normal ACLs. Otherwise,
// we keep the proxy set because we also have to verify that the
// target service matches on the proxy.
if token != proxy.ProxyToken {
proxy = nil
}
} else {
proxy = a.resolveProxyToken(token)
}
// The existence of a token isn't enough, we also need to verify
// that the service name of the matching proxy matches our target
// service.
if proxy != nil {
// Get the target service since we only have the name. The nil
// check below should never be true since a proxy token always
// represents the existence of a local service.
target := a.State.Service(proxy.Proxy.TargetServiceID)
if target == nil {
return "", false, fmt.Errorf("proxy target service not found: %q",
proxy.Proxy.TargetServiceID)
}
if target.Service != targetService {
return "", false, acl.ErrPermissionDenied
}
// Resolve the actual ACL token used to register the proxy/service and
// return that for use in RPC calls.
return a.State.ServiceToken(proxy.Proxy.TargetServiceID), true, nil
}
// Doesn't match, we have to do a full token resolution. The required
// permission for any proxy-related endpoint is service:write, since
// to register a proxy you require that permission and sensitive data
// is usually present in the configuration.
rule, err := a.resolveToken(token)
if err != nil {
return "", false, err
}
if rule != nil && !rule.ServiceWrite(targetService, nil) {
return "", false, acl.ErrPermissionDenied
}
return token, false, nil
}
func (a *Agent) cancelCheckMonitors(checkID types.CheckID) {
// Stop any monitors
delete(a.checkReapAfter, checkID)
@ -2018,7 +2401,7 @@ func (a *Agent) updateTTLCheck(checkID types.CheckID, status, output string) err
check.SetStatus(status, output)
// We don't write any files in dev mode so bail here.
if a.config.DevMode {
if a.config.DataDir == "" {
return nil
}
@ -2367,6 +2750,96 @@ func (a *Agent) unloadChecks() error {
return nil
}
// loadProxies will load connect proxy definitions from configuration and
// persisted definitions on disk, and load them into the local agent.
func (a *Agent) loadProxies(conf *config.RuntimeConfig) error {
for _, svc := range conf.Services {
if svc.Connect != nil {
proxy, err := svc.ConnectManagedProxy()
if err != nil {
return fmt.Errorf("failed adding proxy: %s", err)
}
if proxy == nil {
continue
}
if err := a.AddProxy(proxy, false, ""); err != nil {
return fmt.Errorf("failed adding proxy: %s", err)
}
}
}
// Load any persisted proxies
proxyDir := filepath.Join(a.config.DataDir, proxyDir)
files, err := ioutil.ReadDir(proxyDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("Failed reading proxies dir %q: %s", proxyDir, err)
}
for _, fi := range files {
// Skip all dirs
if fi.IsDir() {
continue
}
// Skip all partially written temporary files
if strings.HasSuffix(fi.Name(), "tmp") {
a.logger.Printf("[WARN] agent: Ignoring temporary proxy file %v", fi.Name())
continue
}
// Open the file for reading
file := filepath.Join(proxyDir, fi.Name())
fh, err := os.Open(file)
if err != nil {
return fmt.Errorf("failed opening proxy file %q: %s", file, err)
}
// Read the contents into a buffer
buf, err := ioutil.ReadAll(fh)
fh.Close()
if err != nil {
return fmt.Errorf("failed reading proxy file %q: %s", file, err)
}
// Try decoding the proxy definition
var p persistedProxy
if err := json.Unmarshal(buf, &p); err != nil {
a.logger.Printf("[ERR] agent: Failed decoding proxy file %q: %s", file, err)
continue
}
proxyID := p.Proxy.ProxyService.ID
if a.State.Proxy(proxyID) != nil {
// Purge previously persisted proxy. This allows config to be preferred
// over services persisted from the API.
a.logger.Printf("[DEBUG] agent: proxy %q exists, not restoring from %q",
proxyID, file)
if err := a.purgeProxy(proxyID); err != nil {
return fmt.Errorf("failed purging proxy %q: %s", proxyID, err)
}
} else {
a.logger.Printf("[DEBUG] agent: restored proxy definition %q from %q",
proxyID, file)
if err := a.AddProxy(p.Proxy, false, p.ProxyToken); err != nil {
return fmt.Errorf("failed adding proxy %q: %s", proxyID, err)
}
}
}
return nil
}
// unloadProxies will deregister all proxies known to the local agent.
func (a *Agent) unloadProxies() error {
for id := range a.State.Proxies() {
if err := a.RemoveProxy(id, false); err != nil {
return fmt.Errorf("Failed deregistering proxy '%s': %s", id, err)
}
}
return nil
}
// snapshotCheckState is used to snapshot the current state of the health
// checks. This is done before we reload our checks, so that we can properly
// restore into the same state.
@ -2508,6 +2981,9 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
// First unload all checks, services, and metadata. This lets us begin the reload
// with a clean slate.
if err := a.unloadProxies(); err != nil {
return fmt.Errorf("Failed unloading proxies: %s", err)
}
if err := a.unloadServices(); err != nil {
return fmt.Errorf("Failed unloading services: %s", err)
}
@ -2520,6 +2996,9 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
if err := a.loadServices(newCfg); err != nil {
return fmt.Errorf("Failed reloading services: %s", err)
}
if err := a.loadProxies(newCfg); err != nil {
return fmt.Errorf("Failed reloading proxies: %s", err)
}
if err := a.loadChecks(newCfg); err != nil {
return fmt.Errorf("Failed reloading checks: %s", err)
}
@ -2544,9 +3023,62 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
}
// Update filtered metrics
metrics.UpdateFilter(newCfg.TelemetryAllowedPrefixes, newCfg.TelemetryBlockedPrefixes)
metrics.UpdateFilter(newCfg.Telemetry.AllowedPrefixes,
newCfg.Telemetry.BlockedPrefixes)
a.State.SetDiscardCheckOutput(newCfg.DiscardCheckOutput)
return nil
}
// registerCache configures the cache and registers all the supported
// types onto the cache. This is NOT safe to call multiple times so
// care should be taken to call this exactly once after the cache
// field has been initialized.
func (a *Agent) registerCache() {
a.cache.RegisterType(cachetype.ConnectCARootName, &cachetype.ConnectCARoot{
RPC: a.delegate,
}, &cache.RegisterOptions{
// Maintain a blocking query, retry dropped connections quickly
Refresh: true,
RefreshTimer: 0 * time.Second,
RefreshTimeout: 10 * time.Minute,
})
a.cache.RegisterType(cachetype.ConnectCALeafName, &cachetype.ConnectCALeaf{
RPC: a.delegate,
Cache: a.cache,
}, &cache.RegisterOptions{
// Maintain a blocking query, retry dropped connections quickly
Refresh: true,
RefreshTimer: 0 * time.Second,
RefreshTimeout: 10 * time.Minute,
})
a.cache.RegisterType(cachetype.IntentionMatchName, &cachetype.IntentionMatch{
RPC: a.delegate,
}, &cache.RegisterOptions{
// Maintain a blocking query, retry dropped connections quickly
Refresh: true,
RefreshTimer: 0 * time.Second,
RefreshTimeout: 10 * time.Minute,
})
}
// defaultProxyCommand returns the default Connect managed proxy command.
func defaultProxyCommand(agentCfg *config.RuntimeConfig) ([]string, error) {
// Get the path to the current exectuable. This is cached once by the
// library so this is effectively just a variable read.
execPath, err := os.Executable()
if err != nil {
return nil, err
}
// "consul connect proxy" default value for managed daemon proxy
cmd := []string{execPath, "connect", "proxy"}
if agentCfg != nil && agentCfg.LogLevel != "INFO" {
cmd = append(cmd, "-log-level", agentCfg.LogLevel)
}
return cmd, nil
}

View File

@ -4,12 +4,21 @@ import (
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/mitchellh/mapstructure"
"github.com/hashicorp/go-memdb"
"github.com/mitchellh/hashstructure"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/checks"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/ipaddr"
@ -97,7 +106,7 @@ func (s *HTTPServer) AgentMetrics(resp http.ResponseWriter, req *http.Request) (
return nil, acl.ErrPermissionDenied
}
if enablePrometheusOutput(req) {
if s.agent.config.TelemetryPrometheusRetentionTime < 1 {
if s.agent.config.Telemetry.PrometheusRetentionTime < 1 {
resp.WriteHeader(http.StatusUnsupportedMediaType)
fmt.Fprint(resp, "Prometheus is not enable since its retention time is not positive")
return nil, nil
@ -153,25 +162,49 @@ func (s *HTTPServer) AgentServices(resp http.ResponseWriter, req *http.Request)
return nil, err
}
proxies := s.agent.State.Proxies()
// Convert into api.AgentService since that includes Connect config but so far
// NodeService doesn't need to internally. They are otherwise identical since
// that is the struct used in client for reading the one we output here
// anyway.
agentSvcs := make(map[string]*api.AgentService)
// Use empty list instead of nil
for id, s := range services {
if s.Tags == nil || s.Meta == nil {
clone := *s
if s.Tags == nil {
clone.Tags = make([]string, 0)
} else {
clone.Tags = s.Tags
as := &api.AgentService{
Kind: api.ServiceKind(s.Kind),
ID: s.ID,
Service: s.Service,
Tags: s.Tags,
Meta: s.Meta,
Port: s.Port,
Address: s.Address,
EnableTagOverride: s.EnableTagOverride,
CreateIndex: s.CreateIndex,
ModifyIndex: s.ModifyIndex,
ProxyDestination: s.ProxyDestination,
}
if s.Meta == nil {
clone.Meta = make(map[string]string)
} else {
clone.Meta = s.Meta
if as.Tags == nil {
as.Tags = []string{}
}
services[id] = &clone
if as.Meta == nil {
as.Meta = map[string]string{}
}
// Attach Connect configs if the exist
if proxy, ok := proxies[id+"-proxy"]; ok {
as.Connect = &api.AgentServiceConnect{
Proxy: &api.AgentServiceConnectProxy{
ExecMode: api.ProxyExecMode(proxy.Proxy.ExecMode.String()),
Command: proxy.Proxy.Command,
Config: proxy.Proxy.Config,
},
}
}
agentSvcs[id] = as
}
return services, nil
return agentSvcs, nil
}
func (s *HTTPServer) AgentChecks(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
@ -554,6 +587,14 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
return nil, nil
}
// Run validation. This is the same validation that would happen on
// the catalog endpoint so it helps ensure the sync will work properly.
if err := ns.Validate(); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, err.Error())
return nil, nil
}
// Verify the check type.
chkTypes, err := args.CheckTypes()
if err != nil {
@ -576,10 +617,30 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
return nil, err
}
// Get any proxy registrations
proxy, err := args.ConnectManagedProxy()
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, err.Error())
return nil, nil
}
// If we have a proxy, verify that we're allowed to add a proxy via the API
if proxy != nil && !s.agent.config.ConnectProxyAllowManagedAPIRegistration {
return nil, &BadRequestError{
Reason: "Managed proxy registration via the API is disallowed."}
}
// Add the service.
if err := s.agent.AddService(ns, chkTypes, true, token); err != nil {
return nil, err
}
// Add proxy (which will add proxy service so do it before we trigger sync)
if proxy != nil {
if err := s.agent.AddProxy(proxy, true, ""); err != nil {
return nil, err
}
}
s.syncChanges()
return nil, nil
}
@ -594,9 +655,27 @@ func (s *HTTPServer) AgentDeregisterService(resp http.ResponseWriter, req *http.
return nil, err
}
// Verify this isn't a proxy
if s.agent.State.Proxy(serviceID) != nil {
return nil, &BadRequestError{
Reason: "Managed proxy service cannot be deregistered directly. " +
"Deregister the service that has a managed proxy to automatically " +
"deregister the managed proxy itself."}
}
if err := s.agent.RemoveService(serviceID, true); err != nil {
return nil, err
}
// Remove the associated managed proxy if it exists
for proxyID, p := range s.agent.State.Proxies() {
if p.Proxy.TargetServiceID == serviceID {
if err := s.agent.RemoveProxy(proxyID, true); err != nil {
return nil, err
}
}
}
s.syncChanges()
return nil, nil
}
@ -828,3 +907,394 @@ func (s *HTTPServer) AgentToken(resp http.ResponseWriter, req *http.Request) (in
s.agent.logger.Printf("[INFO] agent: Updated agent's ACL token %q", target)
return nil, nil
}
// AgentConnectCARoots returns the trusted CA roots.
func (s *HTTPServer) AgentConnectCARoots(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var args structs.DCSpecificRequest
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
raw, m, err := s.agent.cache.Get(cachetype.ConnectCARootName, &args)
if err != nil {
return nil, err
}
defer setCacheMeta(resp, &m)
// Add cache hit
reply, ok := raw.(*structs.IndexedCARoots)
if !ok {
// This should never happen, but we want to protect against panics
return nil, fmt.Errorf("internal error: response type not correct")
}
defer setMeta(resp, &reply.QueryMeta)
return *reply, nil
}
// AgentConnectCALeafCert returns the certificate bundle for a service
// instance. This supports blocking queries to update the returned bundle.
func (s *HTTPServer) AgentConnectCALeafCert(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Get the service name. Note that this is the name of the sevice,
// not the ID of the service instance.
serviceName := strings.TrimPrefix(req.URL.Path, "/v1/agent/connect/ca/leaf/")
args := cachetype.ConnectCALeafRequest{
Service: serviceName, // Need name not ID
}
var qOpts structs.QueryOptions
// Store DC in the ConnectCALeafRequest but query opts separately
if done := s.parse(resp, req, &args.Datacenter, &qOpts); done {
return nil, nil
}
args.MinQueryIndex = qOpts.MinQueryIndex
// Verify the proxy token. This will check both the local proxy token
// as well as the ACL if the token isn't local.
effectiveToken, _, err := s.agent.verifyProxyToken(qOpts.Token, serviceName, "")
if err != nil {
return nil, err
}
args.Token = effectiveToken
raw, m, err := s.agent.cache.Get(cachetype.ConnectCALeafName, &args)
if err != nil {
return nil, err
}
defer setCacheMeta(resp, &m)
reply, ok := raw.(*structs.IssuedCert)
if !ok {
// This should never happen, but we want to protect against panics
return nil, fmt.Errorf("internal error: response type not correct")
}
setIndex(resp, reply.ModifyIndex)
return reply, nil
}
// GET /v1/agent/connect/proxy/:proxy_service_id
//
// Returns the local proxy config for the identified proxy. Requires token=
// param with the correct local ProxyToken (not ACL token).
func (s *HTTPServer) AgentConnectProxyConfig(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Get the proxy ID. Note that this is the ID of a proxy's service instance.
id := strings.TrimPrefix(req.URL.Path, "/v1/agent/connect/proxy/")
// Maybe block
var queryOpts structs.QueryOptions
if parseWait(resp, req, &queryOpts) {
// parseWait returns an error itself
return nil, nil
}
// Parse the token
var token string
s.parseToken(req, &token)
// Parse hash specially since it's only this endpoint that uses it currently.
// Eventually this should happen in parseWait and end up in QueryOptions but I
// didn't want to make very general changes right away.
hash := req.URL.Query().Get("hash")
return s.agentLocalBlockingQuery(resp, hash, &queryOpts,
func(ws memdb.WatchSet) (string, interface{}, error) {
// Retrieve the proxy specified
proxy := s.agent.State.Proxy(id)
if proxy == nil {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprintf(resp, "unknown proxy service ID: %s", id)
return "", nil, nil
}
// Lookup the target service as a convenience
target := s.agent.State.Service(proxy.Proxy.TargetServiceID)
if target == nil {
// Not found since this endpoint is only useful for agent-managed proxies so
// service missing means the service was deregistered racily with this call.
resp.WriteHeader(http.StatusNotFound)
fmt.Fprintf(resp, "unknown target service ID: %s", proxy.Proxy.TargetServiceID)
return "", nil, nil
}
// Validate the ACL token
_, isProxyToken, err := s.agent.verifyProxyToken(token, target.Service, id)
if err != nil {
return "", nil, err
}
// Watch the proxy for changes
ws.Add(proxy.WatchCh)
hash, err := hashstructure.Hash(proxy.Proxy, nil)
if err != nil {
return "", nil, err
}
contentHash := fmt.Sprintf("%x", hash)
// Set defaults
config, err := s.agent.applyProxyConfigDefaults(proxy.Proxy)
if err != nil {
return "", nil, err
}
// Only merge in telemetry config from agent if the requested is
// authorized with a proxy token. This prevents us leaking potentially
// sensitive config like Circonus API token via a public endpoint. Proxy
// tokens are only ever generated in-memory and passed via ENV to a child
// proxy process so potential for abuse here seems small. This endpoint in
// general is only useful for managed proxies now so it should _always_ be
// true that auth is via a proxy token but inconvenient for testing if we
// lock it down so strictly.
if isProxyToken {
// Add telemetry config. Copy the global config so we can customize the
// prefix.
telemetryCfg := s.agent.config.Telemetry
telemetryCfg.MetricsPrefix = telemetryCfg.MetricsPrefix + ".proxy." + target.ID
// First see if the user has specified telemetry
if userRaw, ok := config["telemetry"]; ok {
// User specified domething, see if it is compatible with agent
// telemetry config:
var uCfg lib.TelemetryConfig
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: &uCfg,
// Make sure that if the user passes something that isn't just a
// simple override of a valid TelemetryConfig that we fail so that we
// don't clobber their custom config.
ErrorUnused: true,
})
if err == nil {
if err = dec.Decode(userRaw); err == nil {
// It did decode! Merge any unspecified fields from agent config.
uCfg.MergeDefaults(&telemetryCfg)
config["telemetry"] = uCfg
}
}
// Failed to decode, just keep user's config["telemetry"] verbatim
// with no agent merge.
} else {
// Add agent telemetry config.
config["telemetry"] = telemetryCfg
}
}
reply := &api.ConnectProxyConfig{
ProxyServiceID: proxy.Proxy.ProxyService.ID,
TargetServiceID: target.ID,
TargetServiceName: target.Service,
ContentHash: contentHash,
ExecMode: api.ProxyExecMode(proxy.Proxy.ExecMode.String()),
Command: proxy.Proxy.Command,
Config: config,
}
return contentHash, reply, nil
})
}
type agentLocalBlockingFunc func(ws memdb.WatchSet) (string, interface{}, error)
// agentLocalBlockingQuery performs a blocking query in a generic way against
// local agent state that has no RPC or raft to back it. It uses `hash` paramter
// instead of an `index`. The resp is needed to write the `X-Consul-ContentHash`
// header back on return no Status nor body content is ever written to it.
func (s *HTTPServer) agentLocalBlockingQuery(resp http.ResponseWriter, hash string,
queryOpts *structs.QueryOptions, fn agentLocalBlockingFunc) (interface{}, error) {
// If we are not blocking we can skip tracking and allocating - nil WatchSet
// is still valid to call Add on and will just be a no op.
var ws memdb.WatchSet
var timeout *time.Timer
if hash != "" {
// TODO(banks) at least define these defaults somewhere in a const. Would be
// nice not to duplicate the ones in consul/rpc.go too...
wait := queryOpts.MaxQueryTime
if wait == 0 {
wait = 5 * time.Minute
}
if wait > 10*time.Minute {
wait = 10 * time.Minute
}
// Apply a small amount of jitter to the request.
wait += lib.RandomStagger(wait / 16)
timeout = time.NewTimer(wait)
}
for {
// Must reset this every loop in case the Watch set is already closed but
// hash remains same. In that case we'll need to re-block on ws.Watch()
// again.
ws = memdb.NewWatchSet()
curHash, curResp, err := fn(ws)
if err != nil {
return curResp, err
}
// Return immediately if there is no timeout, the hash is different or the
// Watch returns true (indicating timeout fired). Note that Watch on a nil
// WatchSet immediately returns false which would incorrectly cause this to
// loop and repeat again, however we rely on the invariant that ws == nil
// IFF timeout == nil in which case the Watch call is never invoked.
if timeout == nil || hash != curHash || ws.Watch(timeout.C) {
resp.Header().Set("X-Consul-ContentHash", curHash)
return curResp, err
}
// Watch returned false indicating a change was detected, loop and repeat
// the callback to load the new value.
}
}
// AgentConnectAuthorize
//
// POST /v1/agent/connect/authorize
//
// Note: when this logic changes, consider if the Intention.Check RPC method
// also needs to be updated.
func (s *HTTPServer) AgentConnectAuthorize(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Fetch the token
var token string
s.parseToken(req, &token)
// Decode the request from the request body
var authReq structs.ConnectAuthorizeRequest
if err := decodeBody(req, &authReq, nil); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
}
// We need to have a target to check intentions
if authReq.Target == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Target service must be specified")
return nil, nil
}
// Parse the certificate URI from the client ID
uriRaw, err := url.Parse(authReq.ClientCertURI)
if err != nil {
return &connectAuthorizeResp{
Authorized: false,
Reason: fmt.Sprintf("Client ID must be a URI: %s", err),
}, nil
}
uri, err := connect.ParseCertURI(uriRaw)
if err != nil {
return &connectAuthorizeResp{
Authorized: false,
Reason: fmt.Sprintf("Invalid client ID: %s", err),
}, nil
}
uriService, ok := uri.(*connect.SpiffeIDService)
if !ok {
return &connectAuthorizeResp{
Authorized: false,
Reason: "Client ID must be a valid SPIFFE service URI",
}, nil
}
// We need to verify service:write permissions for the given token.
// We do this manually here since the RPC request below only verifies
// service:read.
rule, err := s.agent.resolveToken(token)
if err != nil {
return nil, err
}
if rule != nil && !rule.ServiceWrite(authReq.Target, nil) {
return nil, acl.ErrPermissionDenied
}
// Validate the trust domain matches ours. Later we will support explicit
// external federation but not built yet.
rootArgs := &structs.DCSpecificRequest{Datacenter: s.agent.config.Datacenter}
raw, _, err := s.agent.cache.Get(cachetype.ConnectCARootName, rootArgs)
if err != nil {
return nil, err
}
roots, ok := raw.(*structs.IndexedCARoots)
if !ok {
return nil, fmt.Errorf("internal error: roots response type not correct")
}
if roots.TrustDomain == "" {
return nil, fmt.Errorf("connect CA not bootstrapped yet")
}
if roots.TrustDomain != strings.ToLower(uriService.Host) {
return &connectAuthorizeResp{
Authorized: false,
Reason: fmt.Sprintf("Identity from an external trust domain: %s",
uriService.Host),
}, nil
}
// TODO(banks): Implement revocation list checking here.
// Get the intentions for this target service.
args := &structs.IntentionQueryRequest{
Datacenter: s.agent.config.Datacenter,
Match: &structs.IntentionQueryMatch{
Type: structs.IntentionMatchDestination,
Entries: []structs.IntentionMatchEntry{
{
Namespace: structs.IntentionDefaultNamespace,
Name: authReq.Target,
},
},
},
}
args.Token = token
raw, m, err := s.agent.cache.Get(cachetype.IntentionMatchName, args)
if err != nil {
return nil, err
}
setCacheMeta(resp, &m)
reply, ok := raw.(*structs.IndexedIntentionMatches)
if !ok {
return nil, fmt.Errorf("internal error: response type not correct")
}
if len(reply.Matches) != 1 {
return nil, fmt.Errorf("Internal error loading matches")
}
// Test the authorization for each match
for _, ixn := range reply.Matches[0] {
if auth, ok := uriService.Authorize(ixn); ok {
return &connectAuthorizeResp{
Authorized: auth,
Reason: fmt.Sprintf("Matched intention: %s", ixn.String()),
}, nil
}
}
// No match, we need to determine the default behavior. We do this by
// specifying the anonymous token token, which will get that behavior.
// The default behavior if ACLs are disabled is to allow connections
// to mimic the behavior of Consul itself: everything is allowed if
// ACLs are disabled.
rule, err = s.agent.resolveToken("")
if err != nil {
return nil, err
}
authz := true
reason := "ACLs disabled, access is allowed by default"
if rule != nil {
authz = rule.IntentionDefaultAllow()
reason = "Default behavior configured by ACLs"
}
return &connectAuthorizeResp{
Authorized: authz,
Reason: reason,
}, nil
}
// connectAuthorizeResp is the response format/structure for the
// /v1/agent/connect/authorize endpoint.
type connectAuthorizeResp struct {
Authorized bool // True if authorized, false if not
Reason string // Reason for the Authorized value (whether true or false)
}

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@ import (
"time"
"github.com/hashicorp/consul/agent/checks"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/testutil"
@ -23,6 +24,8 @@ import (
"github.com/hashicorp/consul/types"
uuid "github.com/hashicorp/go-uuid"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func externalIP() (string, error) {
@ -51,10 +54,62 @@ func TestAgent_MultiStartStop(t *testing.T) {
}
}
func TestAgent_ConnectClusterIDConfig(t *testing.T) {
tests := []struct {
name string
hcl string
wantClusterID string
wantPanic bool
}{
{
name: "default TestAgent has fixed cluster id",
hcl: "",
wantClusterID: connect.TestClusterID,
},
{
name: "no cluster ID specified sets to test ID",
hcl: "connect { enabled = true }",
wantClusterID: connect.TestClusterID,
},
{
name: "non-UUID cluster_id is fatal",
hcl: `connect {
enabled = true
ca_config {
cluster_id = "fake-id"
}
}`,
wantClusterID: "",
wantPanic: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Indirection to support panic recovery cleanly
testFn := func() {
a := &TestAgent{Name: "test", HCL: tt.hcl}
a.ExpectConfigError = tt.wantPanic
a.Start()
defer a.Shutdown()
cfg := a.consulConfig()
assert.Equal(t, tt.wantClusterID, cfg.CAConfig.ClusterID)
}
if tt.wantPanic {
require.Panics(t, testFn)
} else {
testFn()
}
})
}
}
func TestAgent_StartStop(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
// defer a.Shutdown()
defer a.Shutdown()
if err := a.Leave(); err != nil {
t.Fatalf("err: %v", err)
@ -1294,6 +1349,187 @@ func TestAgent_PurgeServiceOnDuplicate(t *testing.T) {
}
}
func TestAgent_PersistProxy(t *testing.T) {
t.Parallel()
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
cfg := `
server = false
bootstrap = false
data_dir = "` + dataDir + `"
`
a := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
a.Start()
defer os.RemoveAll(dataDir)
defer a.Shutdown()
require := require.New(t)
assert := assert.New(t)
// Add a service to proxy (precondition for AddProxy)
svc1 := &structs.NodeService{
ID: "redis",
Service: "redis",
Tags: []string{"foo"},
Port: 8000,
}
require.NoError(a.AddService(svc1, nil, true, ""))
// Add a proxy for it
proxy := &structs.ConnectManagedProxy{
TargetServiceID: svc1.ID,
Command: []string{"/bin/sleep", "3600"},
}
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash("redis-proxy"))
// Proxy is not persisted unless requested
require.NoError(a.AddProxy(proxy, false, ""))
_, err := os.Stat(file)
require.Error(err, "proxy should not be persisted")
// Proxy is persisted if requested
require.NoError(a.AddProxy(proxy, true, ""))
_, err = os.Stat(file)
require.NoError(err, "proxy should be persisted")
content, err := ioutil.ReadFile(file)
require.NoError(err)
var gotProxy persistedProxy
require.NoError(json.Unmarshal(content, &gotProxy))
assert.Equal(proxy.Command, gotProxy.Proxy.Command)
assert.Len(gotProxy.ProxyToken, 36) // sanity check for UUID
// Updates service definition on disk
proxy.Config = map[string]interface{}{
"foo": "bar",
}
require.NoError(a.AddProxy(proxy, true, ""))
content, err = ioutil.ReadFile(file)
require.NoError(err)
require.NoError(json.Unmarshal(content, &gotProxy))
assert.Equal(gotProxy.Proxy.Command, proxy.Command)
assert.Equal(gotProxy.Proxy.Config, proxy.Config)
assert.Len(gotProxy.ProxyToken, 36) // sanity check for UUID
a.Shutdown()
// Should load it back during later start
a2 := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
a2.Start()
defer a2.Shutdown()
restored := a2.State.Proxy("redis-proxy")
require.NotNil(restored)
assert.Equal(gotProxy.ProxyToken, restored.ProxyToken)
// Ensure the port that was auto picked at random is the same again
assert.Equal(gotProxy.Proxy.ProxyService.Port, restored.Proxy.ProxyService.Port)
assert.Equal(gotProxy.Proxy.Command, restored.Proxy.Command)
}
func TestAgent_PurgeProxy(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
require := require.New(t)
// Add a service to proxy (precondition for AddProxy)
svc1 := &structs.NodeService{
ID: "redis",
Service: "redis",
Tags: []string{"foo"},
Port: 8000,
}
require.NoError(a.AddService(svc1, nil, true, ""))
// Add a proxy for it
proxy := &structs.ConnectManagedProxy{
TargetServiceID: svc1.ID,
Command: []string{"/bin/sleep", "3600"},
}
proxyID := "redis-proxy"
require.NoError(a.AddProxy(proxy, true, ""))
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash("redis-proxy"))
// Not removed
require.NoError(a.RemoveProxy(proxyID, false))
_, err := os.Stat(file)
require.NoError(err, "should not be removed")
// Re-add the proxy
require.NoError(a.AddProxy(proxy, true, ""))
// Removed
require.NoError(a.RemoveProxy(proxyID, true))
_, err = os.Stat(file)
require.Error(err, "should be removed")
}
func TestAgent_PurgeProxyOnDuplicate(t *testing.T) {
t.Parallel()
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
cfg := `
data_dir = "` + dataDir + `"
server = false
bootstrap = false
`
a := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
a.Start()
defer a.Shutdown()
defer os.RemoveAll(dataDir)
require := require.New(t)
// Add a service to proxy (precondition for AddProxy)
svc1 := &structs.NodeService{
ID: "redis",
Service: "redis",
Tags: []string{"foo"},
Port: 8000,
}
require.NoError(a.AddService(svc1, nil, true, ""))
// Add a proxy for it
proxy := &structs.ConnectManagedProxy{
TargetServiceID: svc1.ID,
Command: []string{"/bin/sleep", "3600"},
}
proxyID := "redis-proxy"
require.NoError(a.AddProxy(proxy, true, ""))
a.Shutdown()
// Try bringing the agent back up with the service already
// existing in the config
a2 := &TestAgent{Name: t.Name() + "-a2", HCL: cfg + `
service = {
id = "redis"
name = "redis"
tags = ["bar"]
port = 9000
connect {
proxy {
command = ["/bin/sleep", "3600"]
}
}
}
`, DataDir: dataDir}
a2.Start()
defer a2.Shutdown()
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash(proxyID))
_, err := os.Stat(file)
require.Error(err, "should have removed remote state")
result := a2.State.Proxy(proxyID)
require.NotNil(result)
require.Equal(proxy.Command, result.Proxy.Command)
}
func TestAgent_PersistCheck(t *testing.T) {
t.Parallel()
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
@ -1629,6 +1865,96 @@ func TestAgent_unloadServices(t *testing.T) {
}
}
func TestAgent_loadProxies(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `
service = {
id = "rabbitmq"
name = "rabbitmq"
port = 5672
token = "abc123"
connect {
proxy {
config {
bind_port = 1234
}
}
}
}
`)
defer a.Shutdown()
services := a.State.Services()
if _, ok := services["rabbitmq"]; !ok {
t.Fatalf("missing service")
}
if token := a.State.ServiceToken("rabbitmq"); token != "abc123" {
t.Fatalf("bad: %s", token)
}
if _, ok := services["rabbitmq-proxy"]; !ok {
t.Fatalf("missing proxy service")
}
if token := a.State.ServiceToken("rabbitmq-proxy"); token != "abc123" {
t.Fatalf("bad: %s", token)
}
proxies := a.State.Proxies()
if _, ok := proxies["rabbitmq-proxy"]; !ok {
t.Fatalf("missing proxy")
}
}
func TestAgent_loadProxies_nilProxy(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `
service = {
id = "rabbitmq"
name = "rabbitmq"
port = 5672
token = "abc123"
connect {
}
}
`)
defer a.Shutdown()
services := a.State.Services()
require.Contains(t, services, "rabbitmq")
require.Equal(t, "abc123", a.State.ServiceToken("rabbitmq"))
require.NotContains(t, services, "rabbitme-proxy")
require.Empty(t, a.State.Proxies())
}
func TestAgent_unloadProxies(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `
service = {
id = "rabbitmq"
name = "rabbitmq"
port = 5672
token = "abc123"
connect {
proxy {
config {
bind_port = 1234
}
}
}
}
`)
defer a.Shutdown()
// Sanity check it's there
require.NotNil(t, a.State.Proxy("rabbitmq-proxy"))
// Unload all proxies
if err := a.unloadProxies(); err != nil {
t.Fatalf("err: %s", err)
}
if len(a.State.Proxies()) != 0 {
t.Fatalf("should have unloaded proxies")
}
}
func TestAgent_Service_MaintenanceMode(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
@ -2179,6 +2505,18 @@ func TestAgent_reloadWatches(t *testing.T) {
t.Fatalf("bad: %s", err)
}
// Should fail to reload with connect watches
newConf.Watches = []map[string]interface{}{
{
"type": "connect_roots",
"key": "asdf",
"args": []interface{}{"ls"},
},
}
if err := a.reloadWatches(&newConf); err == nil || !strings.Contains(err.Error(), "not allowed in agent config") {
t.Fatalf("bad: %s", err)
}
// Should still succeed with only HTTPS addresses
newConf.HTTPSAddrs = newConf.HTTPAddrs
newConf.HTTPAddrs = make([]net.Addr, 0)
@ -2226,3 +2564,217 @@ func TestAgent_reloadWatchesHTTPS(t *testing.T) {
t.Fatalf("bad: %s", err)
}
}
func TestAgent_AddProxy(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `
node_name = "node1"
connect {
proxy_defaults {
exec_mode = "script"
daemon_command = ["foo", "bar"]
script_command = ["bar", "foo"]
}
}
ports {
proxy_min_port = 20000
proxy_max_port = 20000
}
`)
defer a.Shutdown()
// Register a target service we can use
reg := &structs.NodeService{
Service: "web",
Port: 8080,
}
require.NoError(t, a.AddService(reg, nil, false, ""))
tests := []struct {
desc string
proxy, wantProxy *structs.ConnectManagedProxy
wantTCPCheck string
wantErr bool
}{
{
desc: "basic proxy adding, unregistered service",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"consul", "connect", "proxy"},
Config: map[string]interface{}{
"foo": "bar",
},
TargetServiceID: "db", // non-existent service.
},
// Target service must be registered.
wantErr: true,
},
{
desc: "basic proxy adding, registered service",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"consul", "connect", "proxy"},
Config: map[string]interface{}{
"foo": "bar",
},
TargetServiceID: "web",
},
wantErr: false,
},
{
desc: "default global exec mode",
proxy: &structs.ConnectManagedProxy{
Command: []string{"consul", "connect", "proxy"},
TargetServiceID: "web",
},
wantProxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeScript,
Command: []string{"consul", "connect", "proxy"},
TargetServiceID: "web",
},
wantErr: false,
},
{
desc: "default daemon command",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
TargetServiceID: "web",
},
wantProxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"foo", "bar"},
TargetServiceID: "web",
},
wantErr: false,
},
{
desc: "default script command",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeScript,
TargetServiceID: "web",
},
wantProxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeScript,
Command: []string{"bar", "foo"},
TargetServiceID: "web",
},
wantErr: false,
},
{
desc: "managed proxy with custom bind port",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"consul", "connect", "proxy"},
Config: map[string]interface{}{
"foo": "bar",
"bind_address": "127.10.10.10",
"bind_port": 1234,
},
TargetServiceID: "web",
},
wantTCPCheck: "127.10.10.10:1234",
wantErr: false,
},
{
// This test is necessary since JSON and HCL both will parse
// numbers as a float64.
desc: "managed proxy with custom bind port (float64)",
proxy: &structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"consul", "connect", "proxy"},
Config: map[string]interface{}{
"foo": "bar",
"bind_address": "127.10.10.10",
"bind_port": float64(1234),
},
TargetServiceID: "web",
},
wantTCPCheck: "127.10.10.10:1234",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
require := require.New(t)
err := a.AddProxy(tt.proxy, false, "")
if tt.wantErr {
require.Error(err)
return
}
require.NoError(err)
// Test the ID was created as we expect.
got := a.State.Proxy("web-proxy")
wantProxy := tt.wantProxy
if wantProxy == nil {
wantProxy = tt.proxy
}
wantProxy.ProxyService = got.Proxy.ProxyService
require.Equal(wantProxy, got.Proxy)
// Ensure a TCP check was created for the service.
gotCheck := a.State.Check("service:web-proxy")
require.NotNil(gotCheck)
require.Equal("Connect Proxy Listening", gotCheck.Name)
// Confusingly, a.State.Check("service:web-proxy") will return the state
// but it's Definition field will be empty. This appears to be expected
// when adding Checks as part of `AddService`. Notice how `AddService`
// tests in this file don't assert on that state but instead look at the
// agent's check state directly to ensure the right thing was registered.
// We'll do the same for now.
gotTCP, ok := a.checkTCPs["service:web-proxy"]
require.True(ok)
wantTCPCheck := tt.wantTCPCheck
if wantTCPCheck == "" {
wantTCPCheck = "127.0.0.1:20000"
}
require.Equal(wantTCPCheck, gotTCP.TCP)
require.Equal(10*time.Second, gotTCP.Interval)
})
}
}
func TestAgent_RemoveProxy(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `
node_name = "node1"
`)
defer a.Shutdown()
require := require.New(t)
// Register a target service we can use
reg := &structs.NodeService{
Service: "web",
Port: 8080,
}
require.NoError(a.AddService(reg, nil, false, ""))
// Add a proxy for web
pReg := &structs.ConnectManagedProxy{
TargetServiceID: "web",
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"foo"},
}
require.NoError(a.AddProxy(pReg, false, ""))
// Test the ID was created as we expect.
gotProxy := a.State.Proxy("web-proxy")
require.NotNil(gotProxy)
err := a.RemoveProxy("web-proxy", false)
require.NoError(err)
gotProxy = a.State.Proxy("web-proxy")
require.Nil(gotProxy)
require.Nil(a.State.Service("web-proxy"), "web-proxy service")
// Removing invalid proxy should be an error
err = a.RemoveProxy("foobar", false)
require.Error(err)
}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,240 @@
package cachetype
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
)
// Recommended name for registration.
const ConnectCALeafName = "connect-ca-leaf"
// ConnectCALeaf supports fetching and generating Connect leaf
// certificates.
type ConnectCALeaf struct {
caIndex uint64 // Current index for CA roots
issuedCertsLock sync.RWMutex
issuedCerts map[string]*structs.IssuedCert
RPC RPC // RPC client for remote requests
Cache *cache.Cache // Cache that has CA root certs via ConnectCARoot
}
func (c *ConnectCALeaf) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
var result cache.FetchResult
// Get the correct type
reqReal, ok := req.(*ConnectCALeafRequest)
if !ok {
return result, fmt.Errorf(
"Internal cache failure: request wrong type: %T", req)
}
// This channel watches our overall timeout. The other goroutines
// launched in this function should end all around the same time so
// they clean themselves up.
timeoutCh := time.After(opts.Timeout)
// Kick off the goroutine that waits for new CA roots. The channel buffer
// is so that the goroutine doesn't block forever if we return for other
// reasons.
newRootCACh := make(chan error, 1)
go c.waitNewRootCA(reqReal.Datacenter, newRootCACh, opts.Timeout)
// Get our prior cert (if we had one) and use that to determine our
// expiration time. If no cert exists, we expire immediately since we
// need to generate.
c.issuedCertsLock.RLock()
lastCert := c.issuedCerts[reqReal.Service]
c.issuedCertsLock.RUnlock()
var leafExpiryCh <-chan time.Time
if lastCert != nil {
// Determine how long we wait until triggering. If we've already
// expired, we trigger immediately.
if expiryDur := lastCert.ValidBefore.Sub(time.Now()); expiryDur > 0 {
leafExpiryCh = time.After(expiryDur - 1*time.Hour)
// TODO(mitchellh): 1 hour buffer is hardcoded above
}
}
if leafExpiryCh == nil {
// If the channel is still nil then it means we need to generate
// a cert no matter what: we either don't have an existing one or
// it is expired.
leafExpiryCh = time.After(0)
}
// Block on the events that wake us up.
select {
case <-timeoutCh:
// On a timeout, we just return the empty result and no error.
// It isn't an error to timeout, its just the limit of time the
// caching system wants us to block for. By returning an empty result
// the caching system will ignore.
return result, nil
case err := <-newRootCACh:
// A new root CA triggers us to refresh the leaf certificate.
// If there was an error while getting the root CA then we return.
// Otherwise, we leave the select statement and move to generation.
if err != nil {
return result, err
}
case <-leafExpiryCh:
// The existing leaf certificate is expiring soon, so we generate a
// new cert with a healthy overlapping validity period (determined
// by the above channel).
}
// Need to lookup RootCAs response to discover trust domain. First just lookup
// with no blocking info - this should be a cache hit most of the time.
rawRoots, _, err := c.Cache.Get(ConnectCARootName, &structs.DCSpecificRequest{
Datacenter: reqReal.Datacenter,
})
if err != nil {
return result, err
}
roots, ok := rawRoots.(*structs.IndexedCARoots)
if !ok {
return result, errors.New("invalid RootCA response type")
}
if roots.TrustDomain == "" {
return result, errors.New("cluster has no CA bootstrapped")
}
// Build the service ID
serviceID := &connect.SpiffeIDService{
Host: roots.TrustDomain,
Datacenter: reqReal.Datacenter,
Namespace: "default",
Service: reqReal.Service,
}
// Create a new private key
pk, pkPEM, err := connect.GeneratePrivateKey()
if err != nil {
return result, err
}
// Create a CSR.
csr, err := connect.CreateCSR(serviceID, pk)
if err != nil {
return result, err
}
// Request signing
var reply structs.IssuedCert
args := structs.CASignRequest{
WriteRequest: structs.WriteRequest{Token: reqReal.Token},
Datacenter: reqReal.Datacenter,
CSR: csr,
}
if err := c.RPC.RPC("ConnectCA.Sign", &args, &reply); err != nil {
return result, err
}
reply.PrivateKeyPEM = pkPEM
// Lock the issued certs map so we can insert it. We only insert if
// we didn't happen to get a newer one. This should never happen since
// the Cache should ensure only one Fetch per service, but we sanity
// check just in case.
c.issuedCertsLock.Lock()
defer c.issuedCertsLock.Unlock()
lastCert = c.issuedCerts[reqReal.Service]
if lastCert == nil || lastCert.ModifyIndex < reply.ModifyIndex {
if c.issuedCerts == nil {
c.issuedCerts = make(map[string]*structs.IssuedCert)
}
c.issuedCerts[reqReal.Service] = &reply
lastCert = &reply
}
result.Value = lastCert
result.Index = lastCert.ModifyIndex
return result, nil
}
// waitNewRootCA blocks until a new root CA is available or the timeout is
// reached (on timeout ErrTimeout is returned on the channel).
func (c *ConnectCALeaf) waitNewRootCA(datacenter string, ch chan<- error,
timeout time.Duration) {
// We always want to block on at least an initial value. If this isn't
minIndex := atomic.LoadUint64(&c.caIndex)
if minIndex == 0 {
minIndex = 1
}
// Fetch some new roots. This will block until our MinQueryIndex is
// matched or the timeout is reached.
rawRoots, _, err := c.Cache.Get(ConnectCARootName, &structs.DCSpecificRequest{
Datacenter: datacenter,
QueryOptions: structs.QueryOptions{
MinQueryIndex: minIndex,
MaxQueryTime: timeout,
},
})
if err != nil {
ch <- err
return
}
roots, ok := rawRoots.(*structs.IndexedCARoots)
if !ok {
// This should never happen but we don't want to even risk a panic
ch <- fmt.Errorf(
"internal error: CA root cache returned bad type: %T", rawRoots)
return
}
// We do a loop here because there can be multiple waitNewRootCA calls
// happening simultaneously. Each Fetch kicks off one call. These are
// multiplexed through Cache.Get which should ensure we only ever
// actually make a single RPC call. However, there is a race to set
// the caIndex field so do a basic CAS loop here.
for {
// We only set our index if its newer than what is previously set.
old := atomic.LoadUint64(&c.caIndex)
if old == roots.Index || old > roots.Index {
break
}
// Set the new index atomically. If the caIndex value changed
// in the meantime, retry.
if atomic.CompareAndSwapUint64(&c.caIndex, old, roots.Index) {
break
}
}
// Trigger the channel since we updated.
ch <- nil
}
// ConnectCALeafRequest is the cache.Request implementation for the
// ConnectCALeaf cache type. This is implemented here and not in structs
// since this is only used for cache-related requests and not forwarded
// directly to any Consul servers.
type ConnectCALeafRequest struct {
Token string
Datacenter string
Service string // Service name, not ID
MinQueryIndex uint64
}
func (r *ConnectCALeafRequest) CacheInfo() cache.RequestInfo {
return cache.RequestInfo{
Token: r.Token,
Key: r.Service,
Datacenter: r.Datacenter,
MinIndex: r.MinQueryIndex,
}
}

View File

@ -0,0 +1,209 @@
package cachetype
import (
"fmt"
"sync/atomic"
"testing"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// Test that after an initial signing, new CA roots (new ID) will
// trigger a blocking query to execute.
func TestConnectCALeaf_changingRoots(t *testing.T) {
t.Parallel()
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ, rootsCh := testCALeafType(t, rpc)
defer close(rootsCh)
rootsCh <- structs.IndexedCARoots{
ActiveRootID: "1",
TrustDomain: "fake-trust-domain.consul",
QueryMeta: structs.QueryMeta{Index: 1},
}
// Instrument ConnectCA.Sign to return signed cert
var resp *structs.IssuedCert
var idx uint64
rpc.On("RPC", "ConnectCA.Sign", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
reply := args.Get(2).(*structs.IssuedCert)
reply.ValidBefore = time.Now().Add(12 * time.Hour)
reply.CreateIndex = atomic.AddUint64(&idx, 1)
reply.ModifyIndex = reply.CreateIndex
resp = reply
})
// We'll reuse the fetch options and request
opts := cache.FetchOptions{MinIndex: 0, Timeout: 10 * time.Second}
req := &ConnectCALeafRequest{Datacenter: "dc1", Service: "web"}
// First fetch should return immediately
fetchCh := TestFetchCh(t, typ, opts, req)
select {
case <-time.After(100 * time.Millisecond):
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
require.Equal(cache.FetchResult{
Value: resp,
Index: 1,
}, result)
}
// Second fetch should block with set index
fetchCh = TestFetchCh(t, typ, opts, req)
select {
case result := <-fetchCh:
t.Fatalf("should not return: %#v", result)
case <-time.After(100 * time.Millisecond):
}
// Let's send in new roots, which should trigger the sign req
rootsCh <- structs.IndexedCARoots{
ActiveRootID: "2",
TrustDomain: "fake-trust-domain.consul",
QueryMeta: structs.QueryMeta{Index: 2},
}
select {
case <-time.After(100 * time.Millisecond):
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
require.Equal(cache.FetchResult{
Value: resp,
Index: 2,
}, result)
}
// Third fetch should block
fetchCh = TestFetchCh(t, typ, opts, req)
select {
case result := <-fetchCh:
t.Fatalf("should not return: %#v", result)
case <-time.After(100 * time.Millisecond):
}
}
// Test that after an initial signing, an expiringLeaf will trigger a
// blocking query to resign.
func TestConnectCALeaf_expiringLeaf(t *testing.T) {
t.Parallel()
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ, rootsCh := testCALeafType(t, rpc)
defer close(rootsCh)
rootsCh <- structs.IndexedCARoots{
ActiveRootID: "1",
TrustDomain: "fake-trust-domain.consul",
QueryMeta: structs.QueryMeta{Index: 1},
}
// Instrument ConnectCA.Sign to
var resp *structs.IssuedCert
var idx uint64
rpc.On("RPC", "ConnectCA.Sign", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
reply := args.Get(2).(*structs.IssuedCert)
reply.CreateIndex = atomic.AddUint64(&idx, 1)
reply.ModifyIndex = reply.CreateIndex
// This sets the validity to 0 on the first call, and
// 12 hours+ on subsequent calls. This means that our first
// cert expires immediately.
reply.ValidBefore = time.Now().Add((12 * time.Hour) *
time.Duration(reply.CreateIndex-1))
resp = reply
})
// We'll reuse the fetch options and request
opts := cache.FetchOptions{MinIndex: 0, Timeout: 10 * time.Second}
req := &ConnectCALeafRequest{Datacenter: "dc1", Service: "web"}
// First fetch should return immediately
fetchCh := TestFetchCh(t, typ, opts, req)
select {
case <-time.After(100 * time.Millisecond):
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
require.Equal(cache.FetchResult{
Value: resp,
Index: 1,
}, result)
}
// Second fetch should return immediately despite there being
// no updated CA roots, because we issued an expired cert.
fetchCh = TestFetchCh(t, typ, opts, req)
select {
case <-time.After(100 * time.Millisecond):
t.Fatal("shouldn't block waiting for fetch")
case result := <-fetchCh:
require.Equal(cache.FetchResult{
Value: resp,
Index: 2,
}, result)
}
// Third fetch should block since the cert is not expiring and
// we also didn't update CA certs.
fetchCh = TestFetchCh(t, typ, opts, req)
select {
case result := <-fetchCh:
t.Fatalf("should not return: %#v", result)
case <-time.After(100 * time.Millisecond):
}
}
// testCALeafType returns a *ConnectCALeaf that is pre-configured to
// use the given RPC implementation for "ConnectCA.Sign" operations.
func testCALeafType(t *testing.T, rpc RPC) (*ConnectCALeaf, chan structs.IndexedCARoots) {
// This creates an RPC implementation that will block until the
// value is sent on the channel. This lets us control when the
// next values show up.
rootsCh := make(chan structs.IndexedCARoots, 10)
rootsRPC := &testGatedRootsRPC{ValueCh: rootsCh}
// Create a cache
c := cache.TestCache(t)
c.RegisterType(ConnectCARootName, &ConnectCARoot{RPC: rootsRPC}, &cache.RegisterOptions{
// Disable refresh so that the gated channel controls the
// request directly. Otherwise, we get background refreshes and
// it screws up the ordering of the channel reads of the
// testGatedRootsRPC implementation.
Refresh: false,
})
// Create the leaf type
return &ConnectCALeaf{RPC: rpc, Cache: c}, rootsCh
}
// testGatedRootsRPC will send each subsequent value on the channel as the
// RPC response, blocking if it is waiting for a value on the channel. This
// can be used to control when background fetches are returned and what they
// return.
//
// This should be used with Refresh = false for the registration options so
// automatic refreshes don't mess up the channel read ordering.
type testGatedRootsRPC struct {
ValueCh chan structs.IndexedCARoots
}
func (r *testGatedRootsRPC) RPC(method string, args interface{}, reply interface{}) error {
if method != "ConnectCA.Roots" {
return fmt.Errorf("invalid RPC method: %s", method)
}
replyReal := reply.(*structs.IndexedCARoots)
*replyReal = <-r.ValueCh
return nil
}

View File

@ -0,0 +1,43 @@
package cachetype
import (
"fmt"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
)
// Recommended name for registration.
const ConnectCARootName = "connect-ca-root"
// ConnectCARoot supports fetching the Connect CA roots. This is a
// straightforward cache type since it only has to block on the given
// index and return the data.
type ConnectCARoot struct {
RPC RPC
}
func (c *ConnectCARoot) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
var result cache.FetchResult
// The request should be a DCSpecificRequest.
reqReal, ok := req.(*structs.DCSpecificRequest)
if !ok {
return result, fmt.Errorf(
"Internal cache failure: request wrong type: %T", req)
}
// Set the minimum query index to our current index so we block
reqReal.QueryOptions.MinQueryIndex = opts.MinIndex
reqReal.QueryOptions.MaxQueryTime = opts.Timeout
// Fetch
var reply structs.IndexedCARoots
if err := c.RPC.RPC("ConnectCA.Roots", reqReal, &reply); err != nil {
return result, err
}
result.Value = &reply
result.Index = reply.QueryMeta.Index
return result, nil
}

View File

@ -0,0 +1,57 @@
package cachetype
import (
"testing"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestConnectCARoot(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &ConnectCARoot{RPC: rpc}
// Expect the proper RPC call. This also sets the expected value
// since that is return-by-pointer in the arguments.
var resp *structs.IndexedCARoots
rpc.On("RPC", "ConnectCA.Roots", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.DCSpecificRequest)
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex)
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime)
reply := args.Get(2).(*structs.IndexedCARoots)
reply.QueryMeta.Index = 48
resp = reply
})
// Fetch
result, err := typ.Fetch(cache.FetchOptions{
MinIndex: 24,
Timeout: 1 * time.Second,
}, &structs.DCSpecificRequest{Datacenter: "dc1"})
require.Nil(err)
require.Equal(cache.FetchResult{
Value: resp,
Index: 48,
}, result)
}
func TestConnectCARoot_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &ConnectCARoot{RPC: rpc}
// Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.NotNil(err)
require.Contains(err.Error(), "wrong type")
}

View File

@ -0,0 +1,41 @@
package cachetype
import (
"fmt"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
)
// Recommended name for registration.
const IntentionMatchName = "intention-match"
// IntentionMatch supports fetching the intentions via match queries.
type IntentionMatch struct {
RPC RPC
}
func (c *IntentionMatch) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
var result cache.FetchResult
// The request should be an IntentionQueryRequest.
reqReal, ok := req.(*structs.IntentionQueryRequest)
if !ok {
return result, fmt.Errorf(
"Internal cache failure: request wrong type: %T", req)
}
// Set the minimum query index to our current index so we block
reqReal.MinQueryIndex = opts.MinIndex
reqReal.MaxQueryTime = opts.Timeout
// Fetch
var reply structs.IndexedIntentionMatches
if err := c.RPC.RPC("Intention.Match", reqReal, &reply); err != nil {
return result, err
}
result.Value = &reply
result.Index = reply.Index
return result, nil
}

View File

@ -0,0 +1,57 @@
package cachetype
import (
"testing"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestIntentionMatch(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &IntentionMatch{RPC: rpc}
// Expect the proper RPC call. This also sets the expected value
// since that is return-by-pointer in the arguments.
var resp *structs.IndexedIntentionMatches
rpc.On("RPC", "Intention.Match", mock.Anything, mock.Anything).Return(nil).
Run(func(args mock.Arguments) {
req := args.Get(1).(*structs.IntentionQueryRequest)
require.Equal(uint64(24), req.MinQueryIndex)
require.Equal(1*time.Second, req.MaxQueryTime)
reply := args.Get(2).(*structs.IndexedIntentionMatches)
reply.Index = 48
resp = reply
})
// Fetch
result, err := typ.Fetch(cache.FetchOptions{
MinIndex: 24,
Timeout: 1 * time.Second,
}, &structs.IntentionQueryRequest{Datacenter: "dc1"})
require.NoError(err)
require.Equal(cache.FetchResult{
Value: resp,
Index: 48,
}, result)
}
func TestIntentionMatch_badReqType(t *testing.T) {
require := require.New(t)
rpc := TestRPC(t)
defer rpc.AssertExpectations(t)
typ := &IntentionMatch{RPC: rpc}
// Fetch
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
require.Error(err)
require.Contains(err.Error(), "wrong type")
}

View File

@ -0,0 +1,23 @@
// Code generated by mockery v1.0.0
package cachetype
import mock "github.com/stretchr/testify/mock"
// MockRPC is an autogenerated mock type for the RPC type
type MockRPC struct {
mock.Mock
}
// RPC provides a mock function with given fields: method, args, reply
func (_m *MockRPC) RPC(method string, args interface{}, reply interface{}) error {
ret := _m.Called(method, args, reply)
var r0 error
if rf, ok := ret.Get(0).(func(string, interface{}, interface{}) error); ok {
r0 = rf(method, args, reply)
} else {
r0 = ret.Error(0)
}
return r0
}

10
agent/cache-types/rpc.go Normal file
View File

@ -0,0 +1,10 @@
package cachetype
//go:generate mockery -all -inpkg
// RPC is an interface that an RPC client must implement. This is a helper
// interface that is implemented by the agent delegate so that Type
// implementations can request RPC access.
type RPC interface {
RPC(method string, args interface{}, reply interface{}) error
}

View File

@ -0,0 +1,60 @@
package cachetype
import (
"reflect"
"time"
"github.com/hashicorp/consul/agent/cache"
"github.com/mitchellh/go-testing-interface"
)
// TestRPC returns a mock implementation of the RPC interface.
func TestRPC(t testing.T) *MockRPC {
// This function is relatively useless but this allows us to perhaps
// perform some initialization later.
return &MockRPC{}
}
// TestFetchCh returns a channel that returns the result of the Fetch call.
// This is useful for testing timing and concurrency with Fetch calls.
// Errors will show up as an error type on the resulting channel so a
// type switch should be used.
func TestFetchCh(
t testing.T,
typ cache.Type,
opts cache.FetchOptions,
req cache.Request) <-chan interface{} {
resultCh := make(chan interface{})
go func() {
result, err := typ.Fetch(opts, req)
if err != nil {
resultCh <- err
return
}
resultCh <- result
}()
return resultCh
}
// TestFetchChResult tests that the result from TestFetchCh matches
// within a reasonable period of time (it expects it to be "immediate" but
// waits some milliseconds).
func TestFetchChResult(t testing.T, ch <-chan interface{}, expected interface{}) {
t.Helper()
select {
case result := <-ch:
if err, ok := result.(error); ok {
t.Fatalf("Result was error: %s", err)
return
}
if !reflect.DeepEqual(result, expected) {
t.Fatalf("Result doesn't match!\n\n%#v\n\n%#v", result, expected)
}
case <-time.After(50 * time.Millisecond):
}
}

536
agent/cache/cache.go vendored Normal file
View File

@ -0,0 +1,536 @@
// Package cache provides caching features for data from a Consul server.
//
// While this is similar in some ways to the "agent/ae" package, a key
// difference is that with anti-entropy, the agent is the authoritative
// source so it resolves differences the server may have. With caching (this
// package), the server is the authoritative source and we do our best to
// balance performance and correctness, depending on the type of data being
// requested.
//
// The types of data that can be cached is configurable via the Type interface.
// This allows specialized behavior for certain types of data. Each type of
// Consul data (CA roots, leaf certs, intentions, KV, catalog, etc.) will
// have to be manually implemented. This usually is not much work, see
// the "agent/cache-types" package.
package cache
import (
"container/heap"
"fmt"
"sync"
"time"
"github.com/armon/go-metrics"
)
//go:generate mockery -all -inpkg
// Constants related to refresh backoff. We probably don't ever need to
// make these configurable knobs since they primarily exist to lower load.
const (
CacheRefreshBackoffMin = 3 // 3 attempts before backing off
CacheRefreshMaxWait = 1 * time.Minute // maximum backoff wait time
)
// Cache is a agent-local cache of Consul data. Create a Cache using the
// New function. A zero-value Cache is not ready for usage and will result
// in a panic.
//
// The types of data to be cached must be registered via RegisterType. Then,
// calls to Get specify the type and a Request implementation. The
// implementation of Request is usually done directly on the standard RPC
// struct in agent/structs. This API makes cache usage a mostly drop-in
// replacement for non-cached RPC calls.
//
// The cache is partitioned by ACL and datacenter. This allows the cache
// to be safe for multi-DC queries and for queries where the data is modified
// due to ACLs all without the cache having to have any clever logic, at
// the slight expense of a less perfect cache.
//
// The Cache exposes various metrics via go-metrics. Please view the source
// searching for "metrics." to see the various metrics exposed. These can be
// used to explore the performance of the cache.
type Cache struct {
// types stores the list of data types that the cache knows how to service.
// These can be dynamically registered with RegisterType.
typesLock sync.RWMutex
types map[string]typeEntry
// entries contains the actual cache data. Access to entries and
// entriesExpiryHeap must be protected by entriesLock.
//
// entriesExpiryHeap is a heap of *cacheEntry values ordered by
// expiry, with the soonest to expire being first in the list (index 0).
//
// NOTE(mitchellh): The entry map key is currently a string in the format
// of "<DC>/<ACL token>/<Request key>" in order to properly partition
// requests to different datacenters and ACL tokens. This format has some
// big drawbacks: we can't evict by datacenter, ACL token, etc. For an
// initial implementation this works and the tests are agnostic to the
// internal storage format so changing this should be possible safely.
entriesLock sync.RWMutex
entries map[string]cacheEntry
entriesExpiryHeap *expiryHeap
}
// typeEntry is a single type that is registered with a Cache.
type typeEntry struct {
Type Type
Opts *RegisterOptions
}
// ResultMeta is returned from Get calls along with the value and can be used
// to expose information about the cache status for debugging or testing.
type ResultMeta struct {
// Return whether or not the request was a cache hit
Hit bool
}
// Options are options for the Cache.
type Options struct {
// Nothing currently, reserved.
}
// New creates a new cache with the given RPC client and reasonable defaults.
// Further settings can be tweaked on the returned value.
func New(*Options) *Cache {
// Initialize the heap. The buffer of 1 is really important because
// its possible for the expiry loop to trigger the heap to update
// itself and it'd block forever otherwise.
h := &expiryHeap{NotifyCh: make(chan struct{}, 1)}
heap.Init(h)
c := &Cache{
types: make(map[string]typeEntry),
entries: make(map[string]cacheEntry),
entriesExpiryHeap: h,
}
// Start the expiry watcher
go c.runExpiryLoop()
return c
}
// RegisterOptions are options that can be associated with a type being
// registered for the cache. This changes the behavior of the cache for
// this type.
type RegisterOptions struct {
// LastGetTTL is the time that the values returned by this type remain
// in the cache after the last get operation. If a value isn't accessed
// within this duration, the value is purged from the cache and
// background refreshing will cease.
LastGetTTL time.Duration
// Refresh configures whether the data is actively refreshed or if
// the data is only refreshed on an explicit Get. The default (false)
// is to only request data on explicit Get.
Refresh bool
// RefreshTimer is the time between attempting to refresh data.
// If this is zero, then data is refreshed immediately when a fetch
// is returned.
//
// RefreshTimeout determines the maximum query time for a refresh
// operation. This is specified as part of the query options and is
// expected to be implemented by the Type itself.
//
// Using these values, various "refresh" mechanisms can be implemented:
//
// * With a high timer duration and a low timeout, a timer-based
// refresh can be set that minimizes load on the Consul servers.
//
// * With a low timer and high timeout duration, a blocking-query-based
// refresh can be set so that changes in server data are recognized
// within the cache very quickly.
//
RefreshTimer time.Duration
RefreshTimeout time.Duration
}
// RegisterType registers a cacheable type.
//
// This makes the type available for Get but does not automatically perform
// any prefetching. In order to populate the cache, Get must be called.
func (c *Cache) RegisterType(n string, typ Type, opts *RegisterOptions) {
if opts == nil {
opts = &RegisterOptions{}
}
if opts.LastGetTTL == 0 {
opts.LastGetTTL = 72 * time.Hour // reasonable default is days
}
c.typesLock.Lock()
defer c.typesLock.Unlock()
c.types[n] = typeEntry{Type: typ, Opts: opts}
}
// Get loads the data for the given type and request. If data satisfying the
// minimum index is present in the cache, it is returned immediately. Otherwise,
// this will block until the data is available or the request timeout is
// reached.
//
// Multiple Get calls for the same Request (matching CacheKey value) will
// block on a single network request.
//
// The timeout specified by the Request will be the timeout on the cache
// Get, and does not correspond to the timeout of any background data
// fetching. If the timeout is reached before data satisfying the minimum
// index is retrieved, the last known value (maybe nil) is returned. No
// error is returned on timeout. This matches the behavior of Consul blocking
// queries.
func (c *Cache) Get(t string, r Request) (interface{}, ResultMeta, error) {
info := r.CacheInfo()
if info.Key == "" {
metrics.IncrCounter([]string{"consul", "cache", "bypass"}, 1)
// If no key is specified, then we do not cache this request.
// Pass directly through to the backend.
return c.fetchDirect(t, r)
}
// Get the actual key for our entry
key := c.entryKey(&info)
// First time through
first := true
// timeoutCh for watching our timeout
var timeoutCh <-chan time.Time
RETRY_GET:
// Get the current value
c.entriesLock.RLock()
entry, ok := c.entries[key]
c.entriesLock.RUnlock()
// If we have a current value and the index is greater than the
// currently stored index then we return that right away. If the
// index is zero and we have something in the cache we accept whatever
// we have.
if ok && entry.Valid {
if info.MinIndex == 0 || info.MinIndex < entry.Index {
meta := ResultMeta{}
if first {
metrics.IncrCounter([]string{"consul", "cache", t, "hit"}, 1)
meta.Hit = true
}
// Touch the expiration and fix the heap.
c.entriesLock.Lock()
entry.Expiry.Reset()
c.entriesExpiryHeap.Fix(entry.Expiry)
c.entriesLock.Unlock()
// We purposely do not return an error here since the cache
// only works with fetching values that either have a value
// or have an error, but not both. The Error may be non-nil
// in the entry because of this to note future fetch errors.
return entry.Value, meta, nil
}
}
// If this isn't our first time through and our last value has an error,
// then we return the error. This has the behavior that we don't sit in
// a retry loop getting the same error for the entire duration of the
// timeout. Instead, we make one effort to fetch a new value, and if
// there was an error, we return.
if !first && entry.Error != nil {
return entry.Value, ResultMeta{}, entry.Error
}
if first {
// We increment two different counters for cache misses depending on
// whether we're missing because we didn't have the data at all,
// or if we're missing because we're blocking on a set index.
if info.MinIndex == 0 {
metrics.IncrCounter([]string{"consul", "cache", t, "miss_new"}, 1)
} else {
metrics.IncrCounter([]string{"consul", "cache", t, "miss_block"}, 1)
}
}
// No longer our first time through
first = false
// Set our timeout channel if we must
if info.Timeout > 0 && timeoutCh == nil {
timeoutCh = time.After(info.Timeout)
}
// At this point, we know we either don't have a value at all or the
// value we have is too old. We need to wait for new data.
waiterCh, err := c.fetch(t, key, r, true, 0)
if err != nil {
return nil, ResultMeta{}, err
}
select {
case <-waiterCh:
// Our fetch returned, retry the get from the cache
goto RETRY_GET
case <-timeoutCh:
// Timeout on the cache read, just return whatever we have.
return entry.Value, ResultMeta{}, nil
}
}
// entryKey returns the key for the entry in the cache. See the note
// about the entry key format in the structure docs for Cache.
func (c *Cache) entryKey(r *RequestInfo) string {
return fmt.Sprintf("%s/%s/%s", r.Datacenter, r.Token, r.Key)
}
// fetch triggers a new background fetch for the given Request. If a
// background fetch is already running for a matching Request, the waiter
// channel for that request is returned. The effect of this is that there
// is only ever one blocking query for any matching requests.
//
// If allowNew is true then the fetch should create the cache entry
// if it doesn't exist. If this is false, then fetch will do nothing
// if the entry doesn't exist. This latter case is to support refreshing.
func (c *Cache) fetch(t, key string, r Request, allowNew bool, attempt uint) (<-chan struct{}, error) {
// Get the type that we're fetching
c.typesLock.RLock()
tEntry, ok := c.types[t]
c.typesLock.RUnlock()
if !ok {
return nil, fmt.Errorf("unknown type in cache: %s", t)
}
// We acquire a write lock because we may have to set Fetching to true.
c.entriesLock.Lock()
defer c.entriesLock.Unlock()
entry, ok := c.entries[key]
// If we aren't allowing new values and we don't have an existing value,
// return immediately. We return an immediately-closed channel so nothing
// blocks.
if !ok && !allowNew {
ch := make(chan struct{})
close(ch)
return ch, nil
}
// If we already have an entry and it is actively fetching, then return
// the currently active waiter.
if ok && entry.Fetching {
return entry.Waiter, nil
}
// If we don't have an entry, then create it. The entry must be marked
// as invalid so that it isn't returned as a valid value for a zero index.
if !ok {
entry = cacheEntry{Valid: false, Waiter: make(chan struct{})}
}
// Set that we're fetching to true, which makes it so that future
// identical calls to fetch will return the same waiter rather than
// perform multiple fetches.
entry.Fetching = true
c.entries[key] = entry
metrics.SetGauge([]string{"consul", "cache", "entries_count"}, float32(len(c.entries)))
// The actual Fetch must be performed in a goroutine.
go func() {
// Start building the new entry by blocking on the fetch.
result, err := tEntry.Type.Fetch(FetchOptions{
MinIndex: entry.Index,
Timeout: tEntry.Opts.RefreshTimeout,
}, r)
// Copy the existing entry to start.
newEntry := entry
newEntry.Fetching = false
if result.Value != nil {
// A new value was given, so we create a brand new entry.
newEntry.Value = result.Value
newEntry.Index = result.Index
if newEntry.Index < 1 {
// Less than one is invalid unless there was an error and in this case
// there wasn't since a value was returned. If a badly behaved RPC
// returns 0 when it has no data, we might get into a busy loop here. We
// set this to minimum of 1 which is safe because no valid user data can
// ever be written at raft index 1 due to the bootstrap process for
// raft. This insure that any subsequent background refresh request will
// always block, but allows the initial request to return immediately
// even if there is no data.
newEntry.Index = 1
}
// This is a valid entry with a result
newEntry.Valid = true
}
// Error handling
if err == nil {
metrics.IncrCounter([]string{"consul", "cache", "fetch_success"}, 1)
metrics.IncrCounter([]string{"consul", "cache", t, "fetch_success"}, 1)
if result.Index > 0 {
// Reset the attempts counter so we don't have any backoff
attempt = 0
} else {
// Result having a zero index is an implicit error case. There was no
// actual error but it implies the RPC found in index (nothing written
// yet for that type) but didn't take care to return safe "1" index. We
// don't want to actually treat it like an error by setting
// newEntry.Error to something non-nil, but we should guard against 100%
// CPU burn hot loops caused by that case which will never block but
// also won't backoff either. So we treat it as a failed attempt so that
// at least the failure backoff will save our CPU while still
// periodically refreshing so normal service can resume when the servers
// actually have something to return from the RPC. If we get in this
// state it can be considered a bug in the RPC implementation (to ever
// return a zero index) however since it can happen this is a safety net
// for the future.
attempt++
}
} else {
metrics.IncrCounter([]string{"consul", "cache", "fetch_error"}, 1)
metrics.IncrCounter([]string{"consul", "cache", t, "fetch_error"}, 1)
// Increment attempt counter
attempt++
// Always set the error. We don't override the value here because
// if Valid is true, then we can reuse the Value in the case a
// specific index isn't requested. However, for blocking queries,
// we want Error to be set so that we can return early with the
// error.
newEntry.Error = err
}
// Create a new waiter that will be used for the next fetch.
newEntry.Waiter = make(chan struct{})
// Set our entry
c.entriesLock.Lock()
// If this is a new entry (not in the heap yet), then setup the
// initial expiry information and insert. If we're already in
// the heap we do nothing since we're reusing the same entry.
if newEntry.Expiry == nil || newEntry.Expiry.HeapIndex == -1 {
newEntry.Expiry = &cacheEntryExpiry{
Key: key,
TTL: tEntry.Opts.LastGetTTL,
}
newEntry.Expiry.Reset()
heap.Push(c.entriesExpiryHeap, newEntry.Expiry)
}
c.entries[key] = newEntry
c.entriesLock.Unlock()
// Trigger the old waiter
close(entry.Waiter)
// If refresh is enabled, run the refresh in due time. The refresh
// below might block, but saves us from spawning another goroutine.
if tEntry.Opts.Refresh {
c.refresh(tEntry.Opts, attempt, t, key, r)
}
}()
return entry.Waiter, nil
}
// fetchDirect fetches the given request with no caching. Because this
// bypasses the caching entirely, multiple matching requests will result
// in multiple actual RPC calls (unlike fetch).
func (c *Cache) fetchDirect(t string, r Request) (interface{}, ResultMeta, error) {
// Get the type that we're fetching
c.typesLock.RLock()
tEntry, ok := c.types[t]
c.typesLock.RUnlock()
if !ok {
return nil, ResultMeta{}, fmt.Errorf("unknown type in cache: %s", t)
}
// Fetch it with the min index specified directly by the request.
result, err := tEntry.Type.Fetch(FetchOptions{
MinIndex: r.CacheInfo().MinIndex,
}, r)
if err != nil {
return nil, ResultMeta{}, err
}
// Return the result and ignore the rest
return result.Value, ResultMeta{}, nil
}
// refresh triggers a fetch for a specific Request according to the
// registration options.
func (c *Cache) refresh(opts *RegisterOptions, attempt uint, t string, key string, r Request) {
// Sanity-check, we should not schedule anything that has refresh disabled
if !opts.Refresh {
return
}
// If we're over the attempt minimum, start an exponential backoff.
if attempt > CacheRefreshBackoffMin {
waitTime := (1 << (attempt - CacheRefreshBackoffMin)) * time.Second
if waitTime > CacheRefreshMaxWait {
waitTime = CacheRefreshMaxWait
}
time.Sleep(waitTime)
}
// If we have a timer, wait for it
if opts.RefreshTimer > 0 {
time.Sleep(opts.RefreshTimer)
}
// Trigger. The "allowNew" field is false because in the time we were
// waiting to refresh we may have expired and got evicted. If that
// happened, we don't want to create a new entry.
c.fetch(t, key, r, false, attempt)
}
// runExpiryLoop is a blocking function that watches the expiration
// heap and invalidates entries that have expired.
func (c *Cache) runExpiryLoop() {
var expiryTimer *time.Timer
for {
// If we have a previous timer, stop it.
if expiryTimer != nil {
expiryTimer.Stop()
}
// Get the entry expiring soonest
var entry *cacheEntryExpiry
var expiryCh <-chan time.Time
c.entriesLock.RLock()
if len(c.entriesExpiryHeap.Entries) > 0 {
entry = c.entriesExpiryHeap.Entries[0]
expiryTimer = time.NewTimer(entry.Expires.Sub(time.Now()))
expiryCh = expiryTimer.C
}
c.entriesLock.RUnlock()
select {
case <-c.entriesExpiryHeap.NotifyCh:
// Entries changed, so the heap may have changed. Restart loop.
case <-expiryCh:
c.entriesLock.Lock()
// Entry expired! Remove it.
delete(c.entries, entry.Key)
heap.Remove(c.entriesExpiryHeap, entry.HeapIndex)
// This is subtle but important: if we race and simultaneously
// evict and fetch a new value, then we set this to -1 to
// have it treated as a new value so that the TTL is extended.
entry.HeapIndex = -1
// Set some metrics
metrics.IncrCounter([]string{"consul", "cache", "evict_expired"}, 1)
metrics.SetGauge([]string{"consul", "cache", "entries_count"}, float32(len(c.entries)))
c.entriesLock.Unlock()
}
}
}

760
agent/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,760 @@
package cache
import (
"fmt"
"sort"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// Test a basic Get with no indexes (and therefore no blocking queries).
func TestCacheGet_noIndex(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
typ.Static(FetchResult{Value: 42}, nil).Times(1)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Get, should not fetch since we already have a satisfying value
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.True(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test a basic Get with no index and a failed fetch.
func TestCacheGet_initError(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
fetcherr := fmt.Errorf("error")
typ.Static(FetchResult{}, fetcherr).Times(2)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.Error(err)
require.Nil(result)
require.False(meta.Hit)
// Get, should fetch again since our last fetch was an error
result, meta, err = c.Get("t", req)
require.Error(err)
require.Nil(result)
require.False(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test a Get with a request that returns a blank cache key. This should
// force a backend request and skip the cache entirely.
func TestCacheGet_blankCacheKey(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
typ.Static(FetchResult{Value: 42}, nil).Times(2)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: ""})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Get, should not fetch since we already have a satisfying value
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test that Get blocks on the initial value
func TestCacheGet_blockingInitSameKey(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
triggerCh := make(chan time.Time)
typ.Static(FetchResult{Value: 42}, nil).WaitUntil(triggerCh).Times(1)
// Perform multiple gets
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
// They should block
select {
case <-getCh1:
t.Fatal("should block (ch1)")
case <-getCh2:
t.Fatal("should block (ch2)")
case <-time.After(50 * time.Millisecond):
}
// Trigger it
close(triggerCh)
// Should return
TestCacheGetChResult(t, getCh1, 42)
TestCacheGetChResult(t, getCh2, 42)
}
// Test that Get with different cache keys both block on initial value
// but that the fetches were both properly called.
func TestCacheGet_blockingInitDiffKeys(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Keep track of the keys
var keysLock sync.Mutex
var keys []string
// Configure the type
triggerCh := make(chan time.Time)
typ.Static(FetchResult{Value: 42}, nil).
WaitUntil(triggerCh).
Times(2).
Run(func(args mock.Arguments) {
keysLock.Lock()
defer keysLock.Unlock()
keys = append(keys, args.Get(1).(Request).CacheInfo().Key)
})
// Perform multiple gets
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "goodbye"}))
// They should block
select {
case <-getCh1:
t.Fatal("should block (ch1)")
case <-getCh2:
t.Fatal("should block (ch2)")
case <-time.After(50 * time.Millisecond):
}
// Trigger it
close(triggerCh)
// Should return both!
TestCacheGetChResult(t, getCh1, 42)
TestCacheGetChResult(t, getCh2, 42)
// Verify proper keys
sort.Strings(keys)
require.Equal([]string{"goodbye", "hello"}, keys)
}
// Test a get with an index set will wait until an index that is higher
// is set in the cache.
func TestCacheGet_blockingIndex(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
triggerCh := make(chan time.Time)
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
typ.Static(FetchResult{Value: 42, Index: 6}, nil).WaitUntil(triggerCh)
// Fetch should block
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 5}))
// Should block
select {
case <-resultCh:
t.Fatal("should block")
case <-time.After(50 * time.Millisecond):
}
// Wait a bit
close(triggerCh)
// Should return
TestCacheGetChResult(t, resultCh, 42)
}
// Test a get with an index set will timeout if the fetch doesn't return
// anything.
func TestCacheGet_blockingIndexTimeout(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
triggerCh := make(chan time.Time)
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
typ.Static(FetchResult{Value: 42, Index: 6}, nil).WaitUntil(triggerCh)
// Fetch should block
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 5, Timeout: 200 * time.Millisecond}))
// Should block
select {
case <-resultCh:
t.Fatal("should block")
case <-time.After(50 * time.Millisecond):
}
// Should return after more of the timeout
select {
case result := <-resultCh:
require.Equal(t, 12, result)
case <-time.After(300 * time.Millisecond):
t.Fatal("should've returned")
}
}
// Test a get with an index set with requests returning an error
// will return that error.
func TestCacheGet_blockingIndexError(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
var retries uint32
fetchErr := fmt.Errorf("test fetch error")
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: nil, Index: 5}, fetchErr).Run(func(args mock.Arguments) {
atomic.AddUint32(&retries, 1)
})
// First good fetch to populate catch
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Fetch should not block and should return error
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 7, Timeout: 1 * time.Minute}))
TestCacheGetChResult(t, resultCh, nil)
// Wait a bit
time.Sleep(100 * time.Millisecond)
// Check the number
actual := atomic.LoadUint32(&retries)
require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
}
// Test that if a Type returns an empty value on Fetch that the previous
// value is preserved.
func TestCacheGet_emptyFetchResult(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, nil)
// Configure the type
typ.Static(FetchResult{Value: 42, Index: 1}, nil).Times(1)
typ.Static(FetchResult{Value: nil}, nil)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Get, should not fetch since we already have a satisfying value
req = TestRequest(t, RequestInfo{
Key: "hello", MinIndex: 1, Timeout: 100 * time.Millisecond})
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test that a type registered with a periodic refresh will perform
// that refresh after the timer is up.
func TestCacheGet_periodicRefresh(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, &RegisterOptions{
Refresh: true,
RefreshTimer: 100 * time.Millisecond,
RefreshTimeout: 5 * time.Minute,
})
// This is a bit weird, but we do this to ensure that the final
// call to the Fetch (if it happens, depends on timing) just blocks.
triggerCh := make(chan time.Time)
defer close(triggerCh)
// Configure the type
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
typ.Static(FetchResult{Value: 12, Index: 5}, nil).WaitUntil(triggerCh)
// Fetch should block
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Fetch again almost immediately should return old result
time.Sleep(5 * time.Millisecond)
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Wait for the timer
time.Sleep(200 * time.Millisecond)
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 12)
}
// Test that a type registered with a periodic refresh will perform
// that refresh after the timer is up.
func TestCacheGet_periodicRefreshMultiple(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, &RegisterOptions{
Refresh: true,
RefreshTimer: 0 * time.Millisecond,
RefreshTimeout: 5 * time.Minute,
})
// This is a bit weird, but we do this to ensure that the final
// call to the Fetch (if it happens, depends on timing) just blocks.
trigger := make([]chan time.Time, 3)
for i := range trigger {
trigger[i] = make(chan time.Time)
}
// Configure the type
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once().WaitUntil(trigger[0])
typ.Static(FetchResult{Value: 24, Index: 6}, nil).Once().WaitUntil(trigger[1])
typ.Static(FetchResult{Value: 42, Index: 7}, nil).WaitUntil(trigger[2])
// Fetch should block
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Fetch again almost immediately should return old result
time.Sleep(5 * time.Millisecond)
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Trigger the next, sleep a bit, and verify we get the next result
close(trigger[0])
time.Sleep(100 * time.Millisecond)
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 12)
// Trigger the next, sleep a bit, and verify we get the next result
close(trigger[1])
time.Sleep(100 * time.Millisecond)
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 24)
}
// Test that a refresh performs a backoff.
func TestCacheGet_periodicRefreshErrorBackoff(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, &RegisterOptions{
Refresh: true,
RefreshTimer: 0,
RefreshTimeout: 5 * time.Minute,
})
// Configure the type
var retries uint32
fetchErr := fmt.Errorf("test fetch error")
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
typ.Static(FetchResult{Value: nil, Index: 5}, fetchErr).Run(func(args mock.Arguments) {
atomic.AddUint32(&retries, 1)
})
// Fetch
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Sleep a bit. The refresh will quietly fail in the background. What we
// want to verify is that it doesn't retry too much. "Too much" is hard
// to measure since its CPU dependent if this test is failing. But due
// to the short sleep below, we can calculate about what we'd expect if
// backoff IS working.
time.Sleep(500 * time.Millisecond)
// Fetch should work, we should get a 1 still. Errors are ignored.
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 1)
// Check the number
actual := atomic.LoadUint32(&retries)
require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
}
// Test that a badly behaved RPC that returns 0 index will perform a backoff.
func TestCacheGet_periodicRefreshBadRPCZeroIndexErrorBackoff(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, &RegisterOptions{
Refresh: true,
RefreshTimer: 0,
RefreshTimeout: 5 * time.Minute,
})
// Configure the type
var retries uint32
typ.Static(FetchResult{Value: 0, Index: 0}, nil).Run(func(args mock.Arguments) {
atomic.AddUint32(&retries, 1)
})
// Fetch
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 0)
// Sleep a bit. The refresh will quietly fail in the background. What we
// want to verify is that it doesn't retry too much. "Too much" is hard
// to measure since its CPU dependent if this test is failing. But due
// to the short sleep below, we can calculate about what we'd expect if
// backoff IS working.
time.Sleep(500 * time.Millisecond)
// Fetch should work, we should get a 0 still. Errors are ignored.
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 0)
// Check the number
actual := atomic.LoadUint32(&retries)
require.True(t, actual < 10, fmt.Sprintf("%d retries, should be < 10", actual))
}
// Test that fetching with no index makes an initial request with no index, but
// then ensures all background refreshes have > 0. This ensures we don't end up
// with any index 0 loops from background refreshed while also returning
// immediately on the initial request if there is no data written to that table
// yet.
func TestCacheGet_noIndexSetsOne(t *testing.T) {
t.Parallel()
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
c.RegisterType("t", typ, &RegisterOptions{
Refresh: true,
RefreshTimer: 0,
RefreshTimeout: 5 * time.Minute,
})
// Simulate "well behaved" RPC with no data yet but returning 1
{
first := int32(1)
typ.Static(FetchResult{Value: 0, Index: 1}, nil).Run(func(args mock.Arguments) {
opts := args.Get(0).(FetchOptions)
isFirst := atomic.SwapInt32(&first, 0)
if isFirst == 1 {
assert.Equal(t, uint64(0), opts.MinIndex)
} else {
assert.True(t, opts.MinIndex > 0, "minIndex > 0")
}
})
// Fetch
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 0)
// Sleep a bit so background refresh happens
time.Sleep(100 * time.Millisecond)
}
// Same for "badly behaved" RPC that returns 0 index and no data
{
first := int32(1)
typ.Static(FetchResult{Value: 0, Index: 0}, nil).Run(func(args mock.Arguments) {
opts := args.Get(0).(FetchOptions)
isFirst := atomic.SwapInt32(&first, 0)
if isFirst == 1 {
assert.Equal(t, uint64(0), opts.MinIndex)
} else {
assert.True(t, opts.MinIndex > 0, "minIndex > 0")
}
})
// Fetch
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
TestCacheGetChResult(t, resultCh, 0)
// Sleep a bit so background refresh happens
time.Sleep(100 * time.Millisecond)
}
}
// Test that the backend fetch sets the proper timeout.
func TestCacheGet_fetchTimeout(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
// Register the type with a timeout
timeout := 10 * time.Minute
c.RegisterType("t", typ, &RegisterOptions{
RefreshTimeout: timeout,
})
// Configure the type
var actual time.Duration
typ.Static(FetchResult{Value: 42}, nil).Times(1).Run(func(args mock.Arguments) {
opts := args.Get(0).(FetchOptions)
actual = opts.Timeout
})
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Test the timeout
require.Equal(timeout, actual)
}
// Test that entries expire
func TestCacheGet_expire(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
// Register the type with a timeout
c.RegisterType("t", typ, &RegisterOptions{
LastGetTTL: 400 * time.Millisecond,
})
// Configure the type
typ.Static(FetchResult{Value: 42}, nil).Times(2)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Get, should not fetch, verified via the mock assertions above
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.True(meta.Hit)
// Sleep for the expiry
time.Sleep(500 * time.Millisecond)
// Get, should fetch
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test that entries reset their TTL on Get
func TestCacheGet_expireResetGet(t *testing.T) {
t.Parallel()
require := require.New(t)
typ := TestType(t)
defer typ.AssertExpectations(t)
c := TestCache(t)
// Register the type with a timeout
c.RegisterType("t", typ, &RegisterOptions{
LastGetTTL: 150 * time.Millisecond,
})
// Configure the type
typ.Static(FetchResult{Value: 42}, nil).Times(2)
// Get, should fetch
req := TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err := c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Fetch multiple times, where the total time is well beyond
// the TTL. We should not trigger any fetches during this time.
for i := 0; i < 5; i++ {
// Sleep a bit
time.Sleep(50 * time.Millisecond)
// Get, should not fetch
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.True(meta.Hit)
}
time.Sleep(200 * time.Millisecond)
// Get, should fetch
req = TestRequest(t, RequestInfo{Key: "hello"})
result, meta, err = c.Get("t", req)
require.NoError(err)
require.Equal(42, result)
require.False(meta.Hit)
// Sleep a tiny bit just to let maybe some background calls happen
// then verify that we still only got the one call
time.Sleep(20 * time.Millisecond)
typ.AssertExpectations(t)
}
// Test that Get partitions the caches based on DC so two equivalent requests
// to different datacenters are automatically cached even if their keys are
// the same.
func TestCacheGet_partitionDC(t *testing.T) {
t.Parallel()
c := TestCache(t)
c.RegisterType("t", &testPartitionType{}, nil)
// Perform multiple gets
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Datacenter: "dc1", Key: "hello"}))
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Datacenter: "dc9", Key: "hello"}))
// Should return both!
TestCacheGetChResult(t, getCh1, "dc1")
TestCacheGetChResult(t, getCh2, "dc9")
}
// Test that Get partitions the caches based on token so two equivalent requests
// with different ACL tokens do not return the same result.
func TestCacheGet_partitionToken(t *testing.T) {
t.Parallel()
c := TestCache(t)
c.RegisterType("t", &testPartitionType{}, nil)
// Perform multiple gets
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Token: "", Key: "hello"}))
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
Token: "foo", Key: "hello"}))
// Should return both!
TestCacheGetChResult(t, getCh1, "")
TestCacheGetChResult(t, getCh2, "foo")
}
// testPartitionType implements Type for testing that simply returns a value
// comprised of the request DC and ACL token, used for testing cache
// partitioning.
type testPartitionType struct{}
func (t *testPartitionType) Fetch(opts FetchOptions, r Request) (FetchResult, error) {
info := r.CacheInfo()
return FetchResult{
Value: fmt.Sprintf("%s%s", info.Datacenter, info.Token),
}, nil
}

143
agent/cache/entry.go vendored Normal file
View File

@ -0,0 +1,143 @@
package cache
import (
"container/heap"
"time"
)
// cacheEntry stores a single cache entry.
//
// Note that this isn't a very optimized structure currently. There are
// a lot of improvements that can be made here in the long term.
type cacheEntry struct {
// Fields pertaining to the actual value
Value interface{}
Error error
Index uint64
// Metadata that is used for internal accounting
Valid bool // True if the Value is set
Fetching bool // True if a fetch is already active
Waiter chan struct{} // Closed when this entry is invalidated
// Expiry contains information about the expiration of this
// entry. This is a pointer as its shared as a value in the
// expiryHeap as well.
Expiry *cacheEntryExpiry
}
// cacheEntryExpiry contains the expiration information for a cache
// entry. Any modifications to this struct should be done only while
// the Cache entriesLock is held.
type cacheEntryExpiry struct {
Key string // Key in the cache map
Expires time.Time // Time when entry expires (monotonic clock)
TTL time.Duration // TTL for this entry to extend when resetting
HeapIndex int // Index in the heap
}
// Reset resets the expiration to be the ttl duration from now.
func (e *cacheEntryExpiry) Reset() {
e.Expires = time.Now().Add(e.TTL)
}
// expiryHeap is a heap implementation that stores information about
// when entires expire. Implements container/heap.Interface.
//
// All operations on the heap and read/write of the heap contents require
// the proper entriesLock to be held on Cache.
type expiryHeap struct {
Entries []*cacheEntryExpiry
// NotifyCh is sent a value whenever the 0 index value of the heap
// changes. This can be used to detect when the earliest value
// changes.
//
// There is a single edge case where the heap will not automatically
// send a notification: if heap.Fix is called manually and the index
// changed is 0 and the change doesn't result in any moves (stays at index
// 0), then we won't detect the change. To work around this, please
// always call the expiryHeap.Fix method instead.
NotifyCh chan struct{}
}
// Identical to heap.Fix for this heap instance but will properly handle
// the edge case where idx == 0 and no heap modification is necessary,
// and still notify the NotifyCh.
//
// This is important for cache expiry since the expiry time may have been
// extended and if we don't send a message to the NotifyCh then we'll never
// reset the timer and the entry will be evicted early.
func (h *expiryHeap) Fix(entry *cacheEntryExpiry) {
idx := entry.HeapIndex
heap.Fix(h, idx)
// This is the edge case we handle: if the prev (idx) and current (HeapIndex)
// is zero, it means the head-of-line didn't change while the value
// changed. Notify to reset our expiry worker.
if idx == 0 && entry.HeapIndex == 0 {
h.notify()
}
}
func (h *expiryHeap) Len() int { return len(h.Entries) }
func (h *expiryHeap) Swap(i, j int) {
h.Entries[i], h.Entries[j] = h.Entries[j], h.Entries[i]
h.Entries[i].HeapIndex = i
h.Entries[j].HeapIndex = j
// If we're moving the 0 index, update the channel since we need
// to re-update the timer we're waiting on for the soonest expiring
// value.
if i == 0 || j == 0 {
h.notify()
}
}
func (h *expiryHeap) Less(i, j int) bool {
// The usage of Before here is important (despite being obvious):
// this function uses the monotonic time that should be available
// on the time.Time value so the heap is immune to wall clock changes.
return h.Entries[i].Expires.Before(h.Entries[j].Expires)
}
// heap.Interface, this isn't expected to be called directly.
func (h *expiryHeap) Push(x interface{}) {
entry := x.(*cacheEntryExpiry)
// Set initial heap index, if we're going to the end then Swap
// won't be called so we need to initialize
entry.HeapIndex = len(h.Entries)
// For the first entry, we need to trigger a channel send because
// Swap won't be called; nothing to swap! We can call it right away
// because all heap operations are within a lock.
if len(h.Entries) == 0 {
h.notify()
}
h.Entries = append(h.Entries, entry)
}
// heap.Interface, this isn't expected to be called directly.
func (h *expiryHeap) Pop() interface{} {
old := h.Entries
n := len(old)
x := old[n-1]
h.Entries = old[0 : n-1]
return x
}
func (h *expiryHeap) notify() {
select {
case h.NotifyCh <- struct{}{}:
// Good
default:
// If the send would've blocked, we just ignore it. The reason this
// is safe is because NotifyCh should always be a buffered channel.
// If this blocks, it means that there is a pending message anyways
// so the receiver will restart regardless.
}
}

91
agent/cache/entry_test.go vendored Normal file
View File

@ -0,0 +1,91 @@
package cache
import (
"container/heap"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestExpiryHeap_impl(t *testing.T) {
var _ heap.Interface = new(expiryHeap)
}
func TestExpiryHeap(t *testing.T) {
require := require.New(t)
now := time.Now()
ch := make(chan struct{}, 10) // buffered to prevent blocking in tests
h := &expiryHeap{NotifyCh: ch}
// Init, shouldn't trigger anything
heap.Init(h)
testNoMessage(t, ch)
// Push an initial value, expect one message
entry := &cacheEntryExpiry{Key: "foo", HeapIndex: -1, Expires: now.Add(100)}
heap.Push(h, entry)
require.Equal(0, entry.HeapIndex)
testMessage(t, ch)
testNoMessage(t, ch) // exactly one asserted above
// Push another that goes earlier than entry
entry2 := &cacheEntryExpiry{Key: "bar", HeapIndex: -1, Expires: now.Add(50)}
heap.Push(h, entry2)
require.Equal(0, entry2.HeapIndex)
require.Equal(1, entry.HeapIndex)
testMessage(t, ch)
testNoMessage(t, ch) // exactly one asserted above
// Push another that goes at the end
entry3 := &cacheEntryExpiry{Key: "bar", HeapIndex: -1, Expires: now.Add(1000)}
heap.Push(h, entry3)
require.Equal(2, entry3.HeapIndex)
testNoMessage(t, ch) // no notify cause index 0 stayed the same
// Remove the first entry (not Pop, since we don't use Pop, but that works too)
remove := h.Entries[0]
heap.Remove(h, remove.HeapIndex)
require.Equal(0, entry.HeapIndex)
require.Equal(1, entry3.HeapIndex)
testMessage(t, ch)
testMessage(t, ch) // we have two because two swaps happen
testNoMessage(t, ch)
// Let's change entry 3 to be early, and fix it
entry3.Expires = now.Add(10)
h.Fix(entry3)
require.Equal(1, entry.HeapIndex)
require.Equal(0, entry3.HeapIndex)
testMessage(t, ch)
testNoMessage(t, ch)
// Let's change entry 3 again, this is an edge case where if the 0th
// element changed, we didn't trigger the channel. Our Fix func should.
entry.Expires = now.Add(20)
h.Fix(entry3)
require.Equal(1, entry.HeapIndex) // no move
require.Equal(0, entry3.HeapIndex)
testMessage(t, ch)
testNoMessage(t, ch) // one message
}
func testNoMessage(t *testing.T, ch <-chan struct{}) {
t.Helper()
select {
case <-ch:
t.Fatal("should not have a message")
default:
}
}
func testMessage(t *testing.T, ch <-chan struct{}) {
t.Helper()
select {
case <-ch:
default:
t.Fatal("should have a message")
}
}

23
agent/cache/mock_Request.go vendored Normal file
View File

@ -0,0 +1,23 @@
// Code generated by mockery v1.0.0
package cache
import mock "github.com/stretchr/testify/mock"
// MockRequest is an autogenerated mock type for the Request type
type MockRequest struct {
mock.Mock
}
// CacheInfo provides a mock function with given fields:
func (_m *MockRequest) CacheInfo() RequestInfo {
ret := _m.Called()
var r0 RequestInfo
if rf, ok := ret.Get(0).(func() RequestInfo); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(RequestInfo)
}
return r0
}

30
agent/cache/mock_Type.go vendored Normal file
View File

@ -0,0 +1,30 @@
// Code generated by mockery v1.0.0
package cache
import mock "github.com/stretchr/testify/mock"
// MockType is an autogenerated mock type for the Type type
type MockType struct {
mock.Mock
}
// Fetch provides a mock function with given fields: _a0, _a1
func (_m *MockType) Fetch(_a0 FetchOptions, _a1 Request) (FetchResult, error) {
ret := _m.Called(_a0, _a1)
var r0 FetchResult
if rf, ok := ret.Get(0).(func(FetchOptions, Request) FetchResult); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Get(0).(FetchResult)
}
var r1 error
if rf, ok := ret.Get(1).(func(FetchOptions, Request) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

51
agent/cache/request.go vendored Normal file
View File

@ -0,0 +1,51 @@
package cache
import (
"time"
)
// Request is a cacheable request.
//
// This interface is typically implemented by request structures in
// the agent/structs package.
type Request interface {
// CacheInfo returns information used for caching this request.
CacheInfo() RequestInfo
}
// RequestInfo represents cache information for a request. The caching
// framework uses this to control the behavior of caching and to determine
// cacheability.
type RequestInfo struct {
// Key is a unique cache key for this request. This key should
// be globally unique to identify this request, since any conflicting
// cache keys could result in invalid data being returned from the cache.
// The Key does not need to include ACL or DC information, since the
// cache already partitions by these values prior to using this key.
Key string
// Token is the ACL token associated with this request.
//
// Datacenter is the datacenter that the request is targeting.
//
// Both of these values are used to partition the cache. The cache framework
// today partitions data on these values to simplify behavior: by
// partitioning ACL tokens, the cache doesn't need to be smart about
// filtering results. By filtering datacenter results, the cache can
// service the multi-DC nature of Consul. This comes at the expense of
// working set size, but in general the effect is minimal.
Token string
Datacenter string
// MinIndex is the minimum index being queried. This is used to
// determine if we already have data satisfying the query or if we need
// to block until new data is available. If no index is available, the
// default value (zero) is acceptable.
MinIndex uint64
// Timeout is the timeout for waiting on a blocking query. When the
// timeout is reached, the last known value is returned (or maybe nil
// if there was no prior value). This "last known value" behavior matches
// normal Consul blocking queries.
Timeout time.Duration
}

78
agent/cache/testing.go vendored Normal file
View File

@ -0,0 +1,78 @@
package cache
import (
"reflect"
"time"
"github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/mock"
)
// TestCache returns a Cache instance configuring for testing.
func TestCache(t testing.T) *Cache {
// Simple but lets us do some fine-tuning later if we want to.
return New(nil)
}
// TestCacheGetCh returns a channel that returns the result of the Get call.
// This is useful for testing timing and concurrency with Get calls. Any
// error will be logged, so the result value should always be asserted.
func TestCacheGetCh(t testing.T, c *Cache, typ string, r Request) <-chan interface{} {
resultCh := make(chan interface{})
go func() {
result, _, err := c.Get(typ, r)
if err != nil {
t.Logf("Error: %s", err)
close(resultCh)
return
}
resultCh <- result
}()
return resultCh
}
// TestCacheGetChResult tests that the result from TestCacheGetCh matches
// within a reasonable period of time (it expects it to be "immediate" but
// waits some milliseconds).
func TestCacheGetChResult(t testing.T, ch <-chan interface{}, expected interface{}) {
t.Helper()
select {
case result := <-ch:
if !reflect.DeepEqual(result, expected) {
t.Fatalf("Result doesn't match!\n\n%#v\n\n%#v", result, expected)
}
case <-time.After(50 * time.Millisecond):
t.Fatalf("Result not sent on channel")
}
}
// TestRequest returns a Request that returns the given cache key and index.
// The Reset method can be called to reset it for custom usage.
func TestRequest(t testing.T, info RequestInfo) *MockRequest {
req := &MockRequest{}
req.On("CacheInfo").Return(info)
return req
}
// TestType returns a MockType that can be used to setup expectations
// on data fetching.
func TestType(t testing.T) *MockType {
typ := &MockType{}
return typ
}
// A bit weird, but we add methods to the auto-generated structs here so that
// they don't get clobbered. The helper methods are conveniences.
// Static sets a static value to return for a call to Fetch.
func (m *MockType) Static(r FetchResult, err error) *mock.Call {
return m.Mock.On("Fetch", mock.Anything, mock.Anything).Return(r, err)
}
func (m *MockRequest) Reset() {
m.Mock = mock.Mock{}
}

48
agent/cache/type.go vendored Normal file
View File

@ -0,0 +1,48 @@
package cache
import (
"time"
)
// Type implements the logic to fetch certain types of data.
type Type interface {
// Fetch fetches a single unique item.
//
// The FetchOptions contain the index and timeouts for blocking queries.
// The MinIndex value on the Request itself should NOT be used
// as the blocking index since a request may be reused multiple times
// as part of Refresh behavior.
//
// The return value is a FetchResult which contains information about
// the fetch. If an error is given, the FetchResult is ignored. The
// cache does not support backends that return partial values.
//
// On timeout, FetchResult can behave one of two ways. First, it can
// return the last known value. This is the default behavior of blocking
// RPC calls in Consul so this allows cache types to be implemented with
// no extra logic. Second, FetchResult can return an unset value and index.
// In this case, the cache will reuse the last value automatically.
Fetch(FetchOptions, Request) (FetchResult, error)
}
// FetchOptions are various settable options when a Fetch is called.
type FetchOptions struct {
// MinIndex is the minimum index to be used for blocking queries.
// If blocking queries aren't supported for data being returned,
// this value can be ignored.
MinIndex uint64
// Timeout is the maximum time for the query. This must be implemented
// in the Fetch itself.
Timeout time.Duration
}
// FetchResult is the result of a Type Fetch operation and contains the
// data along with metadata gathered from that operation.
type FetchResult struct {
// Value is the result of the fetch.
Value interface{}
// Index is the corresponding index value for this data.
Index uint64
}

View File

@ -157,12 +157,27 @@ RETRY_ONCE:
return out.Services, nil
}
func (s *HTTPServer) CatalogConnectServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return s.catalogServiceNodes(resp, req, true)
}
func (s *HTTPServer) CatalogServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_service_nodes"}, 1,
return s.catalogServiceNodes(resp, req, false)
}
func (s *HTTPServer) catalogServiceNodes(resp http.ResponseWriter, req *http.Request, connect bool) (interface{}, error) {
metricsKey := "catalog_service_nodes"
pathPrefix := "/v1/catalog/service/"
if connect {
metricsKey = "catalog_connect_service_nodes"
pathPrefix = "/v1/catalog/connect/"
}
metrics.IncrCounterWithLabels([]string{"client", "api", metricsKey}, 1,
[]metrics.Label{{Name: "node", Value: s.nodeName()}})
// Set default DC
args := structs.ServiceSpecificRequest{}
args := structs.ServiceSpecificRequest{Connect: connect}
s.parseSource(req, &args.Source)
args.NodeMetaFilters = s.parseMetaFilter(req)
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
@ -177,7 +192,7 @@ func (s *HTTPServer) CatalogServiceNodes(resp http.ResponseWriter, req *http.Req
}
// Pull out the service name
args.ServiceName = strings.TrimPrefix(req.URL.Path, "/v1/catalog/service/")
args.ServiceName = strings.TrimPrefix(req.URL.Path, pathPrefix)
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service name")

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/testutil/retry"
"github.com/hashicorp/serf/coordinate"
"github.com/stretchr/testify/assert"
)
func TestCatalogRegister_Service_InvalidAddress(t *testing.T) {
@ -750,6 +751,60 @@ func TestCatalogServiceNodes_DistanceSort(t *testing.T) {
}
}
// Test that connect proxies can be queried via /v1/catalog/service/:service
// directly and that their results contain the proxy fields.
func TestCatalogServiceNodes_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/service/%s", args.Service.Service), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1)
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
}
// Test that the Connect-compatible endpoints can be queried for a
// service via /v1/catalog/connect/:service.
func TestCatalogConnectServiceNodes_good(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
args := structs.TestRegisterRequestProxy(t)
args.Service.Address = "127.0.0.55"
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/connect/%s", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogConnectServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
nodes := obj.(structs.ServiceNodes)
assert.Len(nodes, 1)
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
assert.Equal(args.Service.Address, nodes[0].ServiceAddress)
}
func TestCatalogNodeServices(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
@ -785,6 +840,33 @@ func TestCatalogNodeServices(t *testing.T) {
}
}
// Test that the services on a node contain all the Connect proxies on
// the node as well with their fields properly populated.
func TestCatalogNodeServices_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/catalog/node/%s", args.Node), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogNodeServices(resp, req)
assert.Nil(err)
assertIndex(t, resp)
ns := obj.(*structs.NodeServices)
assert.Len(ns.Services, 1)
v := ns.Services[args.Service.Service]
assert.Equal(structs.ServiceKindConnectProxy, v.Kind)
}
func TestCatalogNodeServices_WanTranslation(t *testing.T) {
t.Parallel()
a1 := NewTestAgent(t.Name(), `

View File

@ -14,9 +14,11 @@ import (
"strings"
"time"
"github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/consul"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types"
multierror "github.com/hashicorp/go-multierror"
@ -340,6 +342,12 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
serverPort := b.portVal("ports.server", c.Ports.Server)
serfPortLAN := b.portVal("ports.serf_lan", c.Ports.SerfLAN)
serfPortWAN := b.portVal("ports.serf_wan", c.Ports.SerfWAN)
proxyMinPort := b.portVal("ports.proxy_min_port", c.Ports.ProxyMinPort)
proxyMaxPort := b.portVal("ports.proxy_max_port", c.Ports.ProxyMaxPort)
if proxyMaxPort < proxyMinPort {
return RuntimeConfig{}, fmt.Errorf(
"proxy_min_port must be less than proxy_max_port. To disable, set both to zero.")
}
// determine the default bind and advertise address
//
@ -520,6 +528,30 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
consulRaftHeartbeatTimeout := b.durationVal("consul.raft.heartbeat_timeout", c.Consul.Raft.HeartbeatTimeout) * time.Duration(performanceRaftMultiplier)
consulRaftLeaderLeaseTimeout := b.durationVal("consul.raft.leader_lease_timeout", c.Consul.Raft.LeaderLeaseTimeout) * time.Duration(performanceRaftMultiplier)
// Connect proxy defaults.
connectEnabled := b.boolVal(c.Connect.Enabled)
connectCAProvider := b.stringVal(c.Connect.CAProvider)
connectCAConfig := c.Connect.CAConfig
if connectCAConfig != nil {
TranslateKeys(connectCAConfig, map[string]string{
// Consul CA config
"private_key": "PrivateKey",
"root_cert": "RootCert",
"rotation_period": "RotationPeriod",
// Vault CA config
"address": "Address",
"token": "Token",
"root_pki_path": "RootPKIPath",
"intermediate_pki_path": "IntermediatePKIPath",
})
}
proxyDefaultExecMode := b.stringVal(c.Connect.ProxyDefaults.ExecMode)
proxyDefaultDaemonCommand := c.Connect.ProxyDefaults.DaemonCommand
proxyDefaultScriptCommand := c.Connect.ProxyDefaults.ScriptCommand
proxyDefaultConfig := c.Connect.ProxyDefaults.Config
// ----------------------------------------------------------------
// build runtime config
//
@ -603,29 +635,31 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
HTTPResponseHeaders: c.HTTPConfig.ResponseHeaders,
// Telemetry
TelemetryCirconusAPIApp: b.stringVal(c.Telemetry.CirconusAPIApp),
TelemetryCirconusAPIToken: b.stringVal(c.Telemetry.CirconusAPIToken),
TelemetryCirconusAPIURL: b.stringVal(c.Telemetry.CirconusAPIURL),
TelemetryCirconusBrokerID: b.stringVal(c.Telemetry.CirconusBrokerID),
TelemetryCirconusBrokerSelectTag: b.stringVal(c.Telemetry.CirconusBrokerSelectTag),
TelemetryCirconusCheckDisplayName: b.stringVal(c.Telemetry.CirconusCheckDisplayName),
TelemetryCirconusCheckForceMetricActivation: b.stringVal(c.Telemetry.CirconusCheckForceMetricActivation),
TelemetryCirconusCheckID: b.stringVal(c.Telemetry.CirconusCheckID),
TelemetryCirconusCheckInstanceID: b.stringVal(c.Telemetry.CirconusCheckInstanceID),
TelemetryCirconusCheckSearchTag: b.stringVal(c.Telemetry.CirconusCheckSearchTag),
TelemetryCirconusCheckTags: b.stringVal(c.Telemetry.CirconusCheckTags),
TelemetryCirconusSubmissionInterval: b.stringVal(c.Telemetry.CirconusSubmissionInterval),
TelemetryCirconusSubmissionURL: b.stringVal(c.Telemetry.CirconusSubmissionURL),
TelemetryDisableHostname: b.boolVal(c.Telemetry.DisableHostname),
TelemetryDogstatsdAddr: b.stringVal(c.Telemetry.DogstatsdAddr),
TelemetryDogstatsdTags: c.Telemetry.DogstatsdTags,
TelemetryPrometheusRetentionTime: b.durationVal("prometheus_retention_time", c.Telemetry.PrometheusRetentionTime),
TelemetryFilterDefault: b.boolVal(c.Telemetry.FilterDefault),
TelemetryAllowedPrefixes: telemetryAllowedPrefixes,
TelemetryBlockedPrefixes: telemetryBlockedPrefixes,
TelemetryMetricsPrefix: b.stringVal(c.Telemetry.MetricsPrefix),
TelemetryStatsdAddr: b.stringVal(c.Telemetry.StatsdAddr),
TelemetryStatsiteAddr: b.stringVal(c.Telemetry.StatsiteAddr),
Telemetry: lib.TelemetryConfig{
CirconusAPIApp: b.stringVal(c.Telemetry.CirconusAPIApp),
CirconusAPIToken: b.stringVal(c.Telemetry.CirconusAPIToken),
CirconusAPIURL: b.stringVal(c.Telemetry.CirconusAPIURL),
CirconusBrokerID: b.stringVal(c.Telemetry.CirconusBrokerID),
CirconusBrokerSelectTag: b.stringVal(c.Telemetry.CirconusBrokerSelectTag),
CirconusCheckDisplayName: b.stringVal(c.Telemetry.CirconusCheckDisplayName),
CirconusCheckForceMetricActivation: b.stringVal(c.Telemetry.CirconusCheckForceMetricActivation),
CirconusCheckID: b.stringVal(c.Telemetry.CirconusCheckID),
CirconusCheckInstanceID: b.stringVal(c.Telemetry.CirconusCheckInstanceID),
CirconusCheckSearchTag: b.stringVal(c.Telemetry.CirconusCheckSearchTag),
CirconusCheckTags: b.stringVal(c.Telemetry.CirconusCheckTags),
CirconusSubmissionInterval: b.stringVal(c.Telemetry.CirconusSubmissionInterval),
CirconusSubmissionURL: b.stringVal(c.Telemetry.CirconusSubmissionURL),
DisableHostname: b.boolVal(c.Telemetry.DisableHostname),
DogstatsdAddr: b.stringVal(c.Telemetry.DogstatsdAddr),
DogstatsdTags: c.Telemetry.DogstatsdTags,
PrometheusRetentionTime: b.durationVal("prometheus_retention_time", c.Telemetry.PrometheusRetentionTime),
FilterDefault: b.boolVal(c.Telemetry.FilterDefault),
AllowedPrefixes: telemetryAllowedPrefixes,
BlockedPrefixes: telemetryBlockedPrefixes,
MetricsPrefix: b.stringVal(c.Telemetry.MetricsPrefix),
StatsdAddr: b.stringVal(c.Telemetry.StatsdAddr),
StatsiteAddr: b.stringVal(c.Telemetry.StatsiteAddr),
},
// Agent
AdvertiseAddrLAN: advertiseAddrLAN,
@ -639,6 +673,17 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
CheckUpdateInterval: b.durationVal("check_update_interval", c.CheckUpdateInterval),
Checks: checks,
ClientAddrs: clientAddrs,
ConnectEnabled: connectEnabled,
ConnectCAProvider: connectCAProvider,
ConnectCAConfig: connectCAConfig,
ConnectProxyAllowManagedRoot: b.boolVal(c.Connect.Proxy.AllowManagedRoot),
ConnectProxyAllowManagedAPIRegistration: b.boolVal(c.Connect.Proxy.AllowManagedAPIRegistration),
ConnectProxyBindMinPort: proxyMinPort,
ConnectProxyBindMaxPort: proxyMaxPort,
ConnectProxyDefaultExecMode: proxyDefaultExecMode,
ConnectProxyDefaultDaemonCommand: proxyDefaultDaemonCommand,
ConnectProxyDefaultScriptCommand: proxyDefaultScriptCommand,
ConnectProxyDefaultConfig: proxyDefaultConfig,
DataDir: b.stringVal(c.DataDir),
Datacenter: strings.ToLower(b.stringVal(c.Datacenter)),
DevMode: b.boolVal(b.Flags.DevMode),
@ -882,6 +927,34 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
return b.err
}
// Check for errors in the service definitions
for _, s := range rt.Services {
if err := s.Validate(); err != nil {
return fmt.Errorf("service %q: %s", s.Name, err)
}
}
// Validate the given Connect CA provider config
validCAProviders := map[string]bool{
"": true,
structs.ConsulCAProvider: true,
structs.VaultCAProvider: true,
}
if _, ok := validCAProviders[rt.ConnectCAProvider]; !ok {
return fmt.Errorf("%s is not a valid CA provider", rt.ConnectCAProvider)
} else {
switch rt.ConnectCAProvider {
case structs.ConsulCAProvider:
if _, err := ca.ParseConsulCAConfig(rt.ConnectCAConfig); err != nil {
return err
}
case structs.VaultCAProvider:
if _, err := ca.ParseVaultCAConfig(rt.ConnectCAConfig); err != nil {
return err
}
}
}
// ----------------------------------------------------------------
// warnings
//
@ -1008,6 +1081,26 @@ func (b *Builder) serviceVal(v *ServiceDefinition) *structs.ServiceDefinition {
Token: b.stringVal(v.Token),
EnableTagOverride: b.boolVal(v.EnableTagOverride),
Checks: checks,
Connect: b.serviceConnectVal(v.Connect),
}
}
func (b *Builder) serviceConnectVal(v *ServiceConnect) *structs.ServiceConnect {
if v == nil {
return nil
}
var proxy *structs.ServiceDefinitionConnectProxy
if v.Proxy != nil {
proxy = &structs.ServiceDefinitionConnectProxy{
ExecMode: b.stringVal(v.Proxy.ExecMode),
Command: v.Proxy.Command,
Config: v.Proxy.Config,
}
}
return &structs.ServiceConnect{
Proxy: proxy,
}
}

View File

@ -84,6 +84,7 @@ func Parse(data string, format string) (c Config, err error) {
"services",
"services.checks",
"watches",
"service.connect.proxy.config.upstreams",
})
// There is a difference of representation of some fields depending on
@ -159,6 +160,7 @@ type Config struct {
CheckUpdateInterval *string `json:"check_update_interval,omitempty" hcl:"check_update_interval" mapstructure:"check_update_interval"`
Checks []CheckDefinition `json:"checks,omitempty" hcl:"checks" mapstructure:"checks"`
ClientAddr *string `json:"client_addr,omitempty" hcl:"client_addr" mapstructure:"client_addr"`
Connect Connect `json:"connect,omitempty" hcl:"connect" mapstructure:"connect"`
DNS DNS `json:"dns_config,omitempty" hcl:"dns_config" mapstructure:"dns_config"`
DNSDomain *string `json:"domain,omitempty" hcl:"domain" mapstructure:"domain"`
DNSRecursors []string `json:"recursors,omitempty" hcl:"recursors" mapstructure:"recursors"`
@ -324,6 +326,7 @@ type ServiceDefinition struct {
Checks []CheckDefinition `json:"checks,omitempty" hcl:"checks" mapstructure:"checks"`
Token *string `json:"token,omitempty" hcl:"token" mapstructure:"token"`
EnableTagOverride *bool `json:"enable_tag_override,omitempty" hcl:"enable_tag_override" mapstructure:"enable_tag_override"`
Connect *ServiceConnect `json:"connect,omitempty" hcl:"connect" mapstructure:"connect"`
}
type CheckDefinition struct {
@ -349,6 +352,58 @@ type CheckDefinition struct {
DeregisterCriticalServiceAfter *string `json:"deregister_critical_service_after,omitempty" hcl:"deregister_critical_service_after" mapstructure:"deregister_critical_service_after"`
}
// ServiceConnect is the connect block within a service registration
type ServiceConnect struct {
// TODO(banks) add way to specify that the app is connect-native
// Proxy configures a connect proxy instance for the service
Proxy *ServiceConnectProxy `json:"proxy,omitempty" hcl:"proxy" mapstructure:"proxy"`
}
type ServiceConnectProxy struct {
Command []string `json:"command,omitempty" hcl:"command" mapstructure:"command"`
ExecMode *string `json:"exec_mode,omitempty" hcl:"exec_mode" mapstructure:"exec_mode"`
Config map[string]interface{} `json:"config,omitempty" hcl:"config" mapstructure:"config"`
}
// Connect is the agent-global connect configuration.
type Connect struct {
// Enabled opts the agent into connect. It should be set on all clients and
// servers in a cluster for correct connect operation.
Enabled *bool `json:"enabled,omitempty" hcl:"enabled" mapstructure:"enabled"`
Proxy ConnectProxy `json:"proxy,omitempty" hcl:"proxy" mapstructure:"proxy"`
ProxyDefaults ConnectProxyDefaults `json:"proxy_defaults,omitempty" hcl:"proxy_defaults" mapstructure:"proxy_defaults"`
CAProvider *string `json:"ca_provider,omitempty" hcl:"ca_provider" mapstructure:"ca_provider"`
CAConfig map[string]interface{} `json:"ca_config,omitempty" hcl:"ca_config" mapstructure:"ca_config"`
}
// ConnectProxy is the agent-global connect proxy configuration.
type ConnectProxy struct {
// Consul will not execute managed proxies if its EUID is 0 (root).
// If this is true, then Consul will execute proxies if Consul is
// running as root. This is not recommended.
AllowManagedRoot *bool `json:"allow_managed_root" hcl:"allow_managed_root" mapstructure:"allow_managed_root"`
// AllowManagedAPIRegistration enables managed proxy registration
// via the agent HTTP API. If this is false, only file configurations
// can be used.
AllowManagedAPIRegistration *bool `json:"allow_managed_api_registration" hcl:"allow_managed_api_registration" mapstructure:"allow_managed_api_registration"`
}
// ConnectProxyDefaults is the agent-global defaults for managed Connect proxies.
type ConnectProxyDefaults struct {
// ExecMode is used where a registration doesn't include an exec_mode.
// Defaults to daemon.
ExecMode *string `json:"exec_mode,omitempty" hcl:"exec_mode" mapstructure:"exec_mode"`
// DaemonCommand is used to start proxy in exec_mode = daemon if not specified
// at registration time.
DaemonCommand []string `json:"daemon_command,omitempty" hcl:"daemon_command" mapstructure:"daemon_command"`
// ScriptCommand is used to start proxy in exec_mode = script if not specified
// at registration time.
ScriptCommand []string `json:"script_command,omitempty" hcl:"script_command" mapstructure:"script_command"`
// Config is merged into an Config specified at registration time.
Config map[string]interface{} `json:"config,omitempty" hcl:"config" mapstructure:"config"`
}
type DNS struct {
AllowStale *bool `json:"allow_stale,omitempty" hcl:"allow_stale" mapstructure:"allow_stale"`
ARecordLimit *int `json:"a_record_limit,omitempty" hcl:"a_record_limit" mapstructure:"a_record_limit"`
@ -406,6 +461,8 @@ type Ports struct {
SerfLAN *int `json:"serf_lan,omitempty" hcl:"serf_lan" mapstructure:"serf_lan"`
SerfWAN *int `json:"serf_wan,omitempty" hcl:"serf_wan" mapstructure:"serf_wan"`
Server *int `json:"server,omitempty" hcl:"server" mapstructure:"server"`
ProxyMinPort *int `json:"proxy_min_port,omitempty" hcl:"proxy_min_port" mapstructure:"proxy_min_port"`
ProxyMaxPort *int `json:"proxy_max_port,omitempty" hcl:"proxy_max_port" mapstructure:"proxy_max_port"`
}
type UnixSocket struct {

View File

@ -85,6 +85,8 @@ func DefaultSource() Source {
serf_lan = ` + strconv.Itoa(consul.DefaultLANSerfPort) + `
serf_wan = ` + strconv.Itoa(consul.DefaultWANSerfPort) + `
server = ` + strconv.Itoa(consul.DefaultRPCPort) + `
proxy_min_port = 20000
proxy_max_port = 20255
}
telemetry = {
metrics_prefix = "consul"
@ -108,6 +110,10 @@ func DevSource() Source {
ui = true
log_level = "DEBUG"
server = true
connect = {
enabled = true
}
performance = {
raft_multiplier = 1
}

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types"
"golang.org/x/time/rate"
@ -304,177 +305,8 @@ type RuntimeConfig struct {
// hcl: http_config { response_headers = map[string]string }
HTTPResponseHeaders map[string]string
// TelemetryCirconus*: see https://github.com/circonus-labs/circonus-gometrics
// for more details on the various configuration options.
// Valid configuration combinations:
// - CirconusAPIToken
// metric management enabled (search for existing check or create a new one)
// - CirconusSubmissionUrl
// metric management disabled (use check with specified submission_url,
// broker must be using a public SSL certificate)
// - CirconusAPIToken + CirconusCheckSubmissionURL
// metric management enabled (use check with specified submission_url)
// - CirconusAPIToken + CirconusCheckID
// metric management enabled (use check with specified id)
// TelemetryCirconusAPIApp is an app name associated with API token.
// Default: "consul"
//
// hcl: telemetry { circonus_api_app = string }
TelemetryCirconusAPIApp string
// TelemetryCirconusAPIToken is a valid API Token used to create/manage check. If provided,
// metric management is enabled.
// Default: none
//
// hcl: telemetry { circonus_api_token = string }
TelemetryCirconusAPIToken string
// TelemetryCirconusAPIURL is the base URL to use for contacting the Circonus API.
// Default: "https://api.circonus.com/v2"
//
// hcl: telemetry { circonus_api_url = string }
TelemetryCirconusAPIURL string
// TelemetryCirconusBrokerID is an explicit broker to use when creating a new check. The numeric portion
// of broker._cid. If metric management is enabled and neither a Submission URL nor Check ID
// is provided, an attempt will be made to search for an existing check using Instance ID and
// Search Tag. If one is not found, a new HTTPTRAP check will be created.
// Default: use Select Tag if provided, otherwise, a random Enterprise Broker associated
// with the specified API token or the default Circonus Broker.
// Default: none
//
// hcl: telemetry { circonus_broker_id = string }
TelemetryCirconusBrokerID string
// TelemetryCirconusBrokerSelectTag is a special tag which will be used to select a broker when
// a Broker ID is not provided. The best use of this is to as a hint for which broker
// should be used based on *where* this particular instance is running.
// (e.g. a specific geo location or datacenter, dc:sfo)
// Default: none
//
// hcl: telemetry { circonus_broker_select_tag = string }
TelemetryCirconusBrokerSelectTag string
// TelemetryCirconusCheckDisplayName is the name for the check which will be displayed in the Circonus UI.
// Default: value of CirconusCheckInstanceID
//
// hcl: telemetry { circonus_check_display_name = string }
TelemetryCirconusCheckDisplayName string
// TelemetryCirconusCheckForceMetricActivation will force enabling metrics, as they are encountered,
// if the metric already exists and is NOT active. If check management is enabled, the default
// behavior is to add new metrics as they are encountered. If the metric already exists in the
// check, it will *NOT* be activated. This setting overrides that behavior.
// Default: "false"
//
// hcl: telemetry { circonus_check_metrics_activation = (true|false)
TelemetryCirconusCheckForceMetricActivation string
// TelemetryCirconusCheckID is the check id (not check bundle id) from a previously created
// HTTPTRAP check. The numeric portion of the check._cid field.
// Default: none
//
// hcl: telemetry { circonus_check_id = string }
TelemetryCirconusCheckID string
// TelemetryCirconusCheckInstanceID serves to uniquely identify the metrics coming from this "instance".
// It can be used to maintain metric continuity with transient or ephemeral instances as
// they move around within an infrastructure.
// Default: hostname:app
//
// hcl: telemetry { circonus_check_instance_id = string }
TelemetryCirconusCheckInstanceID string
// TelemetryCirconusCheckSearchTag is a special tag which, when coupled with the instance id, helps to
// narrow down the search results when neither a Submission URL or Check ID is provided.
// Default: service:app (e.g. service:consul)
//
// hcl: telemetry { circonus_check_search_tag = string }
TelemetryCirconusCheckSearchTag string
// TelemetryCirconusCheckSearchTag is a special tag which, when coupled with the instance id, helps to
// narrow down the search results when neither a Submission URL or Check ID is provided.
// Default: service:app (e.g. service:consul)
//
// hcl: telemetry { circonus_check_tags = string }
TelemetryCirconusCheckTags string
// TelemetryCirconusSubmissionInterval is the interval at which metrics are submitted to Circonus.
// Default: 10s
//
// hcl: telemetry { circonus_submission_interval = "duration" }
TelemetryCirconusSubmissionInterval string
// TelemetryCirconusCheckSubmissionURL is the check.config.submission_url field from a
// previously created HTTPTRAP check.
// Default: none
//
// hcl: telemetry { circonus_submission_url = string }
TelemetryCirconusSubmissionURL string
// DisableHostname will disable hostname prefixing for all metrics.
//
// hcl: telemetry { disable_hostname = (true|false)
TelemetryDisableHostname bool
// TelemetryDogStatsdAddr is the address of a dogstatsd instance. If provided,
// metrics will be sent to that instance
//
// hcl: telemetry { dogstatsd_addr = string }
TelemetryDogstatsdAddr string
// TelemetryDogStatsdTags are the global tags that should be sent with each packet to dogstatsd
// It is a list of strings, where each string looks like "my_tag_name:my_tag_value"
//
// hcl: telemetry { dogstatsd_tags = []string }
TelemetryDogstatsdTags []string
// PrometheusRetentionTime is the retention time for prometheus metrics if greater than 0.
// A value of 0 disable Prometheus support. Regarding Prometheus, it is considered a good
// practice to put large values here (such as a few days), and at least the interval between
// prometheus requests.
//
// hcl: telemetry { prometheus_retention_time = "duration" }
TelemetryPrometheusRetentionTime time.Duration
// TelemetryFilterDefault is the default for whether to allow a metric that's not
// covered by the filter.
//
// hcl: telemetry { filter_default = (true|false) }
TelemetryFilterDefault bool
// TelemetryAllowedPrefixes is a list of filter rules to apply for allowing metrics
// by prefix. Use the 'prefix_filter' option and prefix rules with '+' to be
// included.
//
// hcl: telemetry { prefix_filter = []string{"+<expr>", "+<expr>", ...} }
TelemetryAllowedPrefixes []string
// TelemetryBlockedPrefixes is a list of filter rules to apply for blocking metrics
// by prefix. Use the 'prefix_filter' option and prefix rules with '-' to be
// excluded.
//
// hcl: telemetry { prefix_filter = []string{"-<expr>", "-<expr>", ...} }
TelemetryBlockedPrefixes []string
// TelemetryMetricsPrefix is the prefix used to write stats values to.
// Default: "consul."
//
// hcl: telemetry { metrics_prefix = string }
TelemetryMetricsPrefix string
// TelemetryStatsdAddr is the address of a statsd instance. If provided,
// metrics will be sent to that instance.
//
// hcl: telemetry { statsd_addr = string }
TelemetryStatsdAddr string
// TelemetryStatsiteAddr is the address of a statsite instance. If provided,
// metrics will be streamed to that instance.
//
// hcl: telemetry { statsite_addr = string }
TelemetryStatsiteAddr string
// Embed Telemetry Config
Telemetry lib.TelemetryConfig
// Datacenter is the datacenter this node is in. Defaults to "dc1".
//
@ -621,6 +453,61 @@ type RuntimeConfig struct {
// flag: -client string
ClientAddrs []*net.IPAddr
// ConnectEnabled opts the agent into connect. It should be set on all clients
// and servers in a cluster for correct connect operation.
ConnectEnabled bool
// ConnectProxyBindMinPort is the inclusive start of the range of ports
// allocated to the agent for starting proxy listeners on where no explicit
// port is specified.
ConnectProxyBindMinPort int
// ConnectProxyBindMaxPort is the inclusive end of the range of ports
// allocated to the agent for starting proxy listeners on where no explicit
// port is specified.
ConnectProxyBindMaxPort int
// ConnectProxyAllowManagedRoot is true if Consul can execute managed
// proxies when running as root (EUID == 0).
ConnectProxyAllowManagedRoot bool
// ConnectProxyAllowManagedAPIRegistration enables managed proxy registration
// via the agent HTTP API. If this is false, only file configurations
// can be used.
ConnectProxyAllowManagedAPIRegistration bool
// ConnectProxyDefaultExecMode is used where a registration doesn't include an
// exec_mode. Defaults to daemon.
ConnectProxyDefaultExecMode string
// ConnectProxyDefaultDaemonCommand is used to start proxy in exec_mode =
// daemon if not specified at registration time.
ConnectProxyDefaultDaemonCommand []string
// ConnectProxyDefaultScriptCommand is used to start proxy in exec_mode =
// script if not specified at registration time.
ConnectProxyDefaultScriptCommand []string
// ConnectProxyDefaultConfig is merged with any config specified at
// registration time to allow global control of defaults.
ConnectProxyDefaultConfig map[string]interface{}
// ConnectCAProvider is the type of CA provider to use with Connect.
ConnectCAProvider string
// ConnectCAConfig is the config to use for the CA provider.
ConnectCAConfig map[string]interface{}
// ConnectTestDisableManagedProxies is not exposed to public config but us
// used by TestAgent to prevent self-executing the test binary in the
// background if a managed proxy is created for a test. The only place we
// actually want to test processes really being spun up and managed is in
// `agent/proxy` and it does it at a lower level. Note that this still allows
// registering managed proxies via API and other methods, and still creates
// all the agent state for them, just doesn't actually start external
// processes up.
ConnectTestDisableManagedProxies bool
// DNSAddrs contains the list of TCP and UDP addresses the DNS server will
// bind to. If the DNS endpoint is disabled (ports.dns <= 0) the list is
// empty.

View File

@ -19,10 +19,11 @@ import (
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/testutil"
"github.com/hashicorp/consul/types"
"github.com/pascaldekloe/goe/verify"
"github.com/sergi/go-diff/diffmatchpatch"
"github.com/stretchr/testify/require"
)
type configTest struct {
@ -31,6 +32,7 @@ type configTest struct {
pre, post func()
json, jsontail []string
hcl, hcltail []string
skipformat bool
privatev4 func() ([]*net.IPAddr, error)
publicv6 func() ([]*net.IPAddr, error)
patch func(rt *RuntimeConfig)
@ -263,6 +265,7 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
rt.AdvertiseAddrLAN = ipAddr("127.0.0.1")
rt.AdvertiseAddrWAN = ipAddr("127.0.0.1")
rt.BindAddr = ipAddr("127.0.0.1")
rt.ConnectEnabled = true
rt.DevMode = true
rt.DisableAnonymousSignature = true
rt.DisableKeyringFile = true
@ -1850,8 +1853,8 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
`},
patch: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.TelemetryAllowedPrefixes = []string{"foo"}
rt.TelemetryBlockedPrefixes = []string{"bar"}
rt.Telemetry.AllowedPrefixes = []string{"foo"}
rt.Telemetry.BlockedPrefixes = []string{"bar"}
},
warns: []string{`Filter rule must begin with either '+' or '-': "nix"`},
},
@ -2068,6 +2071,127 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
rt.DataDir = dataDir
},
},
{
desc: "HCL service managed proxy 'upstreams'",
args: []string{
`-data-dir=` + dataDir,
},
hcl: []string{
`service {
name = "web"
port = 8080
connect {
proxy {
config {
upstreams {
local_bind_port = 1234
}
}
}
}
}`,
},
skipformat: true, // skipping JSON cause we get slightly diff types (okay)
patch: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.Services = []*structs.ServiceDefinition{
&structs.ServiceDefinition{
Name: "web",
Port: 8080,
Connect: &structs.ServiceConnect{
Proxy: &structs.ServiceDefinitionConnectProxy{
Config: map[string]interface{}{
"upstreams": []map[string]interface{}{
map[string]interface{}{
"local_bind_port": 1234,
},
},
},
},
},
},
}
},
},
{
desc: "JSON service managed proxy 'upstreams'",
args: []string{
`-data-dir=` + dataDir,
},
json: []string{
`{
"service": {
"name": "web",
"port": 8080,
"connect": {
"proxy": {
"config": {
"upstreams": [{
"local_bind_port": 1234
}]
}
}
}
}
}`,
},
skipformat: true, // skipping HCL cause we get slightly diff types (okay)
patch: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.Services = []*structs.ServiceDefinition{
&structs.ServiceDefinition{
Name: "web",
Port: 8080,
Connect: &structs.ServiceConnect{
Proxy: &structs.ServiceDefinitionConnectProxy{
Config: map[string]interface{}{
"upstreams": []interface{}{
map[string]interface{}{
"local_bind_port": float64(1234),
},
},
},
},
},
},
}
},
},
{
desc: "enabling Connect allow_managed_root",
args: []string{
`-data-dir=` + dataDir,
},
json: []string{
`{ "connect": { "proxy": { "allow_managed_root": true } } }`,
},
hcl: []string{
`connect { proxy { allow_managed_root = true } }`,
},
patch: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.ConnectProxyAllowManagedRoot = true
},
},
{
desc: "enabling Connect allow_managed_api_registration",
args: []string{
`-data-dir=` + dataDir,
},
json: []string{
`{ "connect": { "proxy": { "allow_managed_api_registration": true } } }`,
},
hcl: []string{
`connect { proxy { allow_managed_api_registration = true } }`,
},
patch: func(rt *RuntimeConfig) {
rt.DataDir = dataDir
rt.ConnectProxyAllowManagedAPIRegistration = true
},
},
}
testConfig(t, tests, dataDir)
@ -2089,7 +2213,7 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
// json and hcl sources need to be in sync
// to make sure we're generating the same config
if len(tt.json) != len(tt.hcl) {
if len(tt.json) != len(tt.hcl) && !tt.skipformat {
t.Fatal(tt.desc, ": JSON and HCL test case out of sync")
}
@ -2099,6 +2223,12 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
srcs, tails = tt.hcl, tt.hcltail
}
// If we're skipping a format and the current format is empty,
// then skip it!
if tt.skipformat && len(srcs) == 0 {
continue
}
// build the description
var desc []string
if !flagsOnly {
@ -2353,6 +2483,23 @@ func TestFullConfig(t *testing.T) {
],
"check_update_interval": "16507s",
"client_addr": "93.83.18.19",
"connect": {
"ca_provider": "consul",
"ca_config": {
"RotationPeriod": "90h"
},
"enabled": true,
"proxy_defaults": {
"exec_mode": "script",
"daemon_command": ["consul", "connect", "proxy"],
"script_command": ["proxyctl.sh"],
"config": {
"foo": "bar",
"connect_timeout_ms": 1000,
"pedantic_mode": true
}
}
},
"data_dir": "` + dataDir + `",
"datacenter": "rzo029wg",
"disable_anonymous_signature": true,
@ -2417,7 +2564,9 @@ func TestFullConfig(t *testing.T) {
"dns": 7001,
"http": 7999,
"https": 15127,
"server": 3757
"server": 3757,
"proxy_min_port": 2000,
"proxy_max_port": 3000
},
"protocol": 30793,
"raft_protocol": 19016,
@ -2613,7 +2762,16 @@ func TestFullConfig(t *testing.T) {
"ttl": "11222s",
"deregister_critical_service_after": "68482s"
}
]
],
"connect": {
"proxy": {
"exec_mode": "daemon",
"command": ["awesome-proxy"],
"config": {
"foo": "qux"
}
}
}
}
],
"session_ttl_min": "26627s",
@ -2786,6 +2944,25 @@ func TestFullConfig(t *testing.T) {
]
check_update_interval = "16507s"
client_addr = "93.83.18.19"
connect {
ca_provider = "consul"
ca_config {
"RotationPeriod" = "90h"
}
enabled = true
proxy_defaults {
exec_mode = "script"
daemon_command = ["consul", "connect", "proxy"]
script_command = ["proxyctl.sh"]
config = {
foo = "bar"
# hack float since json parses numbers as float and we have to
# assert against the same thing
connect_timeout_ms = 1000.0
pedantic_mode = true
}
}
}
data_dir = "` + dataDir + `"
datacenter = "rzo029wg"
disable_anonymous_signature = true
@ -2851,6 +3028,8 @@ func TestFullConfig(t *testing.T) {
http = 7999,
https = 15127
server = 3757
proxy_min_port = 2000
proxy_max_port = 3000
}
protocol = 30793
raft_protocol = 19016
@ -3047,6 +3226,15 @@ func TestFullConfig(t *testing.T) {
deregister_critical_service_after = "68482s"
}
]
connect {
proxy {
exec_mode = "daemon"
command = ["awesome-proxy"]
config = {
foo = "qux"
}
}
}
}
]
session_ttl_min = "26627s"
@ -3357,6 +3545,23 @@ func TestFullConfig(t *testing.T) {
},
CheckUpdateInterval: 16507 * time.Second,
ClientAddrs: []*net.IPAddr{ipAddr("93.83.18.19")},
ConnectEnabled: true,
ConnectProxyBindMinPort: 2000,
ConnectProxyBindMaxPort: 3000,
ConnectCAProvider: "consul",
ConnectCAConfig: map[string]interface{}{
"RotationPeriod": "90h",
},
ConnectProxyAllowManagedRoot: false,
ConnectProxyAllowManagedAPIRegistration: false,
ConnectProxyDefaultExecMode: "script",
ConnectProxyDefaultDaemonCommand: []string{"consul", "connect", "proxy"},
ConnectProxyDefaultScriptCommand: []string{"proxyctl.sh"},
ConnectProxyDefaultConfig: map[string]interface{}{
"foo": "bar",
"connect_timeout_ms": float64(1000),
"pedantic_mode": true,
},
DNSAddrs: []net.Addr{tcpAddr("93.95.95.81:7001"), udpAddr("93.95.95.81:7001")},
DNSARecordLimit: 29907,
DNSAllowStale: true,
@ -3530,6 +3735,15 @@ func TestFullConfig(t *testing.T) {
DeregisterCriticalServiceAfter: 68482 * time.Second,
},
},
Connect: &structs.ServiceConnect{
Proxy: &structs.ServiceDefinitionConnectProxy{
ExecMode: "daemon",
Command: []string{"awesome-proxy"},
Config: map[string]interface{}{
"foo": "qux",
},
},
},
},
{
ID: "dLOXpSCI",
@ -3616,29 +3830,31 @@ func TestFullConfig(t *testing.T) {
StartJoinAddrsLAN: []string{"LR3hGDoG", "MwVpZ4Up"},
StartJoinAddrsWAN: []string{"EbFSc3nA", "kwXTh623"},
SyslogFacility: "hHv79Uia",
TelemetryCirconusAPIApp: "p4QOTe9j",
TelemetryCirconusAPIToken: "E3j35V23",
TelemetryCirconusAPIURL: "mEMjHpGg",
TelemetryCirconusBrokerID: "BHlxUhed",
TelemetryCirconusBrokerSelectTag: "13xy1gHm",
TelemetryCirconusCheckDisplayName: "DRSlQR6n",
TelemetryCirconusCheckForceMetricActivation: "Ua5FGVYf",
TelemetryCirconusCheckID: "kGorutad",
TelemetryCirconusCheckInstanceID: "rwoOL6R4",
TelemetryCirconusCheckSearchTag: "ovT4hT4f",
TelemetryCirconusCheckTags: "prvO4uBl",
TelemetryCirconusSubmissionInterval: "DolzaflP",
TelemetryCirconusSubmissionURL: "gTcbS93G",
TelemetryDisableHostname: true,
TelemetryDogstatsdAddr: "0wSndumK",
TelemetryDogstatsdTags: []string{"3N81zSUB", "Xtj8AnXZ"},
TelemetryFilterDefault: true,
TelemetryAllowedPrefixes: []string{"oJotS8XJ"},
TelemetryBlockedPrefixes: []string{"cazlEhGn"},
TelemetryMetricsPrefix: "ftO6DySn",
TelemetryPrometheusRetentionTime: 15 * time.Second,
TelemetryStatsdAddr: "drce87cy",
TelemetryStatsiteAddr: "HpFwKB8R",
Telemetry: lib.TelemetryConfig{
CirconusAPIApp: "p4QOTe9j",
CirconusAPIToken: "E3j35V23",
CirconusAPIURL: "mEMjHpGg",
CirconusBrokerID: "BHlxUhed",
CirconusBrokerSelectTag: "13xy1gHm",
CirconusCheckDisplayName: "DRSlQR6n",
CirconusCheckForceMetricActivation: "Ua5FGVYf",
CirconusCheckID: "kGorutad",
CirconusCheckInstanceID: "rwoOL6R4",
CirconusCheckSearchTag: "ovT4hT4f",
CirconusCheckTags: "prvO4uBl",
CirconusSubmissionInterval: "DolzaflP",
CirconusSubmissionURL: "gTcbS93G",
DisableHostname: true,
DogstatsdAddr: "0wSndumK",
DogstatsdTags: []string{"3N81zSUB", "Xtj8AnXZ"},
FilterDefault: true,
AllowedPrefixes: []string{"oJotS8XJ"},
BlockedPrefixes: []string{"cazlEhGn"},
MetricsPrefix: "ftO6DySn",
PrometheusRetentionTime: 15 * time.Second,
StatsdAddr: "drce87cy",
StatsiteAddr: "HpFwKB8R",
},
TLSCipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384},
TLSMinVersion: "pAOWafkR",
TLSPreferServerCipherSuites: true,
@ -4019,6 +4235,18 @@ func TestSanitize(t *testing.T) {
}
],
"ClientAddrs": [],
"ConnectCAConfig": {},
"ConnectCAProvider": "",
"ConnectEnabled": false,
"ConnectProxyAllowManagedAPIRegistration": false,
"ConnectProxyAllowManagedRoot": false,
"ConnectProxyBindMaxPort": 0,
"ConnectProxyBindMinPort": 0,
"ConnectProxyDefaultConfig": {},
"ConnectProxyDefaultDaemonCommand": [],
"ConnectProxyDefaultExecMode": "",
"ConnectProxyDefaultScriptCommand": [],
"ConnectTestDisableManagedProxies": false,
"ConsulCoordinateUpdateBatchSize": 0,
"ConsulCoordinateUpdateMaxBatches": 0,
"ConsulCoordinateUpdatePeriod": "15s",
@ -4150,11 +4378,14 @@ func TestSanitize(t *testing.T) {
"Timeout": "0s"
},
"Checks": [],
"Connect": null,
"EnableTagOverride": false,
"ID": "",
"Kind": "",
"Meta": {},
"Name": "foo",
"Port": 0,
"ProxyDestination": "",
"Tags": [],
"Token": "hidden"
}
@ -4170,29 +4401,31 @@ func TestSanitize(t *testing.T) {
"TLSMinVersion": "",
"TLSPreferServerCipherSuites": false,
"TaggedAddresses": {},
"TelemetryAllowedPrefixes": [],
"TelemetryBlockedPrefixes": [],
"TelemetryCirconusAPIApp": "",
"TelemetryCirconusAPIToken": "hidden",
"TelemetryCirconusAPIURL": "",
"TelemetryCirconusBrokerID": "",
"TelemetryCirconusBrokerSelectTag": "",
"TelemetryCirconusCheckDisplayName": "",
"TelemetryCirconusCheckForceMetricActivation": "",
"TelemetryCirconusCheckID": "",
"TelemetryCirconusCheckInstanceID": "",
"TelemetryCirconusCheckSearchTag": "",
"TelemetryCirconusCheckTags": "",
"TelemetryCirconusSubmissionInterval": "",
"TelemetryCirconusSubmissionURL": "",
"TelemetryDisableHostname": false,
"TelemetryDogstatsdAddr": "",
"TelemetryDogstatsdTags": [],
"TelemetryFilterDefault": false,
"TelemetryMetricsPrefix": "",
"TelemetryPrometheusRetentionTime": "0s",
"TelemetryStatsdAddr": "",
"TelemetryStatsiteAddr": "",
"Telemetry":{
"AllowedPrefixes": [],
"BlockedPrefixes": [],
"CirconusAPIApp": "",
"CirconusAPIToken": "hidden",
"CirconusAPIURL": "",
"CirconusBrokerID": "",
"CirconusBrokerSelectTag": "",
"CirconusCheckDisplayName": "",
"CirconusCheckForceMetricActivation": "",
"CirconusCheckID": "",
"CirconusCheckInstanceID": "",
"CirconusCheckSearchTag": "",
"CirconusCheckTags": "",
"CirconusSubmissionInterval": "",
"CirconusSubmissionURL": "",
"DisableHostname": false,
"DogstatsdAddr": "",
"DogstatsdTags": [],
"FilterDefault": false,
"MetricsPrefix": "",
"PrometheusRetentionTime": "0s",
"StatsdAddr": "",
"StatsiteAddr": ""
},
"TranslateWANAddrs": false,
"UIDir": "",
"UnixSocketGroup": "",
@ -4212,11 +4445,7 @@ func TestSanitize(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if got, want := string(b), rtJSON; got != want {
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(want, got, false)
t.Fatal(dmp.DiffPrettyText(diffs))
}
require.JSONEq(t, rtJSON, string(b))
}
func splitIPPort(hostport string) (net.IP, int) {

View File

@ -0,0 +1,49 @@
package ca
import (
"crypto/x509"
)
// Provider is the interface for Consul to interact with
// an external CA that provides leaf certificate signing for
// given SpiffeIDServices.
type Provider interface {
// Active root returns the currently active root CA for this
// provider. This should be a parent of the certificate returned by
// ActiveIntermediate()
ActiveRoot() (string, error)
// ActiveIntermediate returns the current signing cert used by this provider
// for generating SPIFFE leaf certs. Note that this must not change except
// when Consul requests the change via GenerateIntermediate. Changing the
// signing cert will break Consul's assumptions about which validation paths
// are active.
ActiveIntermediate() (string, error)
// GenerateIntermediate returns a new intermediate signing cert and sets it to
// the active intermediate. If multiple intermediates are needed to complete
// the chain from the signing certificate back to the active root, they should
// all by bundled here.
GenerateIntermediate() (string, error)
// Sign signs a leaf certificate used by Connect proxies from a CSR. The PEM
// returned should include only the leaf certificate as all Intermediates
// needed to validate it will be added by Consul based on the active
// intemediate and any cross-signed intermediates managed by Consul.
Sign(*x509.CertificateRequest) (string, error)
// CrossSignCA must accept a CA certificate from another CA provider
// and cross sign it exactly as it is such that it forms a chain back the the
// CAProvider's current root. Specifically, the Distinguished Name, Subject
// Alternative Name, SubjectKeyID and other relevant extensions must be kept.
// The resulting certificate must have a distinct Serial Number and the
// AuthorityKeyID set to the CAProvider's current signing key as well as the
// Issuer related fields changed as necessary. The resulting certificate is
// returned as a PEM formatted string.
CrossSignCA(*x509.Certificate) (string, error)
// Cleanup performs any necessary cleanup that should happen when the provider
// is shut down permanently, such as removing a temporary PKI backend in Vault
// created for an intermediate CA.
Cleanup() error
}

View File

@ -0,0 +1,379 @@
package ca
import (
"bytes"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net/url"
"sync"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
)
type ConsulProvider struct {
config *structs.ConsulCAProviderConfig
id string
delegate ConsulProviderStateDelegate
sync.RWMutex
}
type ConsulProviderStateDelegate interface {
State() *state.Store
ApplyCARequest(*structs.CARequest) error
}
// NewConsulProvider returns a new instance of the Consul CA provider,
// bootstrapping its state in the state store necessary
func NewConsulProvider(rawConfig map[string]interface{}, delegate ConsulProviderStateDelegate) (*ConsulProvider, error) {
conf, err := ParseConsulCAConfig(rawConfig)
if err != nil {
return nil, err
}
provider := &ConsulProvider{
config: conf,
delegate: delegate,
id: fmt.Sprintf("%s,%s", conf.PrivateKey, conf.RootCert),
}
// Check if this configuration of the provider has already been
// initialized in the state store.
state := delegate.State()
_, providerState, err := state.CAProviderState(provider.id)
if err != nil {
return nil, err
}
// Exit early if the state store has already been populated for this config.
if providerState != nil {
return provider, nil
}
newState := structs.CAConsulProviderState{
ID: provider.id,
}
// Write the initial provider state to get the index to use for the
// CA serial number.
{
args := &structs.CARequest{
Op: structs.CAOpSetProviderState,
ProviderState: &newState,
}
if err := delegate.ApplyCARequest(args); err != nil {
return nil, err
}
}
idx, _, err := state.CAProviderState(provider.id)
if err != nil {
return nil, err
}
// Generate a private key if needed
if conf.PrivateKey == "" {
_, pk, err := connect.GeneratePrivateKey()
if err != nil {
return nil, err
}
newState.PrivateKey = pk
} else {
newState.PrivateKey = conf.PrivateKey
}
// Generate the root CA if necessary
if conf.RootCert == "" {
ca, err := provider.generateCA(newState.PrivateKey, idx+1)
if err != nil {
return nil, fmt.Errorf("error generating CA: %v", err)
}
newState.RootCert = ca
} else {
newState.RootCert = conf.RootCert
}
// Write the provider state
args := &structs.CARequest{
Op: structs.CAOpSetProviderState,
ProviderState: &newState,
}
if err := delegate.ApplyCARequest(args); err != nil {
return nil, err
}
return provider, nil
}
// Return the active root CA and generate a new one if needed
func (c *ConsulProvider) ActiveRoot() (string, error) {
state := c.delegate.State()
_, providerState, err := state.CAProviderState(c.id)
if err != nil {
return "", err
}
return providerState.RootCert, nil
}
// We aren't maintaining separate root/intermediate CAs for the builtin
// provider, so just return the root.
func (c *ConsulProvider) ActiveIntermediate() (string, error) {
return c.ActiveRoot()
}
// We aren't maintaining separate root/intermediate CAs for the builtin
// provider, so just return the root.
func (c *ConsulProvider) GenerateIntermediate() (string, error) {
return c.ActiveIntermediate()
}
// Remove the state store entry for this provider instance.
func (c *ConsulProvider) Cleanup() error {
args := &structs.CARequest{
Op: structs.CAOpDeleteProviderState,
ProviderState: &structs.CAConsulProviderState{ID: c.id},
}
if err := c.delegate.ApplyCARequest(args); err != nil {
return err
}
return nil
}
// Sign returns a new certificate valid for the given SpiffeIDService
// using the current CA.
func (c *ConsulProvider) Sign(csr *x509.CertificateRequest) (string, error) {
// Lock during the signing so we don't use the same index twice
// for different cert serial numbers.
c.Lock()
defer c.Unlock()
// Get the provider state
state := c.delegate.State()
idx, providerState, err := state.CAProviderState(c.id)
if err != nil {
return "", err
}
// Create the keyId for the cert from the signing private key.
signer, err := connect.ParseSigner(providerState.PrivateKey)
if err != nil {
return "", err
}
if signer == nil {
return "", fmt.Errorf("error signing cert: Consul CA not initialized yet")
}
keyId, err := connect.KeyId(signer.Public())
if err != nil {
return "", err
}
// Parse the SPIFFE ID
spiffeId, err := connect.ParseCertURI(csr.URIs[0])
if err != nil {
return "", err
}
serviceId, ok := spiffeId.(*connect.SpiffeIDService)
if !ok {
return "", fmt.Errorf("SPIFFE ID in CSR must be a service ID")
}
// Parse the CA cert
caCert, err := connect.ParseCert(providerState.RootCert)
if err != nil {
return "", fmt.Errorf("error parsing CA cert: %s", err)
}
// Cert template for generation
sn := &big.Int{}
sn.SetUint64(idx + 1)
// Sign the certificate valid from 1 minute in the past, this helps it be
// accepted right away even when nodes are not in close time sync accross the
// cluster. A minute is more than enough for typical DC clock drift.
effectiveNow := time.Now().Add(-1 * time.Minute)
template := x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{CommonName: serviceId.Service},
URIs: csr.URIs,
Signature: csr.Signature,
SignatureAlgorithm: csr.SignatureAlgorithm,
PublicKeyAlgorithm: csr.PublicKeyAlgorithm,
PublicKey: csr.PublicKey,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageDataEncipherment |
x509.KeyUsageKeyAgreement |
x509.KeyUsageDigitalSignature |
x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
// todo(kyhavlov): add a way to set the cert lifetime here from the CA config
NotAfter: effectiveNow.Add(3 * 24 * time.Hour),
NotBefore: effectiveNow,
AuthorityKeyId: keyId,
SubjectKeyId: keyId,
}
// Create the certificate, PEM encode it and return that value.
var buf bytes.Buffer
bs, err := x509.CreateCertificate(
rand.Reader, &template, caCert, csr.PublicKey, signer)
if err != nil {
return "", fmt.Errorf("error generating certificate: %s", err)
}
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
return "", fmt.Errorf("error encoding certificate: %s", err)
}
err = c.incrementProviderIndex(providerState)
if err != nil {
return "", err
}
// Set the response
return buf.String(), nil
}
// CrossSignCA returns the given CA cert signed by the current active root.
func (c *ConsulProvider) CrossSignCA(cert *x509.Certificate) (string, error) {
c.Lock()
defer c.Unlock()
// Get the provider state
state := c.delegate.State()
idx, providerState, err := state.CAProviderState(c.id)
if err != nil {
return "", err
}
privKey, err := connect.ParseSigner(providerState.PrivateKey)
if err != nil {
return "", fmt.Errorf("error parsing private key %q: %s", providerState.PrivateKey, err)
}
rootCA, err := connect.ParseCert(providerState.RootCert)
if err != nil {
return "", err
}
keyId, err := connect.KeyId(privKey.Public())
if err != nil {
return "", err
}
// Create the cross-signing template from the existing root CA
serialNum := &big.Int{}
serialNum.SetUint64(idx + 1)
template := *cert
template.SerialNumber = serialNum
template.SignatureAlgorithm = rootCA.SignatureAlgorithm
template.AuthorityKeyId = keyId
// Sign the certificate valid from 1 minute in the past, this helps it be
// accepted right away even when nodes are not in close time sync accross the
// cluster. A minute is more than enough for typical DC clock drift.
effectiveNow := time.Now().Add(-1 * time.Minute)
template.NotBefore = effectiveNow
// This cross-signed cert is only needed during rotation, and only while old
// leaf certs are still in use. They expire within 3 days currently so 7 is
// safe. TODO(banks): make this be based on leaf expiry time when that is
// configurable.
template.NotAfter = effectiveNow.Add(7 * 24 * time.Hour)
bs, err := x509.CreateCertificate(
rand.Reader, &template, rootCA, cert.PublicKey, privKey)
if err != nil {
return "", fmt.Errorf("error generating CA certificate: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
return "", fmt.Errorf("error encoding private key: %s", err)
}
err = c.incrementProviderIndex(providerState)
if err != nil {
return "", err
}
return buf.String(), nil
}
// incrementProviderIndex does a write to increment the provider state store table index
// used for serial numbers when generating certificates.
func (c *ConsulProvider) incrementProviderIndex(providerState *structs.CAConsulProviderState) error {
newState := *providerState
args := &structs.CARequest{
Op: structs.CAOpSetProviderState,
ProviderState: &newState,
}
if err := c.delegate.ApplyCARequest(args); err != nil {
return err
}
return nil
}
// generateCA makes a new root CA using the current private key
func (c *ConsulProvider) generateCA(privateKey string, sn uint64) (string, error) {
state := c.delegate.State()
_, config, err := state.CAConfig()
if err != nil {
return "", err
}
privKey, err := connect.ParseSigner(privateKey)
if err != nil {
return "", fmt.Errorf("error parsing private key %q: %s", privateKey, err)
}
name := fmt.Sprintf("Consul CA %d", sn)
// The URI (SPIFFE compatible) for the cert
id := connect.SpiffeIDSigningForCluster(config)
keyId, err := connect.KeyId(privKey.Public())
if err != nil {
return "", err
}
// Create the CA cert
serialNum := &big.Int{}
serialNum.SetUint64(sn)
template := x509.Certificate{
SerialNumber: serialNum,
Subject: pkix.Name{CommonName: name},
URIs: []*url.URL{id.URI()},
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageCertSign |
x509.KeyUsageCRLSign |
x509.KeyUsageDigitalSignature,
IsCA: true,
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
NotBefore: time.Now(),
AuthorityKeyId: keyId,
SubjectKeyId: keyId,
}
bs, err := x509.CreateCertificate(
rand.Reader, &template, &template, privKey.Public(), privKey)
if err != nil {
return "", fmt.Errorf("error generating CA certificate: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
return "", fmt.Errorf("error encoding private key: %s", err)
}
return buf.String(), nil
}

View File

@ -0,0 +1,77 @@
package ca
import (
"fmt"
"reflect"
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/mitchellh/mapstructure"
)
func ParseConsulCAConfig(raw map[string]interface{}) (*structs.ConsulCAProviderConfig, error) {
var config structs.ConsulCAProviderConfig
decodeConf := &mapstructure.DecoderConfig{
DecodeHook: ParseDurationFunc(),
ErrorUnused: true,
Result: &config,
WeaklyTypedInput: true,
}
decoder, err := mapstructure.NewDecoder(decodeConf)
if err != nil {
return nil, err
}
if err := decoder.Decode(raw); err != nil {
return nil, fmt.Errorf("error decoding config: %s", err)
}
if config.PrivateKey == "" && config.RootCert != "" {
return nil, fmt.Errorf("must provide a private key when providing a root cert")
}
return &config, nil
}
// ParseDurationFunc is a mapstructure hook for decoding a string or
// []uint8 into a time.Duration value.
func ParseDurationFunc() mapstructure.DecodeHookFunc {
return func(
f reflect.Type,
t reflect.Type,
data interface{}) (interface{}, error) {
var v time.Duration
if t != reflect.TypeOf(v) {
return data, nil
}
switch {
case f.Kind() == reflect.String:
if dur, err := time.ParseDuration(data.(string)); err != nil {
return nil, err
} else {
v = dur
}
return v, nil
case f == reflect.SliceOf(reflect.TypeOf(uint8(0))):
s := Uint8ToString(data.([]uint8))
if dur, err := time.ParseDuration(s); err != nil {
return nil, err
} else {
v = dur
}
return v, nil
default:
return data, nil
}
}
}
func Uint8ToString(bs []uint8) string {
b := make([]byte, len(bs))
for i, v := range bs {
b[i] = byte(v)
}
return string(b)
}

View File

@ -0,0 +1,266 @@
package ca
import (
"crypto/x509"
"fmt"
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type consulCAMockDelegate struct {
state *state.Store
}
func (c *consulCAMockDelegate) State() *state.Store {
return c.state
}
func (c *consulCAMockDelegate) ApplyCARequest(req *structs.CARequest) error {
idx, _, err := c.state.CAConfig()
if err != nil {
return err
}
switch req.Op {
case structs.CAOpSetProviderState:
_, err := c.state.CASetProviderState(idx+1, req.ProviderState)
if err != nil {
return err
}
return nil
case structs.CAOpDeleteProviderState:
if err := c.state.CADeleteProviderState(req.ProviderState.ID); err != nil {
return err
}
return nil
default:
return fmt.Errorf("Invalid CA operation '%s'", req.Op)
}
}
func newMockDelegate(t *testing.T, conf *structs.CAConfiguration) *consulCAMockDelegate {
s, err := state.NewStateStore(nil)
if err != nil {
t.Fatalf("err: %s", err)
}
if s == nil {
t.Fatalf("missing state store")
}
if err := s.CASetConfig(conf.RaftIndex.CreateIndex, conf); err != nil {
t.Fatalf("err: %s", err)
}
return &consulCAMockDelegate{s}
}
func testConsulCAConfig() *structs.CAConfiguration {
return &structs.CAConfiguration{
ClusterID: "asdf",
Provider: "consul",
Config: map[string]interface{}{},
}
}
func TestConsulCAProvider_Bootstrap(t *testing.T) {
t.Parallel()
assert := assert.New(t)
conf := testConsulCAConfig()
delegate := newMockDelegate(t, conf)
provider, err := NewConsulProvider(conf.Config, delegate)
assert.NoError(err)
root, err := provider.ActiveRoot()
assert.NoError(err)
// Intermediate should be the same cert.
inter, err := provider.ActiveIntermediate()
assert.NoError(err)
assert.Equal(root, inter)
// Should be a valid cert
parsed, err := connect.ParseCert(root)
assert.NoError(err)
assert.Equal(parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", conf.ClusterID))
}
func TestConsulCAProvider_Bootstrap_WithCert(t *testing.T) {
t.Parallel()
// Make sure setting a custom private key/root cert works.
assert := assert.New(t)
rootCA := connect.TestCA(t, nil)
conf := testConsulCAConfig()
conf.Config = map[string]interface{}{
"PrivateKey": rootCA.SigningKey,
"RootCert": rootCA.RootCert,
}
delegate := newMockDelegate(t, conf)
provider, err := NewConsulProvider(conf.Config, delegate)
assert.NoError(err)
root, err := provider.ActiveRoot()
assert.NoError(err)
assert.Equal(root, rootCA.RootCert)
}
func TestConsulCAProvider_SignLeaf(t *testing.T) {
t.Parallel()
assert := assert.New(t)
conf := testConsulCAConfig()
delegate := newMockDelegate(t, conf)
provider, err := NewConsulProvider(conf.Config, delegate)
assert.NoError(err)
spiffeService := &connect.SpiffeIDService{
Host: "node1",
Namespace: "default",
Datacenter: "dc1",
Service: "foo",
}
// Generate a leaf cert for the service.
{
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
assert.NoError(err)
cert, err := provider.Sign(csr)
assert.NoError(err)
parsed, err := connect.ParseCert(cert)
assert.NoError(err)
assert.Equal(parsed.URIs[0], spiffeService.URI())
assert.Equal(parsed.Subject.CommonName, "foo")
assert.Equal(uint64(2), parsed.SerialNumber.Uint64())
// Ensure the cert is valid now and expires within the correct limit.
assert.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
assert.True(parsed.NotBefore.Before(time.Now()))
}
// Generate a new cert for another service and make sure
// the serial number is incremented.
spiffeService.Service = "bar"
{
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
assert.NoError(err)
cert, err := provider.Sign(csr)
assert.NoError(err)
parsed, err := connect.ParseCert(cert)
assert.NoError(err)
assert.Equal(parsed.URIs[0], spiffeService.URI())
assert.Equal(parsed.Subject.CommonName, "bar")
assert.Equal(parsed.SerialNumber.Uint64(), uint64(2))
// Ensure the cert is valid now and expires within the correct limit.
assert.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
assert.True(parsed.NotBefore.Before(time.Now()))
}
}
func TestConsulCAProvider_CrossSignCA(t *testing.T) {
t.Parallel()
conf1 := testConsulCAConfig()
delegate1 := newMockDelegate(t, conf1)
provider1, err := NewConsulProvider(conf1.Config, delegate1)
require.NoError(t, err)
conf2 := testConsulCAConfig()
conf2.CreateIndex = 10
delegate2 := newMockDelegate(t, conf2)
provider2, err := NewConsulProvider(conf2.Config, delegate2)
require.NoError(t, err)
testCrossSignProviders(t, provider1, provider2)
}
func testCrossSignProviders(t *testing.T, provider1, provider2 Provider) {
require := require.New(t)
// Get the root from the new provider to be cross-signed.
newRootPEM, err := provider2.ActiveRoot()
require.NoError(err)
newRoot, err := connect.ParseCert(newRootPEM)
require.NoError(err)
oldSubject := newRoot.Subject.CommonName
newInterPEM, err := provider2.ActiveIntermediate()
require.NoError(err)
newIntermediate, err := connect.ParseCert(newInterPEM)
require.NoError(err)
// Have provider1 cross sign our new root cert.
xcPEM, err := provider1.CrossSignCA(newRoot)
require.NoError(err)
xc, err := connect.ParseCert(xcPEM)
require.NoError(err)
oldRootPEM, err := provider1.ActiveRoot()
require.NoError(err)
oldRoot, err := connect.ParseCert(oldRootPEM)
require.NoError(err)
// AuthorityKeyID should now be the signing root's, SubjectKeyId should be kept.
require.Equal(oldRoot.AuthorityKeyId, xc.AuthorityKeyId)
require.Equal(newRoot.SubjectKeyId, xc.SubjectKeyId)
// Subject name should not have changed.
require.Equal(oldSubject, xc.Subject.CommonName)
// Issuer should be the signing root.
require.Equal(oldRoot.Issuer.CommonName, xc.Issuer.CommonName)
// Get a leaf cert so we can verify against the cross-signed cert.
spiffeService := &connect.SpiffeIDService{
Host: "node1",
Namespace: "default",
Datacenter: "dc1",
Service: "foo",
}
raw, _ := connect.TestCSR(t, spiffeService)
leafCsr, err := connect.ParseCSR(raw)
require.NoError(err)
leafPEM, err := provider2.Sign(leafCsr)
require.NoError(err)
cert, err := connect.ParseCert(leafPEM)
require.NoError(err)
// Check that the leaf signed by the new cert can be verified by either root
// certificate by using the new intermediate + cross-signed cert.
intermediatePool := x509.NewCertPool()
intermediatePool.AddCert(newIntermediate)
intermediatePool.AddCert(xc)
for _, root := range []*x509.Certificate{oldRoot, newRoot} {
rootPool := x509.NewCertPool()
rootPool.AddCert(root)
_, err = cert.Verify(x509.VerifyOptions{
Intermediates: intermediatePool,
Roots: rootPool,
})
require.NoError(err)
}
}

View File

@ -0,0 +1,322 @@
package ca
import (
"bytes"
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"net/http"
"strings"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-uuid"
vaultapi "github.com/hashicorp/vault/api"
"github.com/mitchellh/mapstructure"
)
const VaultCALeafCertRole = "leaf-cert"
var ErrBackendNotMounted = fmt.Errorf("backend not mounted")
var ErrBackendNotInitialized = fmt.Errorf("backend not initialized")
type VaultProvider struct {
config *structs.VaultCAProviderConfig
client *vaultapi.Client
clusterId string
}
// NewVaultProvider returns a vault provider with its root and intermediate PKI
// backends mounted and initialized. If the root backend is not set up already,
// it will be mounted/generated as needed, but any existing state will not be
// overwritten.
func NewVaultProvider(rawConfig map[string]interface{}, clusterId string) (*VaultProvider, error) {
conf, err := ParseVaultCAConfig(rawConfig)
if err != nil {
return nil, err
}
// todo(kyhavlov): figure out the right way to pass the TLS config
clientConf := &vaultapi.Config{
Address: conf.Address,
}
client, err := vaultapi.NewClient(clientConf)
if err != nil {
return nil, err
}
client.SetToken(conf.Token)
provider := &VaultProvider{
config: conf,
client: client,
clusterId: clusterId,
}
// Set up the root PKI backend if necessary.
_, err = provider.ActiveRoot()
switch err {
case ErrBackendNotMounted:
err := client.Sys().Mount(conf.RootPKIPath, &vaultapi.MountInput{
Type: "pki",
Description: "root CA backend for Consul Connect",
Config: vaultapi.MountConfigInput{
MaxLeaseTTL: "8760h",
},
})
if err != nil {
return nil, err
}
fallthrough
case ErrBackendNotInitialized:
spiffeID := connect.SpiffeIDSigning{ClusterID: clusterId, Domain: "consul"}
uuid, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
_, err = client.Logical().Write(conf.RootPKIPath+"root/generate/internal", map[string]interface{}{
"common_name": fmt.Sprintf("Vault CA Root Authority %s", uuid),
"uri_sans": spiffeID.URI().String(),
})
if err != nil {
return nil, err
}
default:
if err != nil {
return nil, err
}
}
// Set up the intermediate backend.
if _, err := provider.GenerateIntermediate(); err != nil {
return nil, err
}
return provider, nil
}
func (v *VaultProvider) ActiveRoot() (string, error) {
return v.getCA(v.config.RootPKIPath)
}
func (v *VaultProvider) ActiveIntermediate() (string, error) {
return v.getCA(v.config.IntermediatePKIPath)
}
// getCA returns the raw CA cert for the given endpoint if there is one.
// We have to use the raw NewRequest call here instead of Logical().Read
// because the endpoint only returns the raw PEM contents of the CA cert
// and not the typical format of the secrets endpoints.
func (v *VaultProvider) getCA(path string) (string, error) {
req := v.client.NewRequest("GET", "/v1/"+path+"/ca/pem")
resp, err := v.client.RawRequest(req)
if resp != nil {
defer resp.Body.Close()
}
if resp != nil && resp.StatusCode == http.StatusNotFound {
return "", ErrBackendNotMounted
}
if err != nil {
return "", err
}
bytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
root := string(bytes)
if root == "" {
return "", ErrBackendNotInitialized
}
return root, nil
}
// GenerateIntermediate mounts the configured intermediate PKI backend if
// necessary, then generates and signs a new CA CSR using the root PKI backend
// and updates the intermediate backend to use that new certificate.
func (v *VaultProvider) GenerateIntermediate() (string, error) {
mounts, err := v.client.Sys().ListMounts()
if err != nil {
return "", err
}
// Mount the backend if it isn't mounted already.
if _, ok := mounts[v.config.IntermediatePKIPath]; !ok {
err := v.client.Sys().Mount(v.config.IntermediatePKIPath, &vaultapi.MountInput{
Type: "pki",
Description: "intermediate CA backend for Consul Connect",
Config: vaultapi.MountConfigInput{
MaxLeaseTTL: "2160h",
},
})
if err != nil {
return "", err
}
}
// Create the role for issuing leaf certs if it doesn't exist yet
rolePath := v.config.IntermediatePKIPath + "roles/" + VaultCALeafCertRole
role, err := v.client.Logical().Read(rolePath)
if err != nil {
return "", err
}
spiffeID := connect.SpiffeIDSigning{ClusterID: v.clusterId, Domain: "consul"}
if role == nil {
_, err := v.client.Logical().Write(rolePath, map[string]interface{}{
"allow_any_name": true,
"allowed_uri_sans": "spiffe://*",
"key_type": "any",
"max_ttl": "72h",
"require_cn": false,
})
if err != nil {
return "", err
}
}
// Generate a new intermediate CSR for the root to sign.
csr, err := v.client.Logical().Write(v.config.IntermediatePKIPath+"intermediate/generate/internal", map[string]interface{}{
"common_name": "Vault CA Intermediate Authority",
"uri_sans": spiffeID.URI().String(),
})
if err != nil {
return "", err
}
if csr == nil || csr.Data["csr"] == "" {
return "", fmt.Errorf("got empty value when generating intermediate CSR")
}
// Sign the CSR with the root backend.
intermediate, err := v.client.Logical().Write(v.config.RootPKIPath+"root/sign-intermediate", map[string]interface{}{
"csr": csr.Data["csr"],
"format": "pem_bundle",
})
if err != nil {
return "", err
}
if intermediate == nil || intermediate.Data["certificate"] == "" {
return "", fmt.Errorf("got empty value when generating intermediate certificate")
}
// Set the intermediate backend to use the new certificate.
_, err = v.client.Logical().Write(v.config.IntermediatePKIPath+"intermediate/set-signed", map[string]interface{}{
"certificate": intermediate.Data["certificate"],
})
if err != nil {
return "", err
}
return v.ActiveIntermediate()
}
// Sign calls the configured role in the intermediate PKI backend to issue
// a new leaf certificate based on the provided CSR, with the issuing
// intermediate CA cert attached.
func (v *VaultProvider) Sign(csr *x509.CertificateRequest) (string, error) {
var pemBuf bytes.Buffer
if err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw}); err != nil {
return "", err
}
// Use the leaf cert role to sign a new cert for this CSR.
response, err := v.client.Logical().Write(v.config.IntermediatePKIPath+"sign/"+VaultCALeafCertRole, map[string]interface{}{
"csr": pemBuf.String(),
})
if err != nil {
return "", fmt.Errorf("error issuing cert: %v", err)
}
if response == nil || response.Data["certificate"] == "" || response.Data["issuing_ca"] == "" {
return "", fmt.Errorf("certificate info returned from Vault was blank")
}
cert, ok := response.Data["certificate"].(string)
if !ok {
return "", fmt.Errorf("certificate was not a string")
}
ca, ok := response.Data["issuing_ca"].(string)
if !ok {
return "", fmt.Errorf("issuing_ca was not a string")
}
return fmt.Sprintf("%s\n%s", cert, ca), nil
}
// CrossSignCA takes a CA certificate and cross-signs it to form a trust chain
// back to our active root.
func (v *VaultProvider) CrossSignCA(cert *x509.Certificate) (string, error) {
var pemBuf bytes.Buffer
err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
if err != nil {
return "", err
}
// Have the root PKI backend sign this cert.
response, err := v.client.Logical().Write(v.config.RootPKIPath+"root/sign-self-issued", map[string]interface{}{
"certificate": pemBuf.String(),
})
if err != nil {
return "", fmt.Errorf("error having Vault cross-sign cert: %v", err)
}
if response == nil || response.Data["certificate"] == "" {
return "", fmt.Errorf("certificate info returned from Vault was blank")
}
xcCert, ok := response.Data["certificate"].(string)
if !ok {
return "", fmt.Errorf("certificate was not a string")
}
return xcCert, nil
}
// Cleanup unmounts the configured intermediate PKI backend. It's fine to tear
// this down and recreate it on small config changes because the intermediate
// certs get bundled with the leaf certs, so there's no cost to the CA changing.
func (v *VaultProvider) Cleanup() error {
return v.client.Sys().Unmount(v.config.IntermediatePKIPath)
}
func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) {
var config structs.VaultCAProviderConfig
decodeConf := &mapstructure.DecoderConfig{
ErrorUnused: true,
Result: &config,
WeaklyTypedInput: true,
}
decoder, err := mapstructure.NewDecoder(decodeConf)
if err != nil {
return nil, err
}
if err := decoder.Decode(raw); err != nil {
return nil, fmt.Errorf("error decoding config: %s", err)
}
if config.Token == "" {
return nil, fmt.Errorf("must provide a Vault token")
}
if config.RootPKIPath == "" {
return nil, fmt.Errorf("must provide a valid path to a root PKI backend")
}
if !strings.HasSuffix(config.RootPKIPath, "/") {
config.RootPKIPath += "/"
}
if config.IntermediatePKIPath == "" {
return nil, fmt.Errorf("must provide a valid path for the intermediate PKI backend")
}
if !strings.HasSuffix(config.IntermediatePKIPath, "/") {
config.IntermediatePKIPath += "/"
}
return &config, nil
}

View File

@ -0,0 +1,162 @@
package ca
import (
"fmt"
"io/ioutil"
"net"
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
vaultapi "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/pki"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
"github.com/stretchr/testify/require"
)
func testVaultCluster(t *testing.T) (*VaultProvider, *vault.Core, net.Listener) {
if err := vault.AddTestLogicalBackend("pki", pki.Factory); err != nil {
t.Fatal(err)
}
core, _, token := vault.TestCoreUnsealedRaw(t)
ln, addr := vaulthttp.TestServer(t, core)
provider, err := NewVaultProvider(map[string]interface{}{
"Address": addr,
"Token": token,
"RootPKIPath": "pki-root/",
"IntermediatePKIPath": "pki-intermediate/",
}, "asdf")
if err != nil {
t.Fatal(err)
}
return provider, core, ln
}
func TestVaultCAProvider_Bootstrap(t *testing.T) {
t.Parallel()
require := require.New(t)
provider, core, listener := testVaultCluster(t)
defer core.Shutdown()
defer listener.Close()
client, err := vaultapi.NewClient(&vaultapi.Config{
Address: "http://" + listener.Addr().String(),
})
require.NoError(err)
client.SetToken(provider.config.Token)
cases := []struct {
certFunc func() (string, error)
backendPath string
}{
{
certFunc: provider.ActiveRoot,
backendPath: "pki-root/",
},
{
certFunc: provider.ActiveIntermediate,
backendPath: "pki-intermediate/",
},
}
// Verify the root and intermediate certs match the ones in the vault backends
for _, tc := range cases {
cert, err := tc.certFunc()
require.NoError(err)
req := client.NewRequest("GET", "/v1/"+tc.backendPath+"ca/pem")
resp, err := client.RawRequest(req)
require.NoError(err)
bytes, err := ioutil.ReadAll(resp.Body)
require.NoError(err)
require.Equal(cert, string(bytes))
// Should be a valid CA cert
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.True(parsed.IsCA)
require.Len(parsed.URIs, 1)
require.Equal(parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", provider.clusterId))
}
}
func TestVaultCAProvider_SignLeaf(t *testing.T) {
t.Parallel()
require := require.New(t)
provider, core, listener := testVaultCluster(t)
defer core.Shutdown()
defer listener.Close()
client, err := vaultapi.NewClient(&vaultapi.Config{
Address: "http://" + listener.Addr().String(),
})
require.NoError(err)
client.SetToken(provider.config.Token)
spiffeService := &connect.SpiffeIDService{
Host: "node1",
Namespace: "default",
Datacenter: "dc1",
Service: "foo",
}
// Generate a leaf cert for the service.
var firstSerial uint64
{
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
require.NoError(err)
cert, err := provider.Sign(csr)
require.NoError(err)
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.Equal(parsed.URIs[0], spiffeService.URI())
firstSerial = parsed.SerialNumber.Uint64()
// Ensure the cert is valid now and expires within the correct limit.
require.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
require.True(parsed.NotBefore.Before(time.Now()))
}
// Generate a new cert for another service and make sure
// the serial number is unique.
spiffeService.Service = "bar"
{
raw, _ := connect.TestCSR(t, spiffeService)
csr, err := connect.ParseCSR(raw)
require.NoError(err)
cert, err := provider.Sign(csr)
require.NoError(err)
parsed, err := connect.ParseCert(cert)
require.NoError(err)
require.Equal(parsed.URIs[0], spiffeService.URI())
require.NotEqual(firstSerial, parsed.SerialNumber.Uint64())
// Ensure the cert is valid now and expires within the correct limit.
require.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
require.True(parsed.NotBefore.Before(time.Now()))
}
}
func TestVaultCAProvider_CrossSignCA(t *testing.T) {
t.Parallel()
provider1, core1, listener1 := testVaultCluster(t)
defer core1.Shutdown()
defer listener1.Close()
provider2, core2, listener2 := testVaultCluster(t)
defer core2.Shutdown()
defer listener2.Close()
testCrossSignProviders(t, provider1, provider2)
}

33
agent/connect/csr.go Normal file
View File

@ -0,0 +1,33 @@
package connect
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"net/url"
)
// CreateCSR returns a CSR to sign the given service along with the PEM-encoded
// private key for this certificate.
func CreateCSR(uri CertURI, privateKey crypto.Signer) (string, error) {
template := &x509.CertificateRequest{
URIs: []*url.URL{uri.URI()},
SignatureAlgorithm: x509.ECDSAWithSHA256,
}
// Create the CSR itself
var csrBuf bytes.Buffer
bs, err := x509.CreateCertificateRequest(rand.Reader, template, privateKey)
if err != nil {
return "", err
}
err = pem.Encode(&csrBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: bs})
if err != nil {
return "", err
}
return csrBuf.String(), nil
}

35
agent/connect/generate.go Normal file
View File

@ -0,0 +1,35 @@
package connect
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"fmt"
)
// GeneratePrivateKey generates a new Private key
func GeneratePrivateKey() (crypto.Signer, string, error) {
var pk *ecdsa.PrivateKey
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, "", fmt.Errorf("error generating private key: %s", err)
}
bs, err := x509.MarshalECPrivateKey(pk)
if err != nil {
return nil, "", fmt.Errorf("error generating private key: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: bs})
if err != nil {
return nil, "", fmt.Errorf("error encoding private key: %s", err)
}
return pk, buf.String(), nil
}

121
agent/connect/parsing.go Normal file
View File

@ -0,0 +1,121 @@
package connect
import (
"crypto"
"crypto/ecdsa"
"crypto/sha1"
"crypto/sha256"
"crypto/x509"
"encoding/pem"
"fmt"
"strings"
)
// ParseCert parses the x509 certificate from a PEM-encoded value.
func ParseCert(pemValue string) (*x509.Certificate, error) {
// The _ result below is not an error but the remaining PEM bytes.
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return nil, fmt.Errorf("no PEM-encoded data found")
}
if block.Type != "CERTIFICATE" {
return nil, fmt.Errorf("first PEM-block should be CERTIFICATE type")
}
return x509.ParseCertificate(block.Bytes)
}
// CalculateCertFingerprint parses the x509 certificate from a PEM-encoded value
// and calculates the SHA-1 fingerprint.
func CalculateCertFingerprint(pemValue string) (string, error) {
// The _ result below is not an error but the remaining PEM bytes.
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return "", fmt.Errorf("no PEM-encoded data found")
}
if block.Type != "CERTIFICATE" {
return "", fmt.Errorf("first PEM-block should be CERTIFICATE type")
}
hash := sha1.Sum(block.Bytes)
return HexString(hash[:]), nil
}
// ParseSigner parses a crypto.Signer from a PEM-encoded key. The private key
// is expected to be the first block in the PEM value.
func ParseSigner(pemValue string) (crypto.Signer, error) {
// The _ result below is not an error but the remaining PEM bytes.
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return nil, fmt.Errorf("no PEM-encoded data found")
}
switch block.Type {
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(block.Bytes)
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
case "PRIVATE KEY":
signer, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
pk, ok := signer.(crypto.Signer)
if !ok {
return nil, fmt.Errorf("private key is not a valid format")
}
return pk, nil
default:
return nil, fmt.Errorf("unknown PEM block type for signing key: %s", block.Type)
}
}
// ParseCSR parses a CSR from a PEM-encoded value. The certificate request
// must be the the first block in the PEM value.
func ParseCSR(pemValue string) (*x509.CertificateRequest, error) {
// The _ result below is not an error but the remaining PEM bytes.
block, _ := pem.Decode([]byte(pemValue))
if block == nil {
return nil, fmt.Errorf("no PEM-encoded data found")
}
if block.Type != "CERTIFICATE REQUEST" {
return nil, fmt.Errorf("first PEM-block should be CERTIFICATE REQUEST type")
}
return x509.ParseCertificateRequest(block.Bytes)
}
// KeyId returns a x509 KeyId from the given signing key. The key must be
// an *ecdsa.PublicKey currently, but may support more types in the future.
func KeyId(raw interface{}) ([]byte, error) {
switch raw.(type) {
case *ecdsa.PublicKey:
default:
return nil, fmt.Errorf("invalid key type: %T", raw)
}
// This is not standard; RFC allows any unique identifier as long as they
// match in subject/authority chains but suggests specific hashing of DER
// bytes of public key including DER tags.
bs, err := x509.MarshalPKIXPublicKey(raw)
if err != nil {
return nil, err
}
// String formatted
kID := sha256.Sum256(bs)
return []byte(strings.Replace(fmt.Sprintf("% x", kID), " ", ":", -1)), nil
}
// HexString returns a standard colon-separated hex value for the input
// byte slice. This should be used with cert serial numbers and so on.
func HexString(input []byte) string {
return strings.Replace(fmt.Sprintf("% x", input), " ", ":", -1)
}

332
agent/connect/testing_ca.go Normal file
View File

@ -0,0 +1,332 @@
package connect
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net/url"
"sync/atomic"
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-uuid"
"github.com/mitchellh/go-testing-interface"
)
// TestClusterID is the Consul cluster ID for testing.
const TestClusterID = "11111111-2222-3333-4444-555555555555"
// testCACounter is just an atomically incremented counter for creating
// unique names for the CA certs.
var testCACounter uint64
// TestCA creates a test CA certificate and signing key and returns it
// in the CARoot structure format. The returned CA will be set as Active = true.
//
// If xc is non-nil, then the returned certificate will have a signing cert
// that is cross-signed with the previous cert, and this will be set as
// SigningCert.
func TestCA(t testing.T, xc *structs.CARoot) *structs.CARoot {
var result structs.CARoot
result.Active = true
result.Name = fmt.Sprintf("Test CA %d", atomic.AddUint64(&testCACounter, 1))
// Create the private key we'll use for this CA cert.
signer, keyPEM := testPrivateKey(t)
result.SigningKey = keyPEM
// The serial number for the cert
sn, err := testSerialNumber()
if err != nil {
t.Fatalf("error generating serial number: %s", err)
}
// The URI (SPIFFE compatible) for the cert
id := &SpiffeIDSigning{ClusterID: TestClusterID, Domain: "consul"}
// Create the CA cert
template := x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{CommonName: result.Name},
URIs: []*url.URL{id.URI()},
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageCertSign |
x509.KeyUsageCRLSign |
x509.KeyUsageDigitalSignature,
IsCA: true,
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
NotBefore: time.Now(),
AuthorityKeyId: testKeyID(t, signer.Public()),
SubjectKeyId: testKeyID(t, signer.Public()),
}
bs, err := x509.CreateCertificate(
rand.Reader, &template, &template, signer.Public(), signer)
if err != nil {
t.Fatalf("error generating CA certificate: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
t.Fatalf("error encoding private key: %s", err)
}
result.RootCert = buf.String()
result.ID, err = CalculateCertFingerprint(result.RootCert)
if err != nil {
t.Fatalf("error generating CA ID fingerprint: %s", err)
}
// If there is a prior CA to cross-sign with, then we need to create that
// and set it as the signing cert.
if xc != nil {
xccert, err := ParseCert(xc.RootCert)
if err != nil {
t.Fatalf("error parsing CA cert: %s", err)
}
xcsigner, err := ParseSigner(xc.SigningKey)
if err != nil {
t.Fatalf("error parsing signing key: %s", err)
}
// Set the authority key to be the previous one.
// NOTE(mitchellh): From Paul Banks: if we have to cross-sign a cert
// that came from outside (e.g. vault) we can't rely on them using the
// same KeyID hashing algo we do so we'd need to actually copy this
// from the xc cert's subjectKeyIdentifier extension.
template.AuthorityKeyId = testKeyID(t, xcsigner.Public())
// Create the new certificate where the parent is the previous
// CA, the public key is the new public key, and the signing private
// key is the old private key.
bs, err := x509.CreateCertificate(
rand.Reader, &template, xccert, signer.Public(), xcsigner)
if err != nil {
t.Fatalf("error generating CA certificate: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
t.Fatalf("error encoding private key: %s", err)
}
result.SigningCert = buf.String()
}
return &result
}
// TestLeaf returns a valid leaf certificate and it's private key for the named
// service with the given CA Root.
func TestLeaf(t testing.T, service string, root *structs.CARoot) (string, string) {
// Parse the CA cert and signing key from the root
cert := root.SigningCert
if cert == "" {
cert = root.RootCert
}
caCert, err := ParseCert(cert)
if err != nil {
t.Fatalf("error parsing CA cert: %s", err)
}
caSigner, err := ParseSigner(root.SigningKey)
if err != nil {
t.Fatalf("error parsing signing key: %s", err)
}
// Build the SPIFFE ID
spiffeId := &SpiffeIDService{
Host: fmt.Sprintf("%s.consul", TestClusterID),
Namespace: "default",
Datacenter: "dc1",
Service: service,
}
// The serial number for the cert
sn, err := testSerialNumber()
if err != nil {
t.Fatalf("error generating serial number: %s", err)
}
// Generate fresh private key
pkSigner, pkPEM, err := GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key: %s", err)
}
// Cert template for generation
template := x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{CommonName: service},
URIs: []*url.URL{spiffeId.URI()},
SignatureAlgorithm: x509.ECDSAWithSHA256,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageDataEncipherment |
x509.KeyUsageKeyAgreement |
x509.KeyUsageDigitalSignature |
x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
NotBefore: time.Now(),
AuthorityKeyId: testKeyID(t, caSigner.Public()),
SubjectKeyId: testKeyID(t, pkSigner.Public()),
}
// Create the certificate, PEM encode it and return that value.
var buf bytes.Buffer
bs, err := x509.CreateCertificate(
rand.Reader, &template, caCert, pkSigner.Public(), caSigner)
if err != nil {
t.Fatalf("error generating certificate: %s", err)
}
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
if err != nil {
t.Fatalf("error encoding private key: %s", err)
}
return buf.String(), pkPEM
}
// TestCSR returns a CSR to sign the given service along with the PEM-encoded
// private key for this certificate.
func TestCSR(t testing.T, uri CertURI) (string, string) {
template := &x509.CertificateRequest{
URIs: []*url.URL{uri.URI()},
SignatureAlgorithm: x509.ECDSAWithSHA256,
}
// Create the private key we'll use
signer, pkPEM := testPrivateKey(t)
// Create the CSR itself
var csrBuf bytes.Buffer
bs, err := x509.CreateCertificateRequest(rand.Reader, template, signer)
if err != nil {
t.Fatalf("error creating CSR: %s", err)
}
err = pem.Encode(&csrBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: bs})
if err != nil {
t.Fatalf("error encoding CSR: %s", err)
}
return csrBuf.String(), pkPEM
}
// testKeyID returns a KeyID from the given public key. This just calls
// KeyId but handles errors for tests.
func testKeyID(t testing.T, raw interface{}) []byte {
result, err := KeyId(raw)
if err != nil {
t.Fatalf("KeyId error: %s", err)
}
return result
}
// testPrivateKey creates an ECDSA based private key. Both a crypto.Signer and
// the key in PEM form are returned.
//
// NOTE(banks): this was memoized to save entropy during tests but it turns out
// crypto/rand will never block and always reads from /dev/urandom on unix OSes
// which does not consume entropy.
//
// If we find by profiling it's taking a lot of cycles we could optimise/cache
// again but we at least need to use different keys for each distinct CA (when
// multiple CAs are generated at once e.g. to test cross-signing) and a
// different one again for the leafs otherwise we risk tests that have false
// positives since signatures from different logical cert's keys are
// indistinguishable, but worse we build validation chains using AuthorityKeyID
// which will be the same for multiple CAs/Leafs. Also note that our UUID
// generator also reads from crypto rand and is called far more often during
// tests than this will be.
func testPrivateKey(t testing.T) (crypto.Signer, string) {
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("error generating private key: %s", err)
}
bs, err := x509.MarshalECPrivateKey(pk)
if err != nil {
t.Fatalf("error generating private key: %s", err)
}
var buf bytes.Buffer
err = pem.Encode(&buf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: bs})
if err != nil {
t.Fatalf("error encoding private key: %s", err)
}
return pk, buf.String()
}
// testSerialNumber generates a serial number suitable for a certificate.
// For testing, this just sets it to a random number.
//
// This function is taken directly from the Vault implementation.
func testSerialNumber() (*big.Int, error) {
return rand.Int(rand.Reader, (&big.Int{}).Exp(big.NewInt(2), big.NewInt(159), nil))
}
// testUUID generates a UUID for testing.
func testUUID(t testing.T) string {
ret, err := uuid.GenerateUUID()
if err != nil {
t.Fatalf("Unable to generate a UUID, %s", err)
}
return ret
}
// TestAgentRPC is an interface that an RPC client must implement. This is a
// helper interface that is implemented by the agent delegate so that test
// helpers can make RPCs without introducing an import cycle on `agent`.
type TestAgentRPC interface {
RPC(method string, args interface{}, reply interface{}) error
}
// TestCAConfigSet sets a CARoot returned by TestCA into the TestAgent state. It
// requires that TestAgent had connect enabled in it's config. If ca is nil, a
// new CA is created.
//
// It returns the CARoot passed or created.
//
// Note that we have to use an interface for the TestAgent.RPC method since we
// can't introduce an import cycle by importing `agent.TestAgent` here directly.
// It also means this will work in a few other places we mock that method.
func TestCAConfigSet(t testing.T, a TestAgentRPC,
ca *structs.CARoot) *structs.CARoot {
t.Helper()
if ca == nil {
ca = TestCA(t, nil)
}
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": ca.SigningKey,
"RootCert": ca.RootCert,
"RotationPeriod": 180 * 24 * time.Hour,
},
}
args := &structs.CARequest{
Datacenter: "dc1",
Config: newConfig,
}
var reply interface{}
err := a.RPC("ConnectCA.ConfigurationSet", args, &reply)
if err != nil {
t.Fatalf("failed to set test CA config: %s", err)
}
return ca
}

View File

@ -0,0 +1,100 @@
package connect
import (
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
// hasOpenSSL is used to determine if the openssl CLI exists for unit tests.
var hasOpenSSL bool
func init() {
_, err := exec.LookPath("openssl")
hasOpenSSL = err == nil
}
// Test that the TestCA and TestLeaf functions generate valid certificates.
func TestTestCAAndLeaf(t *testing.T) {
if !hasOpenSSL {
t.Skip("openssl not found")
return
}
assert := assert.New(t)
// Create the certs
ca := TestCA(t, nil)
leaf, _ := TestLeaf(t, "web", ca)
// Create a temporary directory for storing the certs
td, err := ioutil.TempDir("", "consul")
assert.Nil(err)
defer os.RemoveAll(td)
// Write the cert
assert.Nil(ioutil.WriteFile(filepath.Join(td, "ca.pem"), []byte(ca.RootCert), 0644))
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf.pem"), []byte(leaf), 0644))
// Use OpenSSL to verify so we have an external, known-working process
// that can verify this outside of our own implementations.
cmd := exec.Command(
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf.pem")
cmd.Dir = td
output, err := cmd.Output()
t.Log(string(output))
assert.Nil(err)
}
// Test cross-signing.
func TestTestCAAndLeaf_xc(t *testing.T) {
if !hasOpenSSL {
t.Skip("openssl not found")
return
}
assert := assert.New(t)
// Create the certs
ca1 := TestCA(t, nil)
ca2 := TestCA(t, ca1)
leaf1, _ := TestLeaf(t, "web", ca1)
leaf2, _ := TestLeaf(t, "web", ca2)
// Create a temporary directory for storing the certs
td, err := ioutil.TempDir("", "consul")
assert.Nil(err)
defer os.RemoveAll(td)
// Write the cert
xcbundle := []byte(ca1.RootCert)
xcbundle = append(xcbundle, '\n')
xcbundle = append(xcbundle, []byte(ca2.SigningCert)...)
assert.Nil(ioutil.WriteFile(filepath.Join(td, "ca.pem"), xcbundle, 0644))
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf1.pem"), []byte(leaf1), 0644))
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf2.pem"), []byte(leaf2), 0644))
// OpenSSL verify the cross-signed leaf (leaf2)
{
cmd := exec.Command(
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf2.pem")
cmd.Dir = td
output, err := cmd.Output()
t.Log(string(output))
assert.Nil(err)
}
// OpenSSL verify the old leaf (leaf1)
{
cmd := exec.Command(
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf1.pem")
cmd.Dir = td
output, err := cmd.Output()
t.Log(string(output))
assert.Nil(err)
}
}

View File

@ -0,0 +1,21 @@
package connect
import (
"github.com/mitchellh/go-testing-interface"
)
// TestSpiffeIDService returns a SPIFFE ID representing a service.
func TestSpiffeIDService(t testing.T, service string) *SpiffeIDService {
return TestSpiffeIDServiceWithHost(t, service, TestClusterID+".consul")
}
// TestSpiffeIDServiceWithHost returns a SPIFFE ID representing a service with
// the specified trust domain.
func TestSpiffeIDServiceWithHost(t testing.T, service, host string) *SpiffeIDService {
return &SpiffeIDService{
Host: host,
Namespace: "default",
Datacenter: "dc1",
Service: service,
}
}

91
agent/connect/uri.go Normal file
View File

@ -0,0 +1,91 @@
package connect
import (
"fmt"
"net/url"
"regexp"
"strings"
"github.com/hashicorp/consul/agent/structs"
)
// CertURI represents a Connect-valid URI value for a TLS certificate.
// The user should type switch on the various implementations in this
// package to determine the type of URI and the data encoded within it.
//
// Note that the current implementations of this are all also SPIFFE IDs.
// However, we anticipate that we may accept URIs that are also not SPIFFE
// compliant and therefore the interface is named as such.
type CertURI interface {
// Authorize tests the authorization for this URI as a client
// for the given intention. The return value `auth` is only valid if
// the second value `match` is true. If the second value `match` is
// false, then the intention doesn't match this client and any
// result should be ignored.
Authorize(*structs.Intention) (auth bool, match bool)
// URI is the valid URI value used in the cert.
URI() *url.URL
}
var (
spiffeIDServiceRegexp = regexp.MustCompile(
`^/ns/([^/]+)/dc/([^/]+)/svc/([^/]+)$`)
)
// ParseCertURI parses a the URI value from a TLS certificate.
func ParseCertURI(input *url.URL) (CertURI, error) {
if input.Scheme != "spiffe" {
return nil, fmt.Errorf("SPIFFE ID must have 'spiffe' scheme")
}
// Path is the raw value of the path without url decoding values.
// RawPath is empty if there were no encoded values so we must
// check both.
path := input.Path
if input.RawPath != "" {
path = input.RawPath
}
// Test for service IDs
if v := spiffeIDServiceRegexp.FindStringSubmatch(path); v != nil {
// Determine the values. We assume they're sane to save cycles,
// but if the raw path is not empty that means that something is
// URL encoded so we go to the slow path.
ns := v[1]
dc := v[2]
service := v[3]
if input.RawPath != "" {
var err error
if ns, err = url.PathUnescape(v[1]); err != nil {
return nil, fmt.Errorf("Invalid namespace: %s", err)
}
if dc, err = url.PathUnescape(v[2]); err != nil {
return nil, fmt.Errorf("Invalid datacenter: %s", err)
}
if service, err = url.PathUnescape(v[3]); err != nil {
return nil, fmt.Errorf("Invalid service: %s", err)
}
}
return &SpiffeIDService{
Host: input.Host,
Namespace: ns,
Datacenter: dc,
Service: service,
}, nil
}
// Test for signing ID
if input.Path == "" {
idx := strings.Index(input.Host, ".")
if idx > 0 {
return &SpiffeIDSigning{
ClusterID: input.Host[:idx],
Domain: input.Host[idx+1:],
}, nil
}
}
return nil, fmt.Errorf("SPIFFE ID is not in the expected format")
}

View File

@ -0,0 +1,42 @@
package connect
import (
"fmt"
"net/url"
"github.com/hashicorp/consul/agent/structs"
)
// SpiffeIDService is the structure to represent the SPIFFE ID for a service.
type SpiffeIDService struct {
Host string
Namespace string
Datacenter string
Service string
}
// URI returns the *url.URL for this SPIFFE ID.
func (id *SpiffeIDService) URI() *url.URL {
var result url.URL
result.Scheme = "spiffe"
result.Host = id.Host
result.Path = fmt.Sprintf("/ns/%s/dc/%s/svc/%s",
id.Namespace, id.Datacenter, id.Service)
return &result
}
// CertURI impl.
func (id *SpiffeIDService) Authorize(ixn *structs.Intention) (bool, bool) {
if ixn.SourceNS != structs.IntentionWildcard && ixn.SourceNS != id.Namespace {
// Non-matching namespace
return false, false
}
if ixn.SourceName != structs.IntentionWildcard && ixn.SourceName != id.Service {
// Non-matching name
return false, false
}
// Match, return allow value
return ixn.Action == structs.IntentionActionAllow, true
}

View File

@ -0,0 +1,104 @@
package connect
import (
"testing"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/assert"
)
func TestSpiffeIDServiceAuthorize(t *testing.T) {
ns := structs.IntentionDefaultNamespace
serviceWeb := &SpiffeIDService{
Host: "1234.consul",
Namespace: structs.IntentionDefaultNamespace,
Datacenter: "dc01",
Service: "web",
}
cases := []struct {
Name string
URI *SpiffeIDService
Ixn *structs.Intention
Auth bool
Match bool
}{
{
"exact source, not matching namespace",
serviceWeb,
&structs.Intention{
SourceNS: "different",
SourceName: "db",
},
false,
false,
},
{
"exact source, not matching name",
serviceWeb,
&structs.Intention{
SourceNS: ns,
SourceName: "db",
},
false,
false,
},
{
"exact source, allow",
serviceWeb,
&structs.Intention{
SourceNS: serviceWeb.Namespace,
SourceName: serviceWeb.Service,
Action: structs.IntentionActionAllow,
},
true,
true,
},
{
"exact source, deny",
serviceWeb,
&structs.Intention{
SourceNS: serviceWeb.Namespace,
SourceName: serviceWeb.Service,
Action: structs.IntentionActionDeny,
},
false,
true,
},
{
"exact namespace, wildcard service, deny",
serviceWeb,
&structs.Intention{
SourceNS: serviceWeb.Namespace,
SourceName: structs.IntentionWildcard,
Action: structs.IntentionActionDeny,
},
false,
true,
},
{
"exact namespace, wildcard service, allow",
serviceWeb,
&structs.Intention{
SourceNS: serviceWeb.Namespace,
SourceName: structs.IntentionWildcard,
Action: structs.IntentionActionAllow,
},
true,
true,
},
}
for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
auth, match := tc.URI.Authorize(tc.Ixn)
assert.Equal(t, tc.Auth, auth)
assert.Equal(t, tc.Match, match)
})
}
}

View File

@ -0,0 +1,75 @@
package connect
import (
"fmt"
"net/url"
"strings"
"github.com/hashicorp/consul/agent/structs"
)
// SpiffeIDSigning is the structure to represent the SPIFFE ID for a
// signing certificate (not a leaf service).
type SpiffeIDSigning struct {
ClusterID string // Unique cluster ID
Domain string // The domain, usually "consul"
}
// URI returns the *url.URL for this SPIFFE ID.
func (id *SpiffeIDSigning) URI() *url.URL {
var result url.URL
result.Scheme = "spiffe"
result.Host = id.Host()
return &result
}
// Host is the canonical representation as a DNS-compatible hostname.
func (id *SpiffeIDSigning) Host() string {
return strings.ToLower(fmt.Sprintf("%s.%s", id.ClusterID, id.Domain))
}
// CertURI impl.
func (id *SpiffeIDSigning) Authorize(ixn *structs.Intention) (bool, bool) {
// Never authorize as a client.
return false, true
}
// CanSign takes any CertURI and returns whether or not this signing entity is
// allowed to sign CSRs for that entity (i.e. represents the trust domain for
// that entity).
//
// I choose to make this a fixed centralised method here for now rather than a
// method on CertURI interface since we don't intend this to be extensible
// outside and it's easier to reason about the security properties when they are
// all in one place with "whitelist" semantics.
func (id *SpiffeIDSigning) CanSign(cu CertURI) bool {
switch other := cu.(type) {
case *SpiffeIDSigning:
// We can only sign other CA certificates for the same trust domain. Note
// that we could open this up later for example to support external
// federation of roots and cross-signing external roots that have different
// URI structure but it's simpler to start off restrictive.
return id == other
case *SpiffeIDService:
// The host component of the service must be an exact match for now under
// ascii case folding (since hostnames are case-insensitive). Later we might
// worry about Unicode domains if we start allowing customisation beyond the
// built-in cluster ids.
return strings.ToLower(other.Host) == id.Host()
default:
return false
}
}
// SpiffeIDSigningForCluster returns the SPIFFE signing identifier (trust
// domain) representation of the given CA config. If config is nil this function
// will panic.
//
// NOTE(banks): we intentionally fix the tld `.consul` for now rather than tie
// this to the `domain` config used for DNS because changing DNS domain can't
// break all certificate validation. That does mean that DNS prefix might not
// match the identity URIs and so the trust domain might not actually resolve
// which we would like but don't actually need.
func SpiffeIDSigningForCluster(config *structs.CAConfiguration) *SpiffeIDSigning {
return &SpiffeIDSigning{ClusterID: config.ClusterID, Domain: "consul"}
}

View File

@ -0,0 +1,123 @@
package connect
import (
"net/url"
"strings"
"testing"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/assert"
)
// Signing ID should never authorize
func TestSpiffeIDSigningAuthorize(t *testing.T) {
var id SpiffeIDSigning
auth, ok := id.Authorize(nil)
assert.False(t, auth)
assert.True(t, ok)
}
func TestSpiffeIDSigningForCluster(t *testing.T) {
// For now it should just append .consul to the ID.
config := &structs.CAConfiguration{
ClusterID: TestClusterID,
}
id := SpiffeIDSigningForCluster(config)
assert.Equal(t, id.URI().String(), "spiffe://"+TestClusterID+".consul")
}
// fakeCertURI is a CertURI implementation that our implementation doesn't know
// about
type fakeCertURI string
func (f fakeCertURI) Authorize(*structs.Intention) (auth bool, match bool) {
return false, false
}
func (f fakeCertURI) URI() *url.URL {
u, _ := url.Parse(string(f))
return u
}
func TestSpiffeIDSigning_CanSign(t *testing.T) {
testSigning := &SpiffeIDSigning{
ClusterID: TestClusterID,
Domain: "consul",
}
tests := []struct {
name string
id *SpiffeIDSigning
input CertURI
want bool
}{
{
name: "same signing ID",
id: testSigning,
input: testSigning,
want: true,
},
{
name: "other signing ID",
id: testSigning,
input: &SpiffeIDSigning{
ClusterID: "fakedomain",
Domain: "consul",
},
want: false,
},
{
name: "different TLD signing ID",
id: testSigning,
input: &SpiffeIDSigning{
ClusterID: TestClusterID,
Domain: "evil",
},
want: false,
},
{
name: "nil",
id: testSigning,
input: nil,
want: false,
},
{
name: "unrecognised CertURI implementation",
id: testSigning,
input: fakeCertURI("spiffe://foo.bar/baz"),
want: false,
},
{
name: "service - good",
id: testSigning,
input: &SpiffeIDService{TestClusterID + ".consul", "default", "dc1", "web"},
want: true,
},
{
name: "service - good midex case",
id: testSigning,
input: &SpiffeIDService{strings.ToUpper(TestClusterID) + ".CONsuL", "defAUlt", "dc1", "WEB"},
want: true,
},
{
name: "service - different cluster",
id: testSigning,
input: &SpiffeIDService{"55555555-4444-3333-2222-111111111111.consul", "default", "dc1", "web"},
want: false,
},
{
name: "service - different TLD",
id: testSigning,
input: &SpiffeIDService{TestClusterID + ".fake", "default", "dc1", "web"},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.id.CanSign(tt.input)
assert.Equal(t, tt.want, got)
})
}
}

82
agent/connect/uri_test.go Normal file
View File

@ -0,0 +1,82 @@
package connect
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
// testCertURICases contains the test cases for parsing and encoding
// the SPIFFE IDs. This is a global since it is used in multiple test functions.
var testCertURICases = []struct {
Name string
URI string
Struct interface{}
ParseError string
}{
{
"invalid scheme",
"http://google.com/",
nil,
"scheme",
},
{
"basic service ID",
"spiffe://1234.consul/ns/default/dc/dc01/svc/web",
&SpiffeIDService{
Host: "1234.consul",
Namespace: "default",
Datacenter: "dc01",
Service: "web",
},
"",
},
{
"service with URL-encoded values",
"spiffe://1234.consul/ns/foo%2Fbar/dc/bar%2Fbaz/svc/baz%2Fqux",
&SpiffeIDService{
Host: "1234.consul",
Namespace: "foo/bar",
Datacenter: "bar/baz",
Service: "baz/qux",
},
"",
},
{
"signing ID",
"spiffe://1234.consul",
&SpiffeIDSigning{
ClusterID: "1234",
Domain: "consul",
},
"",
},
}
func TestParseCertURI(t *testing.T) {
for _, tc := range testCertURICases {
t.Run(tc.Name, func(t *testing.T) {
assert := assert.New(t)
// Parse the URI, should always be valid
uri, err := url.Parse(tc.URI)
assert.Nil(err)
// Parse the ID and check the error/return value
actual, err := ParseCertURI(uri)
if err != nil {
t.Logf("parse error: %s", err.Error())
}
assert.Equal(tc.ParseError != "", err != nil, "error value")
if err != nil {
assert.Contains(err.Error(), tc.ParseError)
return
}
assert.Equal(tc.Struct, actual)
})
}
}

View File

@ -0,0 +1,98 @@
package agent
import (
"fmt"
"net/http"
"github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/structs"
)
// GET /v1/connect/ca/roots
func (s *HTTPServer) ConnectCARoots(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var args structs.DCSpecificRequest
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
var reply structs.IndexedCARoots
defer setMeta(resp, &reply.QueryMeta)
if err := s.agent.RPC("ConnectCA.Roots", &args, &reply); err != nil {
return nil, err
}
return reply, nil
}
// /v1/connect/ca/configuration
func (s *HTTPServer) ConnectCAConfiguration(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
switch req.Method {
case "GET":
return s.ConnectCAConfigurationGet(resp, req)
case "PUT":
return s.ConnectCAConfigurationSet(resp, req)
default:
return nil, MethodNotAllowedError{req.Method, []string{"GET", "POST"}}
}
}
// GEt /v1/connect/ca/configuration
func (s *HTTPServer) ConnectCAConfigurationGet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in ConnectCAConfiguration
var args structs.DCSpecificRequest
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
var reply structs.CAConfiguration
err := s.agent.RPC("ConnectCA.ConfigurationGet", &args, &reply)
if err != nil {
return nil, err
}
fixupConfig(&reply)
return reply, nil
}
// PUT /v1/connect/ca/configuration
func (s *HTTPServer) ConnectCAConfigurationSet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in ConnectCAConfiguration
var args structs.CARequest
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req, &args.Config, nil); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
}
var reply interface{}
err := s.agent.RPC("ConnectCA.ConfigurationSet", &args, &reply)
return nil, err
}
// A hack to fix up the config types inside of the map[string]interface{}
// so that they get formatted correctly during json.Marshal. Without this,
// string values that get converted to []uint8 end up getting output back
// to the user in base64-encoded form.
func fixupConfig(conf *structs.CAConfiguration) {
for k, v := range conf.Config {
if raw, ok := v.([]uint8); ok {
strVal := ca.Uint8ToString(raw)
conf.Config[k] = strVal
switch conf.Provider {
case structs.ConsulCAProvider:
if k == "PrivateKey" && strVal != "" {
conf.Config["PrivateKey"] = "hidden"
}
case structs.VaultCAProvider:
if k == "Token" && strVal != "" {
conf.Config["Token"] = "hidden"
}
}
}
}
}

View File

@ -0,0 +1,115 @@
package agent
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
ca "github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/assert"
)
func TestConnectCARoots_empty(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "connect { enabled = false }")
defer a.Shutdown()
req, _ := http.NewRequest("GET", "/v1/connect/ca/roots", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.ConnectCARoots(resp, req)
assert.Nil(err)
value := obj.(structs.IndexedCARoots)
assert.Equal(value.ActiveRootID, "")
assert.Len(value.Roots, 0)
}
func TestConnectCARoots_list(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Set some CAs. Note that NewTestAgent already bootstraps one CA so this just
// adds a second and makes it active.
ca2 := connect.TestCAConfigSet(t, a, nil)
// List
req, _ := http.NewRequest("GET", "/v1/connect/ca/roots", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.ConnectCARoots(resp, req)
assert.NoError(err)
value := obj.(structs.IndexedCARoots)
assert.Equal(value.ActiveRootID, ca2.ID)
assert.Len(value.Roots, 2)
// We should never have the secret information
for _, r := range value.Roots {
assert.Equal("", r.SigningCert)
assert.Equal("", r.SigningKey)
}
}
func TestConnectCAConfig(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
expected := &structs.ConsulCAProviderConfig{
RotationPeriod: 90 * 24 * time.Hour,
}
// Get the initial config.
{
req, _ := http.NewRequest("GET", "/v1/connect/ca/configuration", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.ConnectCAConfiguration(resp, req)
assert.NoError(err)
value := obj.(structs.CAConfiguration)
parsed, err := ca.ParseConsulCAConfig(value.Config)
assert.NoError(err)
assert.Equal("consul", value.Provider)
assert.Equal(expected, parsed)
}
// Set the config.
{
body := bytes.NewBuffer([]byte(`
{
"Provider": "consul",
"Config": {
"RotationPeriod": 3600000000000
}
}`))
req, _ := http.NewRequest("PUT", "/v1/connect/ca/configuration", body)
resp := httptest.NewRecorder()
_, err := a.srv.ConnectCAConfiguration(resp, req)
assert.NoError(err)
}
// The config should be updated now.
{
expected.RotationPeriod = time.Hour
req, _ := http.NewRequest("GET", "/v1/connect/ca/configuration", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.ConnectCAConfiguration(resp, req)
assert.NoError(err)
value := obj.(structs.CAConfiguration)
parsed, err := ca.ParseConsulCAConfig(value.Config)
assert.NoError(err)
assert.Equal("consul", value.Provider)
assert.Equal(expected, parsed)
}
}

View File

@ -454,6 +454,33 @@ func (f *aclFilter) filterCoordinates(coords *structs.Coordinates) {
*coords = c
}
// filterIntentions is used to filter intentions based on ACL rules.
// We prune entries the user doesn't have access to, and we redact any tokens
// if the user doesn't have a management token.
func (f *aclFilter) filterIntentions(ixns *structs.Intentions) {
// Management tokens can see everything with no filtering.
if f.acl.ACLList() {
return
}
// Otherwise, we need to see what the token has access to.
ret := make(structs.Intentions, 0, len(*ixns))
for _, ixn := range *ixns {
// If no prefix ACL applies to this then filter it, since
// we know at this point the user doesn't have a management
// token, otherwise see what the policy says.
prefix, ok := ixn.GetACLPrefix()
if !ok || !f.acl.IntentionRead(prefix) {
f.logger.Printf("[DEBUG] consul: dropping intention %q from result due to ACLs", ixn.ID)
continue
}
ret = append(ret, ixn)
}
*ixns = ret
}
// filterNodeDump is used to filter through all parts of a node dump and
// remove elements the provided ACL token cannot access.
func (f *aclFilter) filterNodeDump(dump *structs.NodeDump) {
@ -598,6 +625,9 @@ func (s *Server) filterACL(token string, subj interface{}) error {
case *structs.IndexedHealthChecks:
filt.filterHealthChecks(&v.HealthChecks)
case *structs.IndexedIntentions:
filt.filterIntentions(&v.Intentions)
case *structs.IndexedNodeDump:
filt.filterNodeDump(&v.Dump)

View File

@ -11,6 +11,7 @@ import (
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/consul/testutil/retry"
"github.com/stretchr/testify/assert"
)
var testACLPolicy = `
@ -847,6 +848,58 @@ node "node1" {
}
}
func TestACL_filterIntentions(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fill := func() structs.Intentions {
return structs.Intentions{
&structs.Intention{
ID: "f004177f-2c28-83b7-4229-eacc25fe55d1",
DestinationName: "bar",
},
&structs.Intention{
ID: "f004177f-2c28-83b7-4229-eacc25fe55d2",
DestinationName: "foo",
},
}
}
// Try permissive filtering.
{
ixns := fill()
filt := newACLFilter(acl.AllowAll(), nil, false)
filt.filterIntentions(&ixns)
assert.Len(ixns, 2)
}
// Try restrictive filtering.
{
ixns := fill()
filt := newACLFilter(acl.DenyAll(), nil, false)
filt.filterIntentions(&ixns)
assert.Len(ixns, 0)
}
// Policy to see one
policy, err := acl.Parse(`
service "foo" {
policy = "read"
}
`, nil)
assert.Nil(err)
perms, err := acl.New(acl.DenyAll(), policy, nil)
assert.Nil(err)
// Filter
{
ixns := fill()
filt := newACLFilter(perms, nil, false)
filt.filterIntentions(&ixns)
assert.Len(ixns, 1)
}
}
func TestACL_filterServices(t *testing.T) {
t.Parallel()
// Create some services

View File

@ -47,6 +47,13 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error
// Handle a service registration.
if args.Service != nil {
// Validate the service. This is in addition to the below since
// the above just hasn't been moved over yet. We should move it over
// in time.
if err := args.Service.Validate(); err != nil {
return err
}
// If no service id, but service name, use default
if args.Service.ID == "" && args.Service.Service != "" {
args.Service.ID = args.Service.Service
@ -73,6 +80,13 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error
return acl.ErrPermissionDenied
}
}
// Proxies must have write permission on their destination
if args.Service.Kind == structs.ServiceKindConnectProxy {
if rule != nil && !rule.ServiceWrite(args.Service.ProxyDestination, nil) {
return acl.ErrPermissionDenied
}
}
}
// Move the old format single check into the slice, and fixup IDs.
@ -244,24 +258,52 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru
return fmt.Errorf("Must provide service name")
}
// Determine the function we'll call
var f func(memdb.WatchSet, *state.Store) (uint64, structs.ServiceNodes, error)
switch {
case args.Connect:
f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) {
return s.ConnectServiceNodes(ws, args.ServiceName)
}
default:
f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) {
if args.ServiceAddress != "" {
return s.ServiceAddressNodes(ws, args.ServiceAddress)
}
if args.TagFilter {
return s.ServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
}
return s.ServiceNodes(ws, args.ServiceName)
}
}
// If we're doing a connect query, we need read access to the service
// we're trying to find proxies for, so check that.
if args.Connect {
// Fetch the ACL token, if any.
rule, err := c.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil && !rule.ServiceRead(args.ServiceName) {
// Just return nil, which will return an empty response (tested)
return nil
}
}
err := c.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
var index uint64
var services structs.ServiceNodes
var err error
if args.TagFilter {
index, services, err = state.ServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
} else {
index, services, err = state.ServiceNodes(ws, args.ServiceName)
}
if args.ServiceAddress != "" {
index, services, err = state.ServiceAddressNodes(ws, args.ServiceAddress)
}
index, services, err := f(ws, state)
if err != nil {
return err
}
reply.Index, reply.ServiceNodes = index, services
if len(args.NodeMetaFilters) > 0 {
var filtered structs.ServiceNodes
@ -280,17 +322,24 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru
// Provide some metrics
if err == nil {
metrics.IncrCounterWithLabels([]string{"catalog", "service", "query"}, 1,
// For metrics, we separate Connect-based lookups from non-Connect
key := "service"
if args.Connect {
key = "connect"
}
metrics.IncrCounterWithLabels([]string{"catalog", key, "query"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
if args.ServiceTag != "" {
metrics.IncrCounterWithLabels([]string{"catalog", "service", "query-tag"}, 1,
metrics.IncrCounterWithLabels([]string{"catalog", key, "query-tag"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}, {Name: "tag", Value: args.ServiceTag}})
}
if len(reply.ServiceNodes) == 0 {
metrics.IncrCounterWithLabels([]string{"catalog", "service", "not-found"}, 1,
metrics.IncrCounterWithLabels([]string{"catalog", key, "not-found"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
}
}
return err
}

View File

@ -16,6 +16,8 @@ import (
"github.com/hashicorp/consul/testutil/retry"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCatalog_Register(t *testing.T) {
@ -332,6 +334,147 @@ func TestCatalog_Register_ForwardDC(t *testing.T) {
}
}
func TestCatalog_Register_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
args := structs.TestRegisterRequestProxy(t)
// Register
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// List
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Service,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
}
// Test an invalid ConnectProxy. We don't need to exhaustively test because
// this is all tested in structs on the Validate method.
func TestCatalog_Register_ConnectProxy_invalid(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
args := structs.TestRegisterRequestProxy(t)
args.Service.ProxyDestination = ""
// Register
var out struct{}
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.NotNil(err)
assert.Contains(err.Error(), "ProxyDestination")
}
// Test that write is required for the proxy destination to register a proxy.
func TestCatalog_Register_ConnectProxy_ACLProxyDestination(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.ACLDatacenter = "dc1"
c.ACLMasterToken = "root"
c.ACLDefaultPolicy = "deny"
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create the ACL.
arg := structs.ACLRequest{
Datacenter: "dc1",
Op: structs.ACLSet,
ACL: structs.ACL{
Name: "User token",
Type: structs.ACLTypeClient,
Rules: `
service "foo" {
policy = "write"
}
`,
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
var token string
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &token))
// Register should fail because we don't have permission on the destination
args := structs.TestRegisterRequestProxy(t)
args.Service.Service = "foo"
args.Service.ProxyDestination = "bar"
args.WriteRequest.Token = token
var out struct{}
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.True(acl.IsErrPermissionDenied(err))
// Register should fail with the right destination but wrong name
args = structs.TestRegisterRequestProxy(t)
args.Service.Service = "bar"
args.Service.ProxyDestination = "foo"
args.WriteRequest.Token = token
err = msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
assert.True(acl.IsErrPermissionDenied(err))
// Register should work with the right destination
args = structs.TestRegisterRequestProxy(t)
args.Service.Service = "foo"
args.Service.ProxyDestination = "foo"
args.WriteRequest.Token = token
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
}
func TestCatalog_Register_ConnectNative(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
args := structs.TestRegisterRequest(t)
args.Service.Connect.Native = true
// Register
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// List
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Service,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindTypical, v.ServiceKind)
assert.True(v.ServiceConnect.Native)
}
func TestCatalog_Deregister(t *testing.T) {
t.Parallel()
dir1, s1 := testServer(t)
@ -1599,6 +1742,246 @@ func TestCatalog_ListServiceNodes_DistanceSort(t *testing.T) {
}
}
func TestCatalog_ListServiceNodes_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the service
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Service,
TagFilter: false,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
}
func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the proxy service
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// Register the service
{
dst := args.Service.ProxyDestination
args := structs.TestRegisterRequest(t)
args.Service.Service = dst
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
}
// List
req := structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: args.Service.ProxyDestination,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
// List by non-Connect
req = structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.ProxyDestination,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v = resp.ServiceNodes[0]
assert.Equal(args.Service.ProxyDestination, v.ServiceName)
assert.Equal("", v.ServiceProxyDestination)
}
// Test that calling ServiceNodes with Connect: true will return
// Connect native services.
func TestCatalog_ListServiceNodes_ConnectDestinationNative(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the native service
args := structs.TestRegisterRequest(t)
args.Service.Connect.Native = true
var out struct{}
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: args.Service.Service,
}
var resp structs.IndexedServiceNodes
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
require.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
require.Equal(args.Service.Service, v.ServiceName)
// List by non-Connect
req = structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Service,
}
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
require.Len(resp.ServiceNodes, 1)
v = resp.ServiceNodes[0]
require.Equal(args.Service.Service, v.ServiceName)
}
func TestCatalog_ListServiceNodes_ConnectProxy_ACL(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.ACLDatacenter = "dc1"
c.ACLMasterToken = "root"
c.ACLDefaultPolicy = "deny"
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create the ACL.
arg := structs.ACLRequest{
Datacenter: "dc1",
Op: structs.ACLSet,
ACL: structs.ACL{
Name: "User token",
Type: structs.ACLTypeClient,
Rules: `
service "foo" {
policy = "write"
}
`,
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
var token string
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &token))
{
// Register a proxy
args := structs.TestRegisterRequestProxy(t)
args.Service.Service = "foo-proxy"
args.Service.ProxyDestination = "bar"
args.WriteRequest.Token = "root"
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a proxy
args = structs.TestRegisterRequestProxy(t)
args.Service.Service = "foo-proxy"
args.Service.ProxyDestination = "foo"
args.WriteRequest.Token = "root"
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a proxy
args = structs.TestRegisterRequestProxy(t)
args.Service.Service = "another-proxy"
args.Service.ProxyDestination = "foo"
args.WriteRequest.Token = "root"
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
}
// List w/ token. This should disallow because we don't have permission
// to read "bar"
req := structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: "bar",
QueryOptions: structs.QueryOptions{Token: token},
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 0)
// List w/ token. This should work since we're requesting "foo", but should
// also only contain the proxies with names that adhere to our ACL.
req = structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: "foo",
QueryOptions: structs.QueryOptions{Token: token},
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal("foo-proxy", v.ServiceName)
}
func TestCatalog_ListServiceNodes_ConnectNative(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the service
args := structs.TestRegisterRequest(t)
args.Service.Connect.Native = true
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.ServiceSpecificRequest{
Datacenter: "dc1",
ServiceName: args.Service.Service,
TagFilter: false,
}
var resp structs.IndexedServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
assert.Len(resp.ServiceNodes, 1)
v := resp.ServiceNodes[0]
assert.Equal(args.Service.Connect.Native, v.ServiceConnect.Native)
}
func TestCatalog_NodeServices(t *testing.T) {
t.Parallel()
dir1, s1 := testServer(t)
@ -1649,6 +2032,67 @@ func TestCatalog_NodeServices(t *testing.T) {
}
}
func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the service
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.NodeSpecificRequest{
Datacenter: "dc1",
Node: args.Node,
}
var resp structs.IndexedNodeServices
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
assert.Len(resp.NodeServices.Services, 1)
v := resp.NodeServices.Services[args.Service.Service]
assert.Equal(structs.ServiceKindConnectProxy, v.Kind)
assert.Equal(args.Service.ProxyDestination, v.ProxyDestination)
}
func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Register the service
args := structs.TestRegisterRequest(t)
var out struct{}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
// List
req := structs.NodeSpecificRequest{
Datacenter: "dc1",
Node: args.Node,
}
var resp structs.IndexedNodeServices
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
assert.Len(resp.NodeServices.Services, 1)
v := resp.NodeServices.Services[args.Service.Service]
assert.Equal(args.Service.Connect.Native, v.Connect.Native)
}
// Used to check for a regression against a known bug
func TestCatalog_Register_FailedCase1(t *testing.T) {
t.Parallel()

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/consul/types"
@ -346,6 +347,13 @@ type Config struct {
// autopilot tasks, such as promoting eligible non-voters and removing
// dead servers.
AutopilotInterval time.Duration
// ConnectEnabled is whether to enable Connect features such as the CA.
ConnectEnabled bool
// CAConfig is used to apply the initial Connect CA configuration when
// bootstrapping.
CAConfig *structs.CAConfiguration
}
// CheckProtocolVersion validates the protocol version.
@ -425,6 +433,13 @@ func DefaultConfig() *Config {
ServerStabilizationTime: 10 * time.Second,
},
CAConfig: &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"RotationPeriod": "2160h",
},
},
ServerHealthInterval: 2 * time.Second,
AutopilotInterval: 10 * time.Second,
}

View File

@ -0,0 +1,393 @@
package consul
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
)
var ErrConnectNotEnabled = errors.New("Connect must be enabled in order to use this endpoint")
// ConnectCA manages the Connect CA.
type ConnectCA struct {
// srv is a pointer back to the server.
srv *Server
}
// ConfigurationGet returns the configuration for the CA.
func (s *ConnectCA) ConfigurationGet(
args *structs.DCSpecificRequest,
reply *structs.CAConfiguration) error {
// Exit early if Connect hasn't been enabled.
if !s.srv.config.ConnectEnabled {
return ErrConnectNotEnabled
}
if done, err := s.srv.forward("ConnectCA.ConfigurationGet", args, args, reply); done {
return err
}
// This action requires operator read access.
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil && !rule.OperatorRead() {
return acl.ErrPermissionDenied
}
state := s.srv.fsm.State()
_, config, err := state.CAConfig()
if err != nil {
return err
}
*reply = *config
return nil
}
// ConfigurationSet updates the configuration for the CA.
func (s *ConnectCA) ConfigurationSet(
args *structs.CARequest,
reply *interface{}) error {
// Exit early if Connect hasn't been enabled.
if !s.srv.config.ConnectEnabled {
return ErrConnectNotEnabled
}
if done, err := s.srv.forward("ConnectCA.ConfigurationSet", args, args, reply); done {
return err
}
// This action requires operator write access.
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil && !rule.OperatorWrite() {
return acl.ErrPermissionDenied
}
// Exit early if it's a no-op change
state := s.srv.fsm.State()
_, config, err := state.CAConfig()
if err != nil {
return err
}
args.Config.ClusterID = config.ClusterID
if args.Config.Provider == config.Provider && reflect.DeepEqual(args.Config.Config, config.Config) {
return nil
}
// Create a new instance of the provider described by the config
// and get the current active root CA. This acts as a good validation
// of the config and makes sure the provider is functioning correctly
// before we commit any changes to Raft.
newProvider, err := s.srv.createCAProvider(args.Config)
if err != nil {
return fmt.Errorf("could not initialize provider: %v", err)
}
newRootPEM, err := newProvider.ActiveRoot()
if err != nil {
return err
}
newActiveRoot, err := parseCARoot(newRootPEM, args.Config.Provider)
if err != nil {
return err
}
// Compare the new provider's root CA ID to the current one. If they
// match, just update the existing provider with the new config.
// If they don't match, begin the root rotation process.
_, root, err := state.CARootActive(nil)
if err != nil {
return err
}
if root != nil && root.ID == newActiveRoot.ID {
args.Op = structs.CAOpSetConfig
resp, err := s.srv.raftApply(structs.ConnectCARequestType, args)
if err != nil {
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
// If the config has been committed, update the local provider instance
s.srv.setCAProvider(newProvider, newActiveRoot)
s.srv.logger.Printf("[INFO] connect: CA provider config updated")
return nil
}
// At this point, we know the config change has trigged a root rotation,
// either by swapping the provider type or changing the provider's config
// to use a different root certificate.
// If it's a config change that would trigger a rotation (different provider/root):
// 1. Get the root from the new provider.
// 2. Call CrossSignCA on the old provider to sign the new root with the old one to
// get a cross-signed certificate.
// 3. Take the active root for the new provider and append the intermediate from step 2
// to its list of intermediates.
newRoot, err := connect.ParseCert(newRootPEM)
if err != nil {
return err
}
// Have the old provider cross-sign the new intermediate
oldProvider, _ := s.srv.getCAProvider()
if oldProvider == nil {
return fmt.Errorf("internal error: CA provider is nil")
}
xcCert, err := oldProvider.CrossSignCA(newRoot)
if err != nil {
return err
}
// Add the cross signed cert to the new root's intermediates.
newActiveRoot.IntermediateCerts = []string{xcCert}
intermediate, err := newProvider.GenerateIntermediate()
if err != nil {
return err
}
if intermediate != newRootPEM {
newActiveRoot.IntermediateCerts = append(newActiveRoot.IntermediateCerts, intermediate)
}
// Update the roots and CA config in the state store at the same time
idx, roots, err := state.CARoots(nil)
if err != nil {
return err
}
var newRoots structs.CARoots
for _, r := range roots {
newRoot := *r
if newRoot.Active {
newRoot.Active = false
}
newRoots = append(newRoots, &newRoot)
}
newRoots = append(newRoots, newActiveRoot)
args.Op = structs.CAOpSetRootsAndConfig
args.Index = idx
args.Roots = newRoots
resp, err := s.srv.raftApply(structs.ConnectCARequestType, args)
if err != nil {
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
// If the config has been committed, update the local provider instance
// and call teardown on the old provider
s.srv.setCAProvider(newProvider, newActiveRoot)
if err := oldProvider.Cleanup(); err != nil {
s.srv.logger.Printf("[WARN] connect: failed to clean up old provider %q", config.Provider)
}
s.srv.logger.Printf("[INFO] connect: CA rotated to new root under provider %q", args.Config.Provider)
return nil
}
// Roots returns the currently trusted root certificates.
func (s *ConnectCA) Roots(
args *structs.DCSpecificRequest,
reply *structs.IndexedCARoots) error {
// Forward if necessary
if done, err := s.srv.forward("ConnectCA.Roots", args, args, reply); done {
return err
}
// Load the ClusterID to generate TrustDomain. We do this outside the loop
// since by definition this value should be immutable once set for lifetime of
// the cluster so we don't need to look it up more than once. We also don't
// have to worry about non-atomicity between the config fetch transaction and
// the CARoots transaction below since this field must remain immutable. Do
// not re-use this state/config for other logic that might care about changes
// of config during the blocking query below.
{
state := s.srv.fsm.State()
_, config, err := state.CAConfig()
if err != nil {
return err
}
// Check CA is actually bootstrapped...
if config != nil {
// Build TrustDomain based on the ClusterID stored.
signingID := connect.SpiffeIDSigningForCluster(config)
if signingID == nil {
// If CA is bootstrapped at all then this should never happen but be
// defensive.
return errors.New("no cluster trust domain setup")
}
reply.TrustDomain = signingID.Host()
}
}
return s.srv.blockingQuery(
&args.QueryOptions, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
index, roots, err := state.CARoots(ws)
if err != nil {
return err
}
reply.Index, reply.Roots = index, roots
if reply.Roots == nil {
reply.Roots = make(structs.CARoots, 0)
}
// The API response must NEVER contain the secret information
// such as keys and so on. We use a whitelist below to copy the
// specific fields we want to expose.
for i, r := range reply.Roots {
// IMPORTANT: r must NEVER be modified, since it is a pointer
// directly to the structure in the memdb store.
reply.Roots[i] = &structs.CARoot{
ID: r.ID,
Name: r.Name,
SerialNumber: r.SerialNumber,
SigningKeyID: r.SigningKeyID,
NotBefore: r.NotBefore,
NotAfter: r.NotAfter,
RootCert: r.RootCert,
IntermediateCerts: r.IntermediateCerts,
RaftIndex: r.RaftIndex,
Active: r.Active,
}
if r.Active {
reply.ActiveRootID = r.ID
}
}
return nil
},
)
}
// Sign signs a certificate for a service.
func (s *ConnectCA) Sign(
args *structs.CASignRequest,
reply *structs.IssuedCert) error {
// Exit early if Connect hasn't been enabled.
if !s.srv.config.ConnectEnabled {
return ErrConnectNotEnabled
}
if done, err := s.srv.forward("ConnectCA.Sign", args, args, reply); done {
return err
}
// Parse the CSR
csr, err := connect.ParseCSR(args.CSR)
if err != nil {
return err
}
// Parse the SPIFFE ID
spiffeID, err := connect.ParseCertURI(csr.URIs[0])
if err != nil {
return err
}
serviceID, ok := spiffeID.(*connect.SpiffeIDService)
if !ok {
return fmt.Errorf("SPIFFE ID in CSR must be a service ID")
}
provider, caRoot := s.srv.getCAProvider()
if provider == nil {
return fmt.Errorf("internal error: CA provider is nil")
}
// Verify that the CSR entity is in the cluster's trust domain
state := s.srv.fsm.State()
_, config, err := state.CAConfig()
if err != nil {
return err
}
signingID := connect.SpiffeIDSigningForCluster(config)
if !signingID.CanSign(serviceID) {
return fmt.Errorf("SPIFFE ID in CSR from a different trust domain: %s, "+
"we are %s", serviceID.Host, signingID.Host())
}
// Verify that the ACL token provided has permission to act as this service
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil && !rule.ServiceWrite(serviceID.Service, nil) {
return acl.ErrPermissionDenied
}
// Verify that the DC in the service URI matches us. We might relax this
// requirement later but being restrictive for now is safer.
if serviceID.Datacenter != s.srv.config.Datacenter {
return fmt.Errorf("SPIFFE ID in CSR from a different datacenter: %s, "+
"we are %s", serviceID.Datacenter, s.srv.config.Datacenter)
}
// All seems to be in order, actually sign it.
pem, err := provider.Sign(csr)
if err != nil {
return err
}
// Append any intermediates needed by this root.
for _, p := range caRoot.IntermediateCerts {
pem = strings.TrimSpace(pem) + "\n" + p
}
// TODO(banks): when we implement IssuedCerts table we can use the insert to
// that as the raft index to return in response. Right now we can rely on only
// the built-in provider being supported and the implementation detail that we
// have to write a SerialIndex update to the provider config table for every
// cert issued so in all cases this index will be higher than any previous
// sign response. This has to be reloaded after the provider.Sign call to
// observe the index update.
state = s.srv.fsm.State()
modIdx, _, err := state.CAConfig()
if err != nil {
return err
}
cert, err := connect.ParseCert(pem)
if err != nil {
return err
}
// Set the response
*reply = structs.IssuedCert{
SerialNumber: connect.HexString(cert.SerialNumber.Bytes()),
CertPEM: pem,
Service: serviceID.Service,
ServiceURI: cert.URIs[0].String(),
ValidAfter: cert.NotBefore,
ValidBefore: cert.NotAfter,
RaftIndex: structs.RaftIndex{
ModifyIndex: modIdx,
CreateIndex: modIdx,
},
}
return nil
}

View File

@ -0,0 +1,434 @@
package consul
import (
"crypto/x509"
"encoding/pem"
"fmt"
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/connect"
ca "github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/stretchr/testify/assert"
)
func testParseCert(t *testing.T, pemValue string) *x509.Certificate {
cert, err := connect.ParseCert(pemValue)
if err != nil {
t.Fatal(err)
}
return cert
}
// Test listing root CAs.
func TestConnectCARoots(t *testing.T) {
t.Parallel()
assert := assert.New(t)
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Insert some CAs
state := s1.fsm.State()
ca1 := connect.TestCA(t, nil)
ca2 := connect.TestCA(t, nil)
ca2.Active = false
idx, _, err := state.CARoots(nil)
require.NoError(err)
ok, err := state.CARootSetCAS(idx, idx, []*structs.CARoot{ca1, ca2})
assert.True(ok)
require.NoError(err)
_, caCfg, err := state.CAConfig()
require.NoError(err)
// Request
args := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var reply structs.IndexedCARoots
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
// Verify
assert.Equal(ca1.ID, reply.ActiveRootID)
assert.Len(reply.Roots, 2)
for _, r := range reply.Roots {
// These must never be set, for security
assert.Equal("", r.SigningCert)
assert.Equal("", r.SigningKey)
}
assert.Equal(fmt.Sprintf("%s.consul", caCfg.ClusterID), reply.TrustDomain)
}
func TestConnectCAConfig_GetSet(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Get the starting config
{
args := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var reply structs.CAConfiguration
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config)
assert.NoError(err)
expected, err := ca.ParseConsulCAConfig(s1.config.CAConfig.Config)
assert.NoError(err)
assert.Equal(reply.Provider, s1.config.CAConfig.Provider)
assert.Equal(actual, expected)
}
// Update a config value
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": "",
"RootCert": "",
"RotationPeriod": 180 * 24 * time.Hour,
},
}
{
args := &structs.CARequest{
Datacenter: "dc1",
Config: newConfig,
}
var reply interface{}
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
}
// Verify the new config was set
{
args := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var reply structs.CAConfiguration
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config)
assert.NoError(err)
expected, err := ca.ParseConsulCAConfig(newConfig.Config)
assert.NoError(err)
assert.Equal(reply.Provider, newConfig.Provider)
assert.Equal(actual, expected)
}
}
func TestConnectCAConfig_TriggerRotation(t *testing.T) {
t.Parallel()
assert := assert.New(t)
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Store the current root
rootReq := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var rootList structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
assert.Len(rootList.Roots, 1)
oldRoot := rootList.Roots[0]
// Update the provider config to use a new private key, which should
// cause a rotation.
_, newKey, err := connect.GeneratePrivateKey()
assert.NoError(err)
newConfig := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": newKey,
"RootCert": "",
"RotationPeriod": 90 * 24 * time.Hour,
},
}
{
args := &structs.CARequest{
Datacenter: "dc1",
Config: newConfig,
}
var reply interface{}
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
}
// Make sure the new root has been added along with an intermediate
// cross-signed by the old root.
var newRootPEM string
{
args := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var reply structs.IndexedCARoots
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
assert.Len(reply.Roots, 2)
for _, r := range reply.Roots {
if r.ID == oldRoot.ID {
// The old root should no longer be marked as the active root,
// and none of its other fields should have changed.
assert.False(r.Active)
assert.Equal(r.Name, oldRoot.Name)
assert.Equal(r.RootCert, oldRoot.RootCert)
assert.Equal(r.SigningCert, oldRoot.SigningCert)
assert.Equal(r.IntermediateCerts, oldRoot.IntermediateCerts)
} else {
newRootPEM = r.RootCert
// The new root should have a valid cross-signed cert from the old
// root as an intermediate.
assert.True(r.Active)
assert.Len(r.IntermediateCerts, 1)
xc := testParseCert(t, r.IntermediateCerts[0])
oldRootCert := testParseCert(t, oldRoot.RootCert)
newRootCert := testParseCert(t, r.RootCert)
// Should have the authority key ID and signature algo of the
// (old) signing CA.
assert.Equal(xc.AuthorityKeyId, oldRootCert.AuthorityKeyId)
assert.NotEqual(xc.SubjectKeyId, oldRootCert.SubjectKeyId)
assert.Equal(xc.SignatureAlgorithm, oldRootCert.SignatureAlgorithm)
// The common name and SAN should not have changed.
assert.Equal(xc.Subject.CommonName, newRootCert.Subject.CommonName)
assert.Equal(xc.URIs, newRootCert.URIs)
}
}
}
// Verify the new config was set.
{
args := &structs.DCSpecificRequest{
Datacenter: "dc1",
}
var reply structs.CAConfiguration
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
actual, err := ca.ParseConsulCAConfig(reply.Config)
require.NoError(err)
expected, err := ca.ParseConsulCAConfig(newConfig.Config)
require.NoError(err)
assert.Equal(reply.Provider, newConfig.Provider)
assert.Equal(actual, expected)
}
// Verify that new leaf certs get the cross-signed intermediate bundled
{
// Generate a CSR and request signing
spiffeId := connect.TestSpiffeIDService(t, "web")
csr, _ := connect.TestCSR(t, spiffeId)
args := &structs.CASignRequest{
Datacenter: "dc1",
CSR: csr,
}
var reply structs.IssuedCert
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
// Verify that the cert is signed by the new CA
{
roots := x509.NewCertPool()
require.True(roots.AppendCertsFromPEM([]byte(newRootPEM)))
leaf, err := connect.ParseCert(reply.CertPEM)
require.NoError(err)
_, err = leaf.Verify(x509.VerifyOptions{
Roots: roots,
})
require.NoError(err)
}
// And that it validates via the intermediate
{
roots := x509.NewCertPool()
assert.True(roots.AppendCertsFromPEM([]byte(oldRoot.RootCert)))
leaf, err := connect.ParseCert(reply.CertPEM)
require.NoError(err)
// Make sure the intermediate was returned as well as leaf
_, rest := pem.Decode([]byte(reply.CertPEM))
require.NotEmpty(rest)
intermediates := x509.NewCertPool()
require.True(intermediates.AppendCertsFromPEM(rest))
_, err = leaf.Verify(x509.VerifyOptions{
Roots: roots,
Intermediates: intermediates,
})
require.NoError(err)
}
// Verify other fields
assert.Equal("web", reply.Service)
assert.Equal(spiffeId.URI().String(), reply.ServiceURI)
}
}
// Test CA signing
func TestConnectCASign(t *testing.T) {
t.Parallel()
assert := assert.New(t)
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Generate a CSR and request signing
spiffeId := connect.TestSpiffeIDService(t, "web")
csr, _ := connect.TestCSR(t, spiffeId)
args := &structs.CASignRequest{
Datacenter: "dc1",
CSR: csr,
}
var reply structs.IssuedCert
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
// Get the current CA
state := s1.fsm.State()
_, ca, err := state.CARootActive(nil)
require.NoError(err)
// Verify that the cert is signed by the CA
roots := x509.NewCertPool()
assert.True(roots.AppendCertsFromPEM([]byte(ca.RootCert)))
leaf, err := connect.ParseCert(reply.CertPEM)
require.NoError(err)
_, err = leaf.Verify(x509.VerifyOptions{
Roots: roots,
})
require.NoError(err)
// Verify other fields
assert.Equal("web", reply.Service)
assert.Equal(spiffeId.URI().String(), reply.ServiceURI)
}
func TestConnectCASignValidation(t *testing.T) {
t.Parallel()
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.ACLDatacenter = "dc1"
c.ACLMasterToken = "root"
c.ACLDefaultPolicy = "deny"
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create an ACL token with service:write for web*
var webToken string
{
arg := structs.ACLRequest{
Datacenter: "dc1",
Op: structs.ACLSet,
ACL: structs.ACL{
Name: "User token",
Type: structs.ACLTypeClient,
Rules: `
service "web" {
policy = "write"
}`,
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &webToken))
}
testWebID := connect.TestSpiffeIDService(t, "web")
tests := []struct {
name string
id connect.CertURI
wantErr string
}{
{
name: "different cluster",
id: &connect.SpiffeIDService{
Host: "55555555-4444-3333-2222-111111111111.consul",
Namespace: testWebID.Namespace,
Datacenter: testWebID.Datacenter,
Service: testWebID.Service,
},
wantErr: "different trust domain",
},
{
name: "same cluster should validate",
id: testWebID,
wantErr: "",
},
{
name: "same cluster, CSR for a different DC should NOT validate",
id: &connect.SpiffeIDService{
Host: testWebID.Host,
Namespace: testWebID.Namespace,
Datacenter: "dc2",
Service: testWebID.Service,
},
wantErr: "different datacenter",
},
{
name: "same cluster and DC, different service should not have perms",
id: &connect.SpiffeIDService{
Host: testWebID.Host,
Namespace: testWebID.Namespace,
Datacenter: testWebID.Datacenter,
Service: "db",
},
wantErr: "Permission denied",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
csr, _ := connect.TestCSR(t, tt.id)
args := &structs.CASignRequest{
Datacenter: "dc1",
CSR: csr,
WriteRequest: structs.WriteRequest{Token: webToken},
}
var reply structs.IssuedCert
err := msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply)
if tt.wantErr == "" {
require.NoError(t, err)
// No other validation that is handled in different tests
} else {
require.Error(t, err)
require.Contains(t, err.Error(), tt.wantErr)
}
})
}
}

View File

@ -0,0 +1,28 @@
package consul
import (
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
)
// consulCADelegate providers callbacks for the Consul CA provider
// to use the state store for its operations.
type consulCADelegate struct {
srv *Server
}
func (c *consulCADelegate) State() *state.Store {
return c.srv.fsm.State()
}
func (c *consulCADelegate) ApplyCARequest(req *structs.CARequest) error {
resp, err := c.srv.raftApply(structs.ConnectCARequestType, req)
if err != nil {
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
return nil
}

View File

@ -20,6 +20,8 @@ func init() {
registerCommand(structs.PreparedQueryRequestType, (*FSM).applyPreparedQueryOperation)
registerCommand(structs.TxnRequestType, (*FSM).applyTxn)
registerCommand(structs.AutopilotRequestType, (*FSM).applyAutopilotUpdate)
registerCommand(structs.IntentionRequestType, (*FSM).applyIntentionOperation)
registerCommand(structs.ConnectCARequestType, (*FSM).applyConnectCAOperation)
}
func (c *FSM) applyRegister(buf []byte, index uint64) interface{} {
@ -246,3 +248,85 @@ func (c *FSM) applyAutopilotUpdate(buf []byte, index uint64) interface{} {
}
return c.state.AutopilotSetConfig(index, &req.Config)
}
// applyIntentionOperation applies the given intention operation to the state store.
func (c *FSM) applyIntentionOperation(buf []byte, index uint64) interface{} {
var req structs.IntentionRequest
if err := structs.Decode(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode request: %v", err))
}
defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "intention"}, time.Now(),
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
defer metrics.MeasureSinceWithLabels([]string{"fsm", "intention"}, time.Now(),
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
switch req.Op {
case structs.IntentionOpCreate, structs.IntentionOpUpdate:
return c.state.IntentionSet(index, req.Intention)
case structs.IntentionOpDelete:
return c.state.IntentionDelete(index, req.Intention.ID)
default:
c.logger.Printf("[WARN] consul.fsm: Invalid Intention operation '%s'", req.Op)
return fmt.Errorf("Invalid Intention operation '%s'", req.Op)
}
}
// applyConnectCAOperation applies the given CA operation to the state store.
func (c *FSM) applyConnectCAOperation(buf []byte, index uint64) interface{} {
var req structs.CARequest
if err := structs.Decode(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode request: %v", err))
}
defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "ca"}, time.Now(),
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
defer metrics.MeasureSinceWithLabels([]string{"fsm", "ca"}, time.Now(),
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
switch req.Op {
case structs.CAOpSetConfig:
if req.Config.ModifyIndex != 0 {
act, err := c.state.CACheckAndSetConfig(index, req.Config.ModifyIndex, req.Config)
if err != nil {
return err
}
return act
}
return c.state.CASetConfig(index, req.Config)
case structs.CAOpSetRoots:
act, err := c.state.CARootSetCAS(index, req.Index, req.Roots)
if err != nil {
return err
}
return act
case structs.CAOpSetProviderState:
act, err := c.state.CASetProviderState(index, req.ProviderState)
if err != nil {
return err
}
return act
case structs.CAOpDeleteProviderState:
if err := c.state.CADeleteProviderState(req.ProviderState.ID); err != nil {
return err
}
return true
case structs.CAOpSetRootsAndConfig:
act, err := c.state.CARootSetCAS(index, req.Index, req.Roots)
if err != nil {
return err
}
if err := c.state.CASetConfig(index+1, req.Config); err != nil {
return err
}
return act
default:
c.logger.Printf("[WARN] consul.fsm: Invalid CA operation '%s'", req.Op)
return fmt.Errorf("Invalid CA operation '%s'", req.Op)
}
}

View File

@ -8,13 +8,16 @@ import (
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/serf/coordinate"
"github.com/mitchellh/mapstructure"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
)
func generateUUID() (ret string) {
@ -1148,3 +1151,209 @@ func TestFSM_Autopilot(t *testing.T) {
t.Fatalf("bad: %v", config.CleanupDeadServers)
}
}
func TestFSM_Intention_CRUD(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fsm, err := New(nil, os.Stderr)
assert.Nil(err)
// Create a new intention.
ixn := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: structs.TestIntention(t),
}
ixn.Intention.ID = generateUUID()
ixn.Intention.UpdatePrecedence()
{
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err)
assert.Nil(fsm.Apply(makeLog(buf)))
}
// Verify it's in the state store.
{
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err)
actual.CreateIndex, actual.ModifyIndex = 0, 0
actual.CreatedAt = ixn.Intention.CreatedAt
actual.UpdatedAt = ixn.Intention.UpdatedAt
assert.Equal(ixn.Intention, actual)
}
// Make an update
ixn.Op = structs.IntentionOpUpdate
ixn.Intention.SourceName = "api"
{
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err)
assert.Nil(fsm.Apply(makeLog(buf)))
}
// Verify the update.
{
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err)
actual.CreateIndex, actual.ModifyIndex = 0, 0
actual.CreatedAt = ixn.Intention.CreatedAt
actual.UpdatedAt = ixn.Intention.UpdatedAt
assert.Equal(ixn.Intention, actual)
}
// Delete
ixn.Op = structs.IntentionOpDelete
{
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
assert.Nil(err)
assert.Nil(fsm.Apply(makeLog(buf)))
}
// Make sure it's gone.
{
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
assert.Nil(err)
assert.Nil(actual)
}
}
func TestFSM_CAConfig(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fsm, err := New(nil, os.Stderr)
assert.Nil(err)
// Set the autopilot config using a request.
req := structs.CARequest{
Op: structs.CAOpSetConfig,
Config: &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": "asdf",
"RootCert": "qwer",
"RotationPeriod": 90 * 24 * time.Hour,
},
},
}
buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err)
resp := fsm.Apply(makeLog(buf))
if _, ok := resp.(error); ok {
t.Fatalf("bad: %v", resp)
}
// Verify key is set directly in the state store.
_, config, err := fsm.state.CAConfig()
if err != nil {
t.Fatalf("err: %v", err)
}
var conf *structs.ConsulCAProviderConfig
if err := mapstructure.WeakDecode(config.Config, &conf); err != nil {
t.Fatalf("error decoding config: %s, %v", err, config.Config)
}
if got, want := config.Provider, req.Config.Provider; got != want {
t.Fatalf("got %v, want %v", got, want)
}
if got, want := conf.PrivateKey, "asdf"; got != want {
t.Fatalf("got %v, want %v", got, want)
}
if got, want := conf.RootCert, "qwer"; got != want {
t.Fatalf("got %v, want %v", got, want)
}
if got, want := conf.RotationPeriod, 90*24*time.Hour; got != want {
t.Fatalf("got %v, want %v", got, want)
}
// Now use CAS and provide an old index
req.Config.Provider = "static"
req.Config.ModifyIndex = config.ModifyIndex - 1
buf, err = structs.Encode(structs.ConnectCARequestType, req)
if err != nil {
t.Fatalf("err: %v", err)
}
resp = fsm.Apply(makeLog(buf))
if _, ok := resp.(error); ok {
t.Fatalf("bad: %v", resp)
}
_, config, err = fsm.state.CAConfig()
assert.Nil(err)
if config.Provider != "static" {
t.Fatalf("bad: %v", config.Provider)
}
}
func TestFSM_CARoots(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fsm, err := New(nil, os.Stderr)
assert.Nil(err)
// Roots
ca1 := connect.TestCA(t, nil)
ca2 := connect.TestCA(t, nil)
ca2.Active = false
// Create a new request.
req := structs.CARequest{
Op: structs.CAOpSetRoots,
Roots: []*structs.CARoot{ca1, ca2},
}
{
buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err)
assert.True(fsm.Apply(makeLog(buf)).(bool))
}
// Verify it's in the state store.
{
_, roots, err := fsm.state.CARoots(nil)
assert.Nil(err)
assert.Len(roots, 2)
}
}
func TestFSM_CABuiltinProvider(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fsm, err := New(nil, os.Stderr)
assert.Nil(err)
// Provider state.
expected := &structs.CAConsulProviderState{
ID: "foo",
PrivateKey: "a",
RootCert: "b",
RaftIndex: structs.RaftIndex{
CreateIndex: 1,
ModifyIndex: 1,
},
}
// Create a new request.
req := structs.CARequest{
Op: structs.CAOpSetProviderState,
ProviderState: expected,
}
{
buf, err := structs.Encode(structs.ConnectCARequestType, req)
assert.Nil(err)
assert.True(fsm.Apply(makeLog(buf)).(bool))
}
// Verify it's in the state store.
{
_, state, err := fsm.state.CAProviderState("foo")
assert.Nil(err)
assert.Equal(expected, state)
}
}

View File

@ -20,6 +20,8 @@ func init() {
registerRestorer(structs.CoordinateBatchUpdateType, restoreCoordinates)
registerRestorer(structs.PreparedQueryRequestType, restorePreparedQuery)
registerRestorer(structs.AutopilotRequestType, restoreAutopilot)
registerRestorer(structs.IntentionRequestType, restoreIntention)
registerRestorer(structs.ConnectCARequestType, restoreConnectCA)
}
func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error {
@ -44,6 +46,12 @@ func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) err
if err := s.persistAutopilot(sink, encoder); err != nil {
return err
}
if err := s.persistIntentions(sink, encoder); err != nil {
return err
}
if err := s.persistConnectCA(sink, encoder); err != nil {
return err
}
return nil
}
@ -258,6 +266,42 @@ func (s *snapshot) persistAutopilot(sink raft.SnapshotSink,
return nil
}
func (s *snapshot) persistConnectCA(sink raft.SnapshotSink,
encoder *codec.Encoder) error {
roots, err := s.state.CARoots()
if err != nil {
return err
}
for _, r := range roots {
if _, err := sink.Write([]byte{byte(structs.ConnectCARequestType)}); err != nil {
return err
}
if err := encoder.Encode(r); err != nil {
return err
}
}
return nil
}
func (s *snapshot) persistIntentions(sink raft.SnapshotSink,
encoder *codec.Encoder) error {
ixns, err := s.state.Intentions()
if err != nil {
return err
}
for _, ixn := range ixns {
if _, err := sink.Write([]byte{byte(structs.IntentionRequestType)}); err != nil {
return err
}
if err := encoder.Encode(ixn); err != nil {
return err
}
}
return nil
}
func restoreRegistration(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
var req structs.RegisterRequest
if err := decoder.Decode(&req); err != nil {
@ -364,3 +408,25 @@ func restoreAutopilot(header *snapshotHeader, restore *state.Restore, decoder *c
}
return nil
}
func restoreIntention(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
var req structs.Intention
if err := decoder.Decode(&req); err != nil {
return err
}
if err := restore.Intention(&req); err != nil {
return err
}
return nil
}
func restoreConnectCA(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
var req structs.CARoot
if err := decoder.Decode(&req); err != nil {
return err
}
if err := restore.CARoot(&req); err != nil {
return err
}
return nil
}

View File

@ -7,16 +7,20 @@ import (
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
)
func TestFSM_SnapshotRestore_OSS(t *testing.T) {
t.Parallel()
assert := assert.New(t)
fsm, err := New(nil, os.Stderr)
if err != nil {
t.Fatalf("err: %v", err)
@ -98,6 +102,27 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
t.Fatalf("err: %s", err)
}
// Intentions
ixn := structs.TestIntention(t)
ixn.ID = generateUUID()
ixn.RaftIndex = structs.RaftIndex{
CreateIndex: 14,
ModifyIndex: 14,
}
assert.Nil(fsm.state.IntentionSet(14, ixn))
// CA Roots
roots := []*structs.CARoot{
connect.TestCA(t, nil),
connect.TestCA(t, nil),
}
for _, r := range roots[1:] {
r.Active = false
}
ok, err := fsm.state.CARootSetCAS(15, 0, roots)
assert.Nil(err)
assert.True(ok)
// Snapshot
snap, err := fsm.Snapshot()
if err != nil {
@ -260,6 +285,17 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
t.Fatalf("bad: %#v, %#v", restoredConf, autopilotConf)
}
// Verify intentions are restored.
_, ixns, err := fsm2.state.Intentions(nil)
assert.Nil(err)
assert.Len(ixns, 1)
assert.Equal(ixn, ixns[0])
// Verify CA roots are restored.
_, roots, err = fsm2.state.CARoots(nil)
assert.Nil(err)
assert.Len(roots, 2)
// Snapshot
snap, err = fsm2.Snapshot()
if err != nil {

View File

@ -111,18 +111,37 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc
return fmt.Errorf("Must provide service name")
}
// Determine the function we'll call
var f func(memdb.WatchSet, *state.Store, *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error)
switch {
case args.Connect:
f = h.serviceNodesConnect
case args.TagFilter:
f = h.serviceNodesTagFilter
default:
f = h.serviceNodesDefault
}
// If we're doing a connect query, we need read access to the service
// we're trying to find proxies for, so check that.
if args.Connect {
// Fetch the ACL token, if any.
rule, err := h.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil && !rule.ServiceRead(args.ServiceName) {
// Just return nil, which will return an empty response (tested)
return nil
}
}
err := h.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
var index uint64
var nodes structs.CheckServiceNodes
var err error
if args.TagFilter {
index, nodes, err = state.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
} else {
index, nodes, err = state.CheckServiceNodes(ws, args.ServiceName)
}
index, nodes, err := f(ws, state, args)
if err != nil {
return err
}
@ -139,16 +158,37 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc
// Provide some metrics
if err == nil {
metrics.IncrCounterWithLabels([]string{"health", "service", "query"}, 1,
// For metrics, we separate Connect-based lookups from non-Connect
key := "service"
if args.Connect {
key = "connect"
}
metrics.IncrCounterWithLabels([]string{"health", key, "query"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
if args.ServiceTag != "" {
metrics.IncrCounterWithLabels([]string{"health", "service", "query-tag"}, 1,
metrics.IncrCounterWithLabels([]string{"health", key, "query-tag"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}, {Name: "tag", Value: args.ServiceTag}})
}
if len(reply.Nodes) == 0 {
metrics.IncrCounterWithLabels([]string{"health", "service", "not-found"}, 1,
metrics.IncrCounterWithLabels([]string{"health", key, "not-found"}, 1,
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
}
}
return err
}
// The serviceNodes* functions below are the various lookup methods that
// can be used by the ServiceNodes endpoint.
func (h *Health) serviceNodesConnect(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
return s.CheckConnectServiceNodes(ws, args.ServiceName)
}
func (h *Health) serviceNodesTagFilter(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
return s.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
}
func (h *Health) serviceNodesDefault(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
return s.CheckServiceNodes(ws, args.ServiceName)
}

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/stretchr/testify/assert"
)
func TestHealth_ChecksInState(t *testing.T) {
@ -821,6 +822,106 @@ func TestHealth_ServiceNodes_DistanceSort(t *testing.T) {
}
}
func TestHealth_ServiceNodes_ConnectProxy_ACL(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.ACLDatacenter = "dc1"
c.ACLMasterToken = "root"
c.ACLDefaultPolicy = "deny"
c.ACLEnforceVersion8 = false
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Create the ACL.
arg := structs.ACLRequest{
Datacenter: "dc1",
Op: structs.ACLSet,
ACL: structs.ACL{
Name: "User token",
Type: structs.ACLTypeClient,
Rules: `
service "foo" {
policy = "write"
}
`,
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
var token string
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", arg, &token))
{
var out struct{}
// Register a service
args := structs.TestRegisterRequestProxy(t)
args.WriteRequest.Token = "root"
args.Service.ID = "foo-proxy-0"
args.Service.Service = "foo-proxy"
args.Service.ProxyDestination = "bar"
args.Check = &structs.HealthCheck{
Name: "proxy",
Status: api.HealthPassing,
ServiceID: args.Service.ID,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a service
args = structs.TestRegisterRequestProxy(t)
args.WriteRequest.Token = "root"
args.Service.Service = "foo-proxy"
args.Service.ProxyDestination = "foo"
args.Check = &structs.HealthCheck{
Name: "proxy",
Status: api.HealthPassing,
ServiceID: args.Service.Service,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
// Register a service
args = structs.TestRegisterRequestProxy(t)
args.WriteRequest.Token = "root"
args.Service.Service = "another-proxy"
args.Service.ProxyDestination = "foo"
args.Check = &structs.HealthCheck{
Name: "proxy",
Status: api.HealthPassing,
ServiceID: args.Service.Service,
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
}
// List w/ token. This should disallow because we don't have permission
// to read "bar"
req := structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: "bar",
QueryOptions: structs.QueryOptions{Token: token},
}
var resp structs.IndexedCheckServiceNodes
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
assert.Len(resp.Nodes, 0)
// List w/ token. This should work since we're requesting "foo", but should
// also only contain the proxies with names that adhere to our ACL.
req = structs.ServiceSpecificRequest{
Connect: true,
Datacenter: "dc1",
ServiceName: "foo",
QueryOptions: structs.QueryOptions{Token: token},
}
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
assert.Len(resp.Nodes, 1)
}
func TestHealth_NodeChecks_FilterACL(t *testing.T) {
t.Parallel()
dir, token, srv, codec := testACLFilterServer(t)

View File

@ -0,0 +1,358 @@
package consul
import (
"errors"
"fmt"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-uuid"
)
var (
// ErrIntentionNotFound is returned if the intention lookup failed.
ErrIntentionNotFound = errors.New("Intention not found")
)
// Intention manages the Connect intentions.
type Intention struct {
// srv is a pointer back to the server.
srv *Server
}
// Apply creates or updates an intention in the data store.
func (s *Intention) Apply(
args *structs.IntentionRequest,
reply *string) error {
if done, err := s.srv.forward("Intention.Apply", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"consul", "intention", "apply"}, time.Now())
defer metrics.MeasureSince([]string{"intention", "apply"}, time.Now())
// Always set a non-nil intention to avoid nil-access below
if args.Intention == nil {
args.Intention = &structs.Intention{}
}
// If no ID is provided, generate a new ID. This must be done prior to
// appending to the Raft log, because the ID is not deterministic. Once
// the entry is in the log, the state update MUST be deterministic or
// the followers will not converge.
if args.Op == structs.IntentionOpCreate {
if args.Intention.ID != "" {
return fmt.Errorf("ID must be empty when creating a new intention")
}
state := s.srv.fsm.State()
for {
var err error
args.Intention.ID, err = uuid.GenerateUUID()
if err != nil {
s.srv.logger.Printf("[ERR] consul.intention: UUID generation failed: %v", err)
return err
}
_, ixn, err := state.IntentionGet(nil, args.Intention.ID)
if err != nil {
s.srv.logger.Printf("[ERR] consul.intention: intention lookup failed: %v", err)
return err
}
if ixn == nil {
break
}
}
// Set the created at
args.Intention.CreatedAt = time.Now().UTC()
}
*reply = args.Intention.ID
// Get the ACL token for the request for the checks below.
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
// Perform the ACL check
if prefix, ok := args.Intention.GetACLPrefix(); ok {
if rule != nil && !rule.IntentionWrite(prefix) {
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention '%s' denied due to ACLs", args.Intention.ID)
return acl.ErrPermissionDenied
}
}
// If this is not a create, then we have to verify the ID.
if args.Op != structs.IntentionOpCreate {
state := s.srv.fsm.State()
_, ixn, err := state.IntentionGet(nil, args.Intention.ID)
if err != nil {
return fmt.Errorf("Intention lookup failed: %v", err)
}
if ixn == nil {
return fmt.Errorf("Cannot modify non-existent intention: '%s'", args.Intention.ID)
}
// Perform the ACL check that we have write to the old prefix too,
// which must be true to perform any rename.
if prefix, ok := ixn.GetACLPrefix(); ok {
if rule != nil && !rule.IntentionWrite(prefix) {
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention '%s' denied due to ACLs", args.Intention.ID)
return acl.ErrPermissionDenied
}
}
}
// We always update the updatedat field. This has no effect for deletion.
args.Intention.UpdatedAt = time.Now().UTC()
// Default source type
if args.Intention.SourceType == "" {
args.Intention.SourceType = structs.IntentionSourceConsul
}
// Until we support namespaces, we force all namespaces to be default
if args.Intention.SourceNS == "" {
args.Intention.SourceNS = structs.IntentionDefaultNamespace
}
if args.Intention.DestinationNS == "" {
args.Intention.DestinationNS = structs.IntentionDefaultNamespace
}
// Validate. We do not validate on delete since it is valid to only
// send an ID in that case.
if args.Op != structs.IntentionOpDelete {
// Set the precedence
args.Intention.UpdatePrecedence()
if err := args.Intention.Validate(); err != nil {
return err
}
}
// Commit
resp, err := s.srv.raftApply(structs.IntentionRequestType, args)
if err != nil {
s.srv.logger.Printf("[ERR] consul.intention: Apply failed %v", err)
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
return nil
}
// Get returns a single intention by ID.
func (s *Intention) Get(
args *structs.IntentionQueryRequest,
reply *structs.IndexedIntentions) error {
// Forward if necessary
if done, err := s.srv.forward("Intention.Get", args, args, reply); done {
return err
}
return s.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
index, ixn, err := state.IntentionGet(ws, args.IntentionID)
if err != nil {
return err
}
if ixn == nil {
return ErrIntentionNotFound
}
reply.Index = index
reply.Intentions = structs.Intentions{ixn}
// Filter
if err := s.srv.filterACL(args.Token, reply); err != nil {
return err
}
// If ACLs prevented any responses, error
if len(reply.Intentions) == 0 {
s.srv.logger.Printf("[WARN] consul.intention: Request to get intention '%s' denied due to ACLs", args.IntentionID)
return acl.ErrPermissionDenied
}
return nil
},
)
}
// List returns all the intentions.
func (s *Intention) List(
args *structs.DCSpecificRequest,
reply *structs.IndexedIntentions) error {
// Forward if necessary
if done, err := s.srv.forward("Intention.List", args, args, reply); done {
return err
}
return s.srv.blockingQuery(
&args.QueryOptions, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
index, ixns, err := state.Intentions(ws)
if err != nil {
return err
}
reply.Index, reply.Intentions = index, ixns
if reply.Intentions == nil {
reply.Intentions = make(structs.Intentions, 0)
}
return s.srv.filterACL(args.Token, reply)
},
)
}
// Match returns the set of intentions that match the given source/destination.
func (s *Intention) Match(
args *structs.IntentionQueryRequest,
reply *structs.IndexedIntentionMatches) error {
// Forward if necessary
if done, err := s.srv.forward("Intention.Match", args, args, reply); done {
return err
}
// Get the ACL token for the request for the checks below.
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
if rule != nil {
// We go through each entry and test the destination to check if it
// matches.
for _, entry := range args.Match.Entries {
if prefix := entry.Name; prefix != "" && !rule.IntentionRead(prefix) {
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention prefix '%s' denied due to ACLs", prefix)
return acl.ErrPermissionDenied
}
}
}
return s.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
index, matches, err := state.IntentionMatch(ws, args.Match)
if err != nil {
return err
}
reply.Index = index
reply.Matches = matches
return nil
},
)
}
// Check tests a source/destination and returns whether it would be allowed
// or denied based on the current ACL configuration.
//
// Note: Whenever the logic for this method is changed, you should take
// a look at the agent authorize endpoint (agent/agent_endpoint.go) since
// the logic there is similar.
func (s *Intention) Check(
args *structs.IntentionQueryRequest,
reply *structs.IntentionQueryCheckResponse) error {
// Forward maybe
if done, err := s.srv.forward("Intention.Check", args, args, reply); done {
return err
}
// Get the test args, and defensively guard against nil
query := args.Check
if query == nil {
return errors.New("Check must be specified on args")
}
// Build the URI
var uri connect.CertURI
switch query.SourceType {
case structs.IntentionSourceConsul:
uri = &connect.SpiffeIDService{
Namespace: query.SourceNS,
Service: query.SourceName,
}
default:
return fmt.Errorf("unsupported SourceType: %q", query.SourceType)
}
// Get the ACL token for the request for the checks below.
rule, err := s.srv.resolveToken(args.Token)
if err != nil {
return err
}
// Perform the ACL check. For Check we only require ServiceRead and
// NOT IntentionRead because the Check API only returns pass/fail and
// returns no other information about the intentions used.
if prefix, ok := query.GetACLPrefix(); ok {
if rule != nil && !rule.ServiceRead(prefix) {
s.srv.logger.Printf("[WARN] consul.intention: test on intention '%s' denied due to ACLs", prefix)
return acl.ErrPermissionDenied
}
}
// Get the matches for this destination
state := s.srv.fsm.State()
_, matches, err := state.IntentionMatch(nil, &structs.IntentionQueryMatch{
Type: structs.IntentionMatchDestination,
Entries: []structs.IntentionMatchEntry{
structs.IntentionMatchEntry{
Namespace: query.DestinationNS,
Name: query.DestinationName,
},
},
})
if err != nil {
return err
}
if len(matches) != 1 {
// This should never happen since the documented behavior of the
// Match call is that it'll always return exactly the number of results
// as entries passed in. But we guard against misbehavior.
return errors.New("internal error loading matches")
}
// Check the authorization for each match
for _, ixn := range matches[0] {
if auth, ok := uri.Authorize(ixn); ok {
reply.Allowed = auth
return nil
}
}
// No match, we need to determine the default behavior. We do this by
// specifying the anonymous token token, which will get that behavior.
// The default behavior if ACLs are disabled is to allow connections
// to mimic the behavior of Consul itself: everything is allowed if
// ACLs are disabled.
//
// NOTE(mitchellh): This is the same behavior as the agent authorize
// endpoint. If this behavior is incorrect, we should also change it there
// which is much more important.
rule, err = s.srv.resolveToken("")
if err != nil {
return err
}
reply.Allowed = true
if rule != nil {
reply.Allowed = rule.IntentionDefaultAllow()
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -4,16 +4,20 @@ import (
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/connect"
ca "github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/types"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-version"
"github.com/hashicorp/raft"
"github.com/hashicorp/serf/serf"
@ -210,6 +214,12 @@ func (s *Server) establishLeadership() error {
s.getOrCreateAutopilotConfig()
s.autopilot.Start()
// todo(kyhavlov): start a goroutine here for handling periodic CA rotation
if err := s.initializeCA(); err != nil {
return err
}
s.setConsistentReadReady()
return nil
}
@ -226,6 +236,8 @@ func (s *Server) revokeLeadership() error {
return err
}
s.setCAProvider(nil, nil)
s.resetConsistentReadReady()
s.autopilot.Stop()
return nil
@ -359,6 +371,185 @@ func (s *Server) getOrCreateAutopilotConfig() *autopilot.Config {
return config
}
// initializeCAConfig is used to initialize the CA config if necessary
// when setting up the CA during establishLeadership
func (s *Server) initializeCAConfig() (*structs.CAConfiguration, error) {
state := s.fsm.State()
_, config, err := state.CAConfig()
if err != nil {
return nil, err
}
if config != nil {
return config, nil
}
config = s.config.CAConfig
if config.ClusterID == "" {
id, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
config.ClusterID = id
}
req := structs.CARequest{
Op: structs.CAOpSetConfig,
Config: config,
}
if _, err = s.raftApply(structs.ConnectCARequestType, req); err != nil {
return nil, err
}
return config, nil
}
// initializeCA sets up the CA provider when gaining leadership, bootstrapping
// the root in the state store if necessary.
func (s *Server) initializeCA() error {
// Bail if connect isn't enabled.
if !s.config.ConnectEnabled {
return nil
}
conf, err := s.initializeCAConfig()
if err != nil {
return err
}
// Initialize the right provider based on the config
provider, err := s.createCAProvider(conf)
if err != nil {
return err
}
// Get the active root cert from the CA
rootPEM, err := provider.ActiveRoot()
if err != nil {
return fmt.Errorf("error getting root cert: %v", err)
}
rootCA, err := parseCARoot(rootPEM, conf.Provider)
if err != nil {
return err
}
// TODO(banks): in the case that we've just gained leadership in an already
// configured cluster. We really need to fetch RootCA from state to provide it
// in setCAProvider. This matters because if the current active root has
// intermediates, parsing the rootCA from only the root cert PEM above will
// not include them and so leafs we sign will not bundle the intermediates.
s.setCAProvider(provider, rootCA)
// Check if the CA root is already initialized and exit if it is.
// Every change to the CA after this initial bootstrapping should
// be done through the rotation process.
state := s.fsm.State()
_, activeRoot, err := state.CARootActive(nil)
if err != nil {
return err
}
if activeRoot != nil {
if activeRoot.ID != rootCA.ID {
// TODO(banks): this seems like a pretty catastrophic state to get into.
// Shouldn't we do something stronger than warn and continue signing with
// a key that's not the active CA according to the state?
s.logger.Printf("[WARN] connect: CA root %q is not the active root (%q)", rootCA.ID, activeRoot.ID)
}
return nil
}
// Get the highest index
idx, _, err := state.CARoots(nil)
if err != nil {
return err
}
// Store the root cert in raft
resp, err := s.raftApply(structs.ConnectCARequestType, &structs.CARequest{
Op: structs.CAOpSetRoots,
Index: idx,
Roots: []*structs.CARoot{rootCA},
})
if err != nil {
s.logger.Printf("[ERR] connect: Apply failed %v", err)
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
s.logger.Printf("[INFO] connect: initialized CA with provider %q", conf.Provider)
return nil
}
// parseCARoot returns a filled-in structs.CARoot from a raw PEM value.
func parseCARoot(pemValue, provider string) (*structs.CARoot, error) {
id, err := connect.CalculateCertFingerprint(pemValue)
if err != nil {
return nil, fmt.Errorf("error parsing root fingerprint: %v", err)
}
rootCert, err := connect.ParseCert(pemValue)
if err != nil {
return nil, fmt.Errorf("error parsing root cert: %v", err)
}
return &structs.CARoot{
ID: id,
Name: fmt.Sprintf("%s CA Root Cert", strings.Title(provider)),
SerialNumber: rootCert.SerialNumber.Uint64(),
SigningKeyID: connect.HexString(rootCert.AuthorityKeyId),
NotBefore: rootCert.NotBefore,
NotAfter: rootCert.NotAfter,
RootCert: pemValue,
Active: true,
}, nil
}
// createProvider returns a connect CA provider from the given config.
func (s *Server) createCAProvider(conf *structs.CAConfiguration) (ca.Provider, error) {
switch conf.Provider {
case structs.ConsulCAProvider:
return ca.NewConsulProvider(conf.Config, &consulCADelegate{s})
case structs.VaultCAProvider:
return ca.NewVaultProvider(conf.Config, conf.ClusterID)
default:
return nil, fmt.Errorf("unknown CA provider %q", conf.Provider)
}
}
func (s *Server) getCAProvider() (ca.Provider, *structs.CARoot) {
retries := 0
var result ca.Provider
var resultRoot *structs.CARoot
for result == nil {
s.caProviderLock.RLock()
result = s.caProvider
resultRoot = s.caProviderRoot
s.caProviderLock.RUnlock()
// In cases where an agent is started with managed proxies, we may ask
// for the provider before establishLeadership completes. If we're the
// leader, then wait and get the provider again
if result == nil && s.IsLeader() && retries < 10 {
retries++
time.Sleep(50 * time.Millisecond)
continue
}
break
}
return result, resultRoot
}
func (s *Server) setCAProvider(newProvider ca.Provider, root *structs.CARoot) {
s.caProviderLock.Lock()
defer s.caProviderLock.Unlock()
s.caProvider = newProvider
s.caProviderRoot = root
}
// reconcileReaped is used to reconcile nodes that have failed and been reaped
// from Serf but remain in the catalog. This is done by looking for unknown nodes with serfHealth checks registered.
// We generate a "reap" event to cause the node to be cleaned up.

View File

@ -354,7 +354,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
}
// Execute the query for the local DC.
if err := p.execute(query, reply); err != nil {
if err := p.execute(query, reply, args.Connect); err != nil {
return err
}
@ -450,7 +450,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
// by the query setup.
if len(reply.Nodes) == 0 {
wrapper := &queryServerWrapper{p.srv}
if err := queryFailover(wrapper, query, args.Limit, args.QueryOptions, reply); err != nil {
if err := queryFailover(wrapper, query, args, reply); err != nil {
return err
}
}
@ -479,7 +479,7 @@ func (p *PreparedQuery) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRe
}
// Run the query locally to see what we can find.
if err := p.execute(&args.Query, reply); err != nil {
if err := p.execute(&args.Query, reply, args.Connect); err != nil {
return err
}
@ -509,9 +509,18 @@ func (p *PreparedQuery) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRe
// execute runs a prepared query in the local DC without any failover. We don't
// apply any sorting options or ACL checks at this level - it should be done up above.
func (p *PreparedQuery) execute(query *structs.PreparedQuery,
reply *structs.PreparedQueryExecuteResponse) error {
reply *structs.PreparedQueryExecuteResponse,
forceConnect bool) error {
state := p.srv.fsm.State()
_, nodes, err := state.CheckServiceNodes(nil, query.Service.Service)
// If we're requesting Connect-capable services, then switch the
// lookup to be the Connect function.
f := state.CheckServiceNodes
if query.Service.Connect || forceConnect {
f = state.CheckConnectServiceNodes
}
_, nodes, err := f(nil, query.Service.Service)
if err != nil {
return err
}
@ -651,7 +660,7 @@ func (q *queryServerWrapper) ForwardDC(method, dc string, args interface{}, repl
// queryFailover runs an algorithm to determine which DCs to try and then calls
// them to try to locate alternative services.
func queryFailover(q queryServer, query *structs.PreparedQuery,
limit int, options structs.QueryOptions,
args *structs.PreparedQueryExecuteRequest,
reply *structs.PreparedQueryExecuteResponse) error {
// Pull the list of other DCs. This is sorted by RTT in case the user
@ -719,8 +728,9 @@ func queryFailover(q queryServer, query *structs.PreparedQuery,
remote := &structs.PreparedQueryExecuteRemoteRequest{
Datacenter: dc,
Query: *query,
Limit: limit,
QueryOptions: options,
Limit: args.Limit,
QueryOptions: args.QueryOptions,
Connect: args.Connect,
}
if err := q.ForwardDC("PreparedQuery.ExecuteRemote", dc, remote, reply); err != nil {
q.GetLogger().Printf("[WARN] consul.prepared_query: Failed querying for service '%s' in datacenter '%s': %s", query.Service.Service, dc, err)

View File

@ -20,6 +20,7 @@ import (
"github.com/hashicorp/consul/types"
"github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/serf/coordinate"
"github.com/stretchr/testify/require"
)
func TestPreparedQuery_Apply(t *testing.T) {
@ -2617,6 +2618,159 @@ func TestPreparedQuery_Execute_ForwardLeader(t *testing.T) {
}
}
func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
t.Parallel()
require := require.New(t)
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
// Setup 3 services on 3 nodes: one is non-Connect, one is Connect native,
// and one is a proxy to the non-Connect one.
for i := 0; i < 3; i++ {
req := structs.RegisterRequest{
Datacenter: "dc1",
Node: fmt.Sprintf("node%d", i+1),
Address: fmt.Sprintf("127.0.0.%d", i+1),
Service: &structs.NodeService{
Service: "foo",
Port: 8000,
},
}
switch i {
case 0:
// Default do nothing
case 1:
// Connect native
req.Service.Connect.Native = true
case 2:
// Connect proxy
req.Service.Kind = structs.ServiceKindConnectProxy
req.Service.ProxyDestination = req.Service.Service
req.Service.Service = "proxy"
}
var reply struct{}
require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &reply))
}
// The query, start with connect disabled
query := structs.PreparedQueryRequest{
Datacenter: "dc1",
Op: structs.PreparedQueryCreate,
Query: &structs.PreparedQuery{
Name: "test",
Service: structs.ServiceQuery{
Service: "foo",
},
DNS: structs.QueryDNSOptions{
TTL: "10s",
},
},
}
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
// In the future we'll run updates
query.Op = structs.PreparedQueryUpdate
// Run the registered query.
{
req := structs.PreparedQueryExecuteRequest{
Datacenter: "dc1",
QueryIDOrName: query.Query.ID,
}
var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because it omits the proxy whose name
// doesn't match the query.
require.Len(reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader")
}
// Run with the Connect setting specified on the request
{
req := structs.PreparedQueryExecuteRequest{
Datacenter: "dc1",
QueryIDOrName: query.Query.ID,
Connect: true,
}
var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because we should get the native AND
// the proxy (since the destination matches our service name).
require.Len(reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader")
// Make sure the native is the first one
if !reply.Nodes[0].Service.Connect.Native {
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
}
require.True(reply.Nodes[0].Service.Connect.Native, "native")
require.Equal(reply.Service, reply.Nodes[0].Service.Service)
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
require.Equal(reply.Service, reply.Nodes[1].Service.ProxyDestination)
}
// Update the query
query.Query.Service.Connect = true
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
// Run the registered query.
{
req := structs.PreparedQueryExecuteRequest{
Datacenter: "dc1",
QueryIDOrName: query.Query.ID,
}
var reply structs.PreparedQueryExecuteResponse
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Execute", &req, &reply))
// Result should have two because we should get the native AND
// the proxy (since the destination matches our service name).
require.Len(reply.Nodes, 2)
require.Equal(query.Query.Service.Service, reply.Service)
require.Equal(query.Query.DNS, reply.DNS)
require.True(reply.QueryMeta.KnownLeader, "queried leader")
// Make sure the native is the first one
if !reply.Nodes[0].Service.Connect.Native {
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
}
require.True(reply.Nodes[0].Service.Connect.Native, "native")
require.Equal(reply.Service, reply.Nodes[0].Service.Service)
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
require.Equal(reply.Service, reply.Nodes[1].Service.ProxyDestination)
}
// Unset the query
query.Query.Service.Connect = false
require.NoError(msgpackrpc.CallWithCodec(
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
}
func TestPreparedQuery_tagFilter(t *testing.T) {
t.Parallel()
testNodes := func() structs.CheckServiceNodes {
@ -2820,7 +2974,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 0 || reply.Datacenter != "" || reply.Failovers != 0 {
@ -2836,7 +2990,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply)
err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply)
if err == nil || !strings.Contains(err.Error(), "XXX") {
t.Fatalf("bad: %v", err)
}
@ -2853,7 +3007,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 0 || reply.Datacenter != "" || reply.Failovers != 0 {
@ -2876,7 +3030,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -2904,7 +3058,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -2925,7 +3079,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 0 ||
@ -2954,7 +3108,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -2983,7 +3137,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -3012,7 +3166,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -3047,7 +3201,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -3079,7 +3233,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||
@ -3115,7 +3269,10 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
}
var reply structs.PreparedQueryExecuteResponse
if err := queryFailover(mock, query, 5, structs.QueryOptions{RequireConsistent: true}, &reply); err != nil {
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{
Limit: 5,
QueryOptions: structs.QueryOptions{RequireConsistent: true},
}, &reply); err != nil {
t.Fatalf("err: %v", err)
}
if len(reply.Nodes) != 3 ||

View File

@ -18,6 +18,7 @@ import (
"time"
"github.com/hashicorp/consul/acl"
ca "github.com/hashicorp/consul/agent/connect/ca"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/consul/agent/consul/fsm"
"github.com/hashicorp/consul/agent/consul/state"
@ -96,6 +97,16 @@ type Server struct {
// autopilotWaitGroup is used to block until Autopilot shuts down.
autopilotWaitGroup sync.WaitGroup
// caProvider is the current CA provider in use for Connect. This is
// only non-nil when we are the leader.
caProvider ca.Provider
// caProviderRoot is the CARoot that was stored along with the ca.Provider
// active. It's only updated in lock-step with the caProvider. This prevents
// races between state updates to active roots and the fetch of the provider
// instance.
caProviderRoot *structs.CARoot
caProviderLock sync.RWMutex
// Consul configuration
config *Config

View File

@ -4,7 +4,9 @@ func init() {
registerEndpoint(func(s *Server) interface{} { return &ACL{s} })
registerEndpoint(func(s *Server) interface{} { return &Catalog{s} })
registerEndpoint(func(s *Server) interface{} { return NewCoordinate(s) })
registerEndpoint(func(s *Server) interface{} { return &ConnectCA{s} })
registerEndpoint(func(s *Server) interface{} { return &Health{s} })
registerEndpoint(func(s *Server) interface{} { return &Intention{s} })
registerEndpoint(func(s *Server) interface{} { return &Internal{s} })
registerEndpoint(func(s *Server) interface{} { return &KVS{s} })
registerEndpoint(func(s *Server) interface{} { return &Operator{s} })

View File

@ -10,7 +10,9 @@ import (
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/lib/freeport"
"github.com/hashicorp/consul/testrpc"
@ -91,6 +93,17 @@ func testServerConfig(t *testing.T) (string, *Config) {
// looks like several depend on it.
config.RPCHoldTimeout = 5 * time.Second
config.ConnectEnabled = true
config.CAConfig = &structs.CAConfiguration{
ClusterID: connect.TestClusterID,
Provider: structs.ConsulCAProvider,
Config: map[string]interface{}{
"PrivateKey": "",
"RootCert": "",
"RotationPeriod": 90 * 24 * time.Hour,
},
}
return dir, config
}

View File

@ -10,6 +10,10 @@ import (
"github.com/hashicorp/go-memdb"
)
const (
servicesTableName = "services"
)
// nodesTableSchema returns a new table schema used for storing node
// information.
func nodesTableSchema() *memdb.TableSchema {
@ -87,6 +91,12 @@ func servicesTableSchema() *memdb.TableSchema {
Lowercase: true,
},
},
"connect": &memdb.IndexSchema{
Name: "connect",
AllowMissing: true,
Unique: false,
Indexer: &IndexConnectService{},
},
},
}
}
@ -779,15 +789,39 @@ func maxIndexForService(tx *memdb.Txn, serviceName string, checks bool) uint64 {
return maxIndexTxn(tx, "nodes", "services")
}
// ConnectServiceNodes returns the nodes associated with a Connect
// compatible destination for the given service name. This will include
// both proxies and native integrations.
func (s *Store) ConnectServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.ServiceNodes, error) {
return s.serviceNodes(ws, serviceName, true)
}
// ServiceNodes returns the nodes associated with a given service name.
func (s *Store) ServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.ServiceNodes, error) {
return s.serviceNodes(ws, serviceName, false)
}
func (s *Store) serviceNodes(ws memdb.WatchSet, serviceName string, connect bool) (uint64, structs.ServiceNodes, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexForService(tx, serviceName, false)
// Function for lookup
var f func() (memdb.ResultIterator, error)
if !connect {
f = func() (memdb.ResultIterator, error) {
return tx.Get("services", "service", serviceName)
}
} else {
f = func() (memdb.ResultIterator, error) {
return tx.Get("services", "connect", serviceName)
}
}
// List all the services.
services, err := tx.Get("services", "service", serviceName)
services, err := f()
if err != nil {
return 0, nil, fmt.Errorf("failed service lookup: %s", err)
}
@ -1479,14 +1513,36 @@ func (s *Store) deleteCheckTxn(tx *memdb.Txn, idx uint64, node string, checkID t
// CheckServiceNodes is used to query all nodes and checks for a given service.
func (s *Store) CheckServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.CheckServiceNodes, error) {
return s.checkServiceNodes(ws, serviceName, false)
}
// CheckConnectServiceNodes is used to query all nodes and checks for Connect
// compatible endpoints for a given service.
func (s *Store) CheckConnectServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.CheckServiceNodes, error) {
return s.checkServiceNodes(ws, serviceName, true)
}
func (s *Store) checkServiceNodes(ws memdb.WatchSet, serviceName string, connect bool) (uint64, structs.CheckServiceNodes, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexForService(tx, serviceName, true)
// Function for lookup
var f func() (memdb.ResultIterator, error)
if !connect {
f = func() (memdb.ResultIterator, error) {
return tx.Get("services", "service", serviceName)
}
} else {
f = func() (memdb.ResultIterator, error) {
return tx.Get("services", "connect", serviceName)
}
}
// Query the state store for the service.
iter, err := tx.Get("services", "service", serviceName)
iter, err := f()
if err != nil {
return 0, nil, fmt.Errorf("failed service lookup: %s", err)
}

View File

@ -14,6 +14,7 @@ import (
"github.com/hashicorp/go-memdb"
uuid "github.com/hashicorp/go-uuid"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
)
func makeRandomNodeID(t *testing.T) types.NodeID {
@ -981,6 +982,35 @@ func TestStateStore_EnsureService(t *testing.T) {
}
}
func TestStateStore_EnsureService_connectProxy(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Create the service registration.
ns1 := &structs.NodeService{
Kind: structs.ServiceKindConnectProxy,
ID: "connect-proxy",
Service: "connect-proxy",
Address: "1.1.1.1",
Port: 1111,
ProxyDestination: "foo",
}
// Service successfully registers into the state store.
testRegisterNode(t, s, 0, "node1")
assert.Nil(s.EnsureService(10, "node1", ns1))
// Retrieve and verify
_, out, err := s.NodeServices(nil, "node1")
assert.Nil(err)
assert.NotNil(out)
assert.Len(out.Services, 1)
expect1 := *ns1
expect1.CreateIndex, expect1.ModifyIndex = 10, 10
assert.Equal(&expect1, out.Services["connect-proxy"])
}
func TestStateStore_Services(t *testing.T) {
s := testStateStore(t)
@ -1542,6 +1572,51 @@ func TestStateStore_DeleteService(t *testing.T) {
}
}
func TestStateStore_ConnectServiceNodes(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Listing with no results returns an empty list.
ws := memdb.NewWatchSet()
idx, nodes, err := s.ConnectServiceNodes(ws, "db")
assert.Nil(err)
assert.Equal(idx, uint64(0))
assert.Len(nodes, 0)
// Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "native-db", Service: "db", Connect: structs.ServiceConnect{Native: true}}))
assert.Nil(s.EnsureService(17, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}))
assert.True(watchFired(ws))
// Read everything back.
ws = memdb.NewWatchSet()
idx, nodes, err = s.ConnectServiceNodes(ws, "db")
assert.Nil(err)
assert.Equal(idx, uint64(idx))
assert.Len(nodes, 3)
for _, n := range nodes {
assert.True(
n.ServiceKind == structs.ServiceKindConnectProxy ||
n.ServiceConnect.Native,
"either proxy or connect native")
}
// Registering some unrelated node should not fire the watch.
testRegisterNode(t, s, 17, "nope")
assert.False(watchFired(ws))
// But removing a node with the "db" service should fire the watch.
assert.Nil(s.DeleteNode(18, "bar"))
assert.True(watchFired(ws))
}
func TestStateStore_Service_Snapshot(t *testing.T) {
s := testStateStore(t)
@ -2457,6 +2532,48 @@ func TestStateStore_CheckServiceNodes(t *testing.T) {
}
}
func TestStateStore_CheckConnectServiceNodes(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Listing with no results returns an empty list.
ws := memdb.NewWatchSet()
idx, nodes, err := s.CheckConnectServiceNodes(ws, "db")
assert.Nil(err)
assert.Equal(idx, uint64(0))
assert.Len(nodes, 0)
// Create some nodes and services.
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}))
assert.True(watchFired(ws))
// Register node checks
testRegisterCheck(t, s, 17, "foo", "", "check1", api.HealthPassing)
testRegisterCheck(t, s, 18, "bar", "", "check2", api.HealthPassing)
// Register checks against the services.
testRegisterCheck(t, s, 19, "foo", "db", "check3", api.HealthPassing)
testRegisterCheck(t, s, 20, "bar", "proxy", "check4", api.HealthPassing)
// Read everything back.
ws = memdb.NewWatchSet()
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db")
assert.Nil(err)
assert.Equal(idx, uint64(idx))
assert.Len(nodes, 2)
for _, n := range nodes {
assert.Equal(structs.ServiceKindConnectProxy, n.Service.Kind)
assert.Equal("db", n.Service.ProxyDestination)
}
}
func BenchmarkCheckServiceNodes(b *testing.B) {
s, err := NewStateStore(nil)
if err != nil {

View File

@ -0,0 +1,435 @@
package state
import (
"fmt"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
)
const (
caBuiltinProviderTableName = "connect-ca-builtin"
caConfigTableName = "connect-ca-config"
caRootTableName = "connect-ca-roots"
)
// caBuiltinProviderTableSchema returns a new table schema used for storing
// the built-in CA provider's state for connect. This is only used by
// the internal Consul CA provider.
func caBuiltinProviderTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: caBuiltinProviderTableName,
Indexes: map[string]*memdb.IndexSchema{
"id": &memdb.IndexSchema{
Name: "id",
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "ID",
},
},
},
}
}
// caConfigTableSchema returns a new table schema used for storing
// the CA config for Connect.
func caConfigTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: caConfigTableName,
Indexes: map[string]*memdb.IndexSchema{
// This table only stores one row, so this just ignores the ID field
// and always overwrites the same config object.
"id": &memdb.IndexSchema{
Name: "id",
AllowMissing: true,
Unique: true,
Indexer: &memdb.ConditionalIndex{
Conditional: func(obj interface{}) (bool, error) { return true, nil },
},
},
},
}
}
// caRootTableSchema returns a new table schema used for storing
// CA roots for Connect.
func caRootTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: caRootTableName,
Indexes: map[string]*memdb.IndexSchema{
"id": &memdb.IndexSchema{
Name: "id",
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "ID",
},
},
},
}
}
func init() {
registerSchema(caBuiltinProviderTableSchema)
registerSchema(caConfigTableSchema)
registerSchema(caRootTableSchema)
}
// CAConfig is used to pull the CA config from the snapshot.
func (s *Snapshot) CAConfig() (*structs.CAConfiguration, error) {
c, err := s.tx.First(caConfigTableName, "id")
if err != nil {
return nil, err
}
config, ok := c.(*structs.CAConfiguration)
if !ok {
return nil, nil
}
return config, nil
}
// CAConfig is used when restoring from a snapshot.
func (s *Restore) CAConfig(config *structs.CAConfiguration) error {
if err := s.tx.Insert(caConfigTableName, config); err != nil {
return fmt.Errorf("failed restoring CA config: %s", err)
}
return nil
}
// CAConfig is used to get the current CA configuration.
func (s *Store) CAConfig() (uint64, *structs.CAConfiguration, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the CA config
c, err := tx.First(caConfigTableName, "id")
if err != nil {
return 0, nil, fmt.Errorf("failed CA config lookup: %s", err)
}
config, ok := c.(*structs.CAConfiguration)
if !ok {
return 0, nil, nil
}
return config.ModifyIndex, config, nil
}
// CASetConfig is used to set the current CA configuration.
func (s *Store) CASetConfig(idx uint64, config *structs.CAConfiguration) error {
tx := s.db.Txn(true)
defer tx.Abort()
if err := s.caSetConfigTxn(idx, tx, config); err != nil {
return err
}
tx.Commit()
return nil
}
// CACheckAndSetConfig is used to try updating the CA configuration with a
// given Raft index. If the CAS index specified is not equal to the last observed index
// for the config, then the call is a noop,
func (s *Store) CACheckAndSetConfig(idx, cidx uint64, config *structs.CAConfiguration) (bool, error) {
tx := s.db.Txn(true)
defer tx.Abort()
// Check for an existing config
existing, err := tx.First(caConfigTableName, "id")
if err != nil {
return false, fmt.Errorf("failed CA config lookup: %s", err)
}
// If the existing index does not match the provided CAS
// index arg, then we shouldn't update anything and can safely
// return early here.
e, ok := existing.(*structs.CAConfiguration)
if !ok || e.ModifyIndex != cidx {
return false, nil
}
if err := s.caSetConfigTxn(idx, tx, config); err != nil {
return false, err
}
tx.Commit()
return true, nil
}
func (s *Store) caSetConfigTxn(idx uint64, tx *memdb.Txn, config *structs.CAConfiguration) error {
// Check for an existing config
prev, err := tx.First(caConfigTableName, "id")
if err != nil {
return fmt.Errorf("failed CA config lookup: %s", err)
}
// Set the indexes, prevent the cluster ID from changing.
if prev != nil {
existing := prev.(*structs.CAConfiguration)
config.CreateIndex = existing.CreateIndex
config.ClusterID = existing.ClusterID
} else {
config.CreateIndex = idx
}
config.ModifyIndex = idx
if err := tx.Insert(caConfigTableName, config); err != nil {
return fmt.Errorf("failed updating CA config: %s", err)
}
return nil
}
// CARoots is used to pull all the CA roots for the snapshot.
func (s *Snapshot) CARoots() (structs.CARoots, error) {
ixns, err := s.tx.Get(caRootTableName, "id")
if err != nil {
return nil, err
}
var ret structs.CARoots
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
ret = append(ret, wrapped.(*structs.CARoot))
}
return ret, nil
}
// CARoots is used when restoring from a snapshot.
func (s *Restore) CARoot(r *structs.CARoot) error {
// Insert
if err := s.tx.Insert(caRootTableName, r); err != nil {
return fmt.Errorf("failed restoring CA root: %s", err)
}
if err := indexUpdateMaxTxn(s.tx, r.ModifyIndex, caRootTableName); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
return nil
}
// CARoots returns the list of all CA roots.
func (s *Store) CARoots(ws memdb.WatchSet) (uint64, structs.CARoots, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the index
idx := maxIndexTxn(tx, caRootTableName)
// Get all
iter, err := tx.Get(caRootTableName, "id")
if err != nil {
return 0, nil, fmt.Errorf("failed CA root lookup: %s", err)
}
ws.Add(iter.WatchCh())
var results structs.CARoots
for v := iter.Next(); v != nil; v = iter.Next() {
results = append(results, v.(*structs.CARoot))
}
return idx, results, nil
}
// CARootActive returns the currently active CARoot.
func (s *Store) CARootActive(ws memdb.WatchSet) (uint64, *structs.CARoot, error) {
// Get all the roots since there should never be that many and just
// do the filtering in this method.
var result *structs.CARoot
idx, roots, err := s.CARoots(ws)
if err == nil {
for _, r := range roots {
if r.Active {
result = r
break
}
}
}
return idx, result, err
}
// CARootSetCAS sets the current CA root state using a check-and-set operation.
// On success, this will replace the previous set of CARoots completely with
// the given set of roots.
//
// The first boolean result returns whether the transaction succeeded or not.
func (s *Store) CARootSetCAS(idx, cidx uint64, rs []*structs.CARoot) (bool, error) {
tx := s.db.Txn(true)
defer tx.Abort()
// There must be exactly one active CA root.
activeCount := 0
for _, r := range rs {
if r.Active {
activeCount++
}
}
if activeCount != 1 {
return false, fmt.Errorf("there must be exactly one active CA")
}
// Get the current max index
if midx := maxIndexTxn(tx, caRootTableName); midx != cidx {
return false, nil
}
// Go through and find any existing matching CAs so we can preserve and
// update their Create/ModifyIndex values.
for _, r := range rs {
if r.ID == "" {
return false, ErrMissingCARootID
}
existing, err := tx.First(caRootTableName, "id", r.ID)
if err != nil {
return false, fmt.Errorf("failed CA root lookup: %s", err)
}
if existing != nil {
r.CreateIndex = existing.(*structs.CARoot).CreateIndex
} else {
r.CreateIndex = idx
}
r.ModifyIndex = idx
}
// Delete all
_, err := tx.DeleteAll(caRootTableName, "id")
if err != nil {
return false, err
}
// Insert all
for _, r := range rs {
if err := tx.Insert(caRootTableName, r); err != nil {
return false, err
}
}
// Update the index
if err := tx.Insert("index", &IndexEntry{caRootTableName, idx}); err != nil {
return false, fmt.Errorf("failed updating index: %s", err)
}
tx.Commit()
return true, nil
}
// CAProviderState is used to pull the built-in provider states from the snapshot.
func (s *Snapshot) CAProviderState() ([]*structs.CAConsulProviderState, error) {
ixns, err := s.tx.Get(caBuiltinProviderTableName, "id")
if err != nil {
return nil, err
}
var ret []*structs.CAConsulProviderState
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
ret = append(ret, wrapped.(*structs.CAConsulProviderState))
}
return ret, nil
}
// CAProviderState is used when restoring from a snapshot.
func (s *Restore) CAProviderState(state *structs.CAConsulProviderState) error {
if err := s.tx.Insert(caBuiltinProviderTableName, state); err != nil {
return fmt.Errorf("failed restoring built-in CA state: %s", err)
}
if err := indexUpdateMaxTxn(s.tx, state.ModifyIndex, caBuiltinProviderTableName); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
return nil
}
// CAProviderState is used to get the Consul CA provider state for the given ID.
func (s *Store) CAProviderState(id string) (uint64, *structs.CAConsulProviderState, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the index
idx := maxIndexTxn(tx, caBuiltinProviderTableName)
// Get the provider config
c, err := tx.First(caBuiltinProviderTableName, "id", id)
if err != nil {
return 0, nil, fmt.Errorf("failed built-in CA state lookup: %s", err)
}
state, ok := c.(*structs.CAConsulProviderState)
if !ok {
return 0, nil, nil
}
return idx, state, nil
}
// CASetProviderState is used to set the current built-in CA provider state.
func (s *Store) CASetProviderState(idx uint64, state *structs.CAConsulProviderState) (bool, error) {
tx := s.db.Txn(true)
defer tx.Abort()
// Check for an existing config
existing, err := tx.First(caBuiltinProviderTableName, "id", state.ID)
if err != nil {
return false, fmt.Errorf("failed built-in CA state lookup: %s", err)
}
// Set the indexes.
if existing != nil {
state.CreateIndex = existing.(*structs.CAConsulProviderState).CreateIndex
} else {
state.CreateIndex = idx
}
state.ModifyIndex = idx
if err := tx.Insert(caBuiltinProviderTableName, state); err != nil {
return false, fmt.Errorf("failed updating built-in CA state: %s", err)
}
// Update the index
if err := tx.Insert("index", &IndexEntry{caBuiltinProviderTableName, idx}); err != nil {
return false, fmt.Errorf("failed updating index: %s", err)
}
tx.Commit()
return true, nil
}
// CADeleteProviderState is used to remove the built-in Consul CA provider
// state for the given ID.
func (s *Store) CADeleteProviderState(id string) error {
tx := s.db.Txn(true)
defer tx.Abort()
// Get the index
idx := maxIndexTxn(tx, caBuiltinProviderTableName)
// Check for an existing config
existing, err := tx.First(caBuiltinProviderTableName, "id", id)
if err != nil {
return fmt.Errorf("failed built-in CA state lookup: %s", err)
}
if existing == nil {
return nil
}
providerState := existing.(*structs.CAConsulProviderState)
// Do the delete and update the index
if err := tx.Delete(caBuiltinProviderTableName, providerState); err != nil {
return err
}
if err := tx.Insert("index", &IndexEntry{caBuiltinProviderTableName, idx}); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
tx.Commit()
return nil
}

View File

@ -0,0 +1,449 @@
package state
import (
"reflect"
"testing"
"time"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
)
func TestStore_CAConfig(t *testing.T) {
s := testStateStore(t)
expected := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": "asdf",
"RootCert": "qwer",
"RotationPeriod": 90 * 24 * time.Hour,
},
}
if err := s.CASetConfig(0, expected); err != nil {
t.Fatal(err)
}
idx, config, err := s.CAConfig()
if err != nil {
t.Fatal(err)
}
if idx != 0 {
t.Fatalf("bad: %d", idx)
}
if !reflect.DeepEqual(expected, config) {
t.Fatalf("bad: %#v, %#v", expected, config)
}
}
func TestStore_CAConfigCAS(t *testing.T) {
s := testStateStore(t)
expected := &structs.CAConfiguration{
Provider: "consul",
}
if err := s.CASetConfig(0, expected); err != nil {
t.Fatal(err)
}
// Do an extra operation to move the index up by 1 for the
// check-and-set operation after this
if err := s.CASetConfig(1, expected); err != nil {
t.Fatal(err)
}
// Do a CAS with an index lower than the entry
ok, err := s.CACheckAndSetConfig(2, 0, &structs.CAConfiguration{
Provider: "static",
})
if ok || err != nil {
t.Fatalf("expected (false, nil), got: (%v, %#v)", ok, err)
}
// Check that the index is untouched and the entry
// has not been updated.
idx, config, err := s.CAConfig()
if err != nil {
t.Fatal(err)
}
if idx != 1 {
t.Fatalf("bad: %d", idx)
}
if config.Provider != "consul" {
t.Fatalf("bad: %#v", config)
}
// Do another CAS, this time with the correct index
ok, err = s.CACheckAndSetConfig(2, 1, &structs.CAConfiguration{
Provider: "static",
})
if !ok || err != nil {
t.Fatalf("expected (true, nil), got: (%v, %#v)", ok, err)
}
// Make sure the config was updated
idx, config, err = s.CAConfig()
if err != nil {
t.Fatal(err)
}
if idx != 2 {
t.Fatalf("bad: %d", idx)
}
if config.Provider != "static" {
t.Fatalf("bad: %#v", config)
}
}
func TestStore_CAConfig_Snapshot_Restore(t *testing.T) {
s := testStateStore(t)
before := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": "asdf",
"RootCert": "qwer",
"RotationPeriod": 90 * 24 * time.Hour,
},
}
if err := s.CASetConfig(99, before); err != nil {
t.Fatal(err)
}
snap := s.Snapshot()
defer snap.Close()
after := &structs.CAConfiguration{
Provider: "static",
Config: map[string]interface{}{},
}
if err := s.CASetConfig(100, after); err != nil {
t.Fatal(err)
}
snapped, err := snap.CAConfig()
if err != nil {
t.Fatalf("err: %s", err)
}
verify.Values(t, "", before, snapped)
s2 := testStateStore(t)
restore := s2.Restore()
if err := restore.CAConfig(snapped); err != nil {
t.Fatalf("err: %s", err)
}
restore.Commit()
idx, res, err := s2.CAConfig()
if err != nil {
t.Fatalf("err: %s", err)
}
if idx != 99 {
t.Fatalf("bad index: %d", idx)
}
verify.Values(t, "", before, res)
}
func TestStore_CARootSetList(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
assert.Nil(err)
assert.True(ok)
// Make sure the index got updated.
assert.Equal(s.maxIndex(caRootTableName), uint64(1))
assert.True(watchFired(ws), "watch fired")
// Read it back out and verify it.
expected := *ca1
expected.RaftIndex = structs.RaftIndex{
CreateIndex: 1,
ModifyIndex: 1,
}
ws = memdb.NewWatchSet()
_, roots, err := s.CARoots(ws)
assert.Nil(err)
assert.Len(roots, 1)
actual := roots[0]
assert.Equal(&expected, actual)
}
func TestStore_CARootSet_emptyID(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca1.ID = ""
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
assert.NotNil(err)
assert.Contains(err.Error(), ErrMissingCARootID.Error())
assert.False(ok)
// Make sure the index got updated.
assert.Equal(s.maxIndex(caRootTableName), uint64(0))
assert.False(watchFired(ws), "watch fired")
// Read it back out and verify it.
ws = memdb.NewWatchSet()
_, roots, err := s.CARoots(ws)
assert.Nil(err)
assert.Len(roots, 0)
}
func TestStore_CARootSet_noActive(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca1.Active = false
ca2 := connect.TestCA(t, nil)
ca2.Active = false
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
assert.NotNil(err)
assert.Contains(err.Error(), "exactly one active")
assert.False(ok)
}
func TestStore_CARootSet_multipleActive(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca2 := connect.TestCA(t, nil)
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
assert.NotNil(err)
assert.Contains(err.Error(), "exactly one active")
assert.False(ok)
}
func TestStore_CARootActive_valid(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca1.Active = false
ca2 := connect.TestCA(t, nil)
ca3 := connect.TestCA(t, nil)
ca3.Active = false
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2, ca3})
assert.Nil(err)
assert.True(ok)
// Query
ws := memdb.NewWatchSet()
idx, res, err := s.CARootActive(ws)
assert.Equal(idx, uint64(1))
assert.Nil(err)
assert.NotNil(res)
assert.Equal(ca2.ID, res.ID)
}
// Test that querying the active CA returns the correct value.
func TestStore_CARootActive_none(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Querying with no results returns nil.
ws := memdb.NewWatchSet()
idx, res, err := s.CARootActive(ws)
assert.Equal(idx, uint64(0))
assert.Nil(res)
assert.Nil(err)
}
func TestStore_CARoot_Snapshot_Restore(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Create some intentions.
roots := structs.CARoots{
connect.TestCA(t, nil),
connect.TestCA(t, nil),
connect.TestCA(t, nil),
}
for _, r := range roots[1:] {
r.Active = false
}
// Force the sort order of the UUIDs before we create them so the
// order is deterministic.
id := testUUID()
roots[0].ID = "a" + id[1:]
roots[1].ID = "b" + id[1:]
roots[2].ID = "c" + id[1:]
// Now create
ok, err := s.CARootSetCAS(1, 0, roots)
assert.Nil(err)
assert.True(ok)
// Snapshot the queries.
snap := s.Snapshot()
defer snap.Close()
// Alter the real state store.
ok, err = s.CARootSetCAS(2, 1, roots[:1])
assert.Nil(err)
assert.True(ok)
// Verify the snapshot.
assert.Equal(snap.LastIndex(), uint64(1))
dump, err := snap.CARoots()
assert.Nil(err)
assert.Equal(roots, dump)
// Restore the values into a new state store.
func() {
s := testStateStore(t)
restore := s.Restore()
for _, r := range dump {
assert.Nil(restore.CARoot(r))
}
restore.Commit()
// Read the restored values back out and verify that they match.
idx, actual, err := s.CARoots(nil)
assert.Nil(err)
assert.Equal(idx, uint64(2))
assert.Equal(roots, actual)
}()
}
func TestStore_CABuiltinProvider(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
{
expected := &structs.CAConsulProviderState{
ID: "foo",
PrivateKey: "a",
RootCert: "b",
}
ok, err := s.CASetProviderState(0, expected)
assert.NoError(err)
assert.True(ok)
idx, state, err := s.CAProviderState(expected.ID)
assert.NoError(err)
assert.Equal(idx, uint64(0))
assert.Equal(expected, state)
}
{
expected := &structs.CAConsulProviderState{
ID: "bar",
PrivateKey: "c",
RootCert: "d",
}
ok, err := s.CASetProviderState(1, expected)
assert.NoError(err)
assert.True(ok)
idx, state, err := s.CAProviderState(expected.ID)
assert.NoError(err)
assert.Equal(idx, uint64(1))
assert.Equal(expected, state)
}
}
func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Create multiple state entries.
before := []*structs.CAConsulProviderState{
{
ID: "bar",
PrivateKey: "y",
RootCert: "z",
},
{
ID: "foo",
PrivateKey: "a",
RootCert: "b",
},
}
for i, state := range before {
ok, err := s.CASetProviderState(uint64(98+i), state)
assert.NoError(err)
assert.True(ok)
}
// Take a snapshot.
snap := s.Snapshot()
defer snap.Close()
// Modify the state store.
after := &structs.CAConsulProviderState{
ID: "foo",
PrivateKey: "c",
RootCert: "d",
}
ok, err := s.CASetProviderState(100, after)
assert.NoError(err)
assert.True(ok)
snapped, err := snap.CAProviderState()
assert.NoError(err)
assert.Equal(before, snapped)
// Restore onto a new state store.
s2 := testStateStore(t)
restore := s2.Restore()
for _, entry := range snapped {
assert.NoError(restore.CAProviderState(entry))
}
restore.Commit()
// Verify the restored values match those from before the snapshot.
for _, state := range before {
idx, res, err := s2.CAProviderState(state.ID)
assert.NoError(err)
assert.Equal(idx, uint64(99))
assert.Equal(state, res)
}
}

View File

@ -0,0 +1,54 @@
package state
import (
"fmt"
"strings"
"github.com/hashicorp/consul/agent/structs"
)
// IndexConnectService indexes a *struct.ServiceNode for querying by
// services that support Connect to some target service. This will
// properly index the proxy destination for proxies and the service name
// for native services.
type IndexConnectService struct{}
func (idx *IndexConnectService) FromObject(obj interface{}) (bool, []byte, error) {
sn, ok := obj.(*structs.ServiceNode)
if !ok {
return false, nil, fmt.Errorf("Object must be ServiceNode, got %T", obj)
}
var result []byte
switch {
case sn.ServiceKind == structs.ServiceKindConnectProxy:
// For proxies, this service supports Connect for the destination
result = []byte(strings.ToLower(sn.ServiceProxyDestination))
case sn.ServiceConnect.Native:
// For native, this service supports Connect directly
result = []byte(strings.ToLower(sn.ServiceName))
default:
// Doesn't support Connect at all
return false, nil, nil
}
// Return the result with the null terminator appended so we can
// differentiate prefix vs. non-prefix matches.
return true, append(result, '\x00'), nil
}
func (idx *IndexConnectService) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
arg, ok := args[0].(string)
if !ok {
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
}
// Add the null character as a terminator
return append([]byte(strings.ToLower(arg)), '\x00'), nil
}

View File

@ -0,0 +1,122 @@
package state
import (
"testing"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/require"
)
func TestIndexConnectService_FromObject(t *testing.T) {
cases := []struct {
Name string
Input interface{}
ExpectMatch bool
ExpectVal []byte
ExpectErr string
}{
{
"not a ServiceNode",
42,
false,
nil,
"ServiceNode",
},
{
"typical service, not native",
&structs.ServiceNode{
ServiceName: "db",
},
false,
nil,
"",
},
{
"typical service, is native",
&structs.ServiceNode{
ServiceName: "dB",
ServiceConnect: structs.ServiceConnect{Native: true},
},
true,
[]byte("db\x00"),
"",
},
{
"proxy service",
&structs.ServiceNode{
ServiceKind: structs.ServiceKindConnectProxy,
ServiceName: "db",
ServiceProxyDestination: "fOo",
},
true,
[]byte("foo\x00"),
"",
},
}
for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
var idx IndexConnectService
match, val, err := idx.FromObject(tc.Input)
if tc.ExpectErr != "" {
require.Error(err)
require.Contains(err.Error(), tc.ExpectErr)
return
}
require.NoError(err)
require.Equal(tc.ExpectMatch, match)
require.Equal(tc.ExpectVal, val)
})
}
}
func TestIndexConnectService_FromArgs(t *testing.T) {
cases := []struct {
Name string
Args []interface{}
ExpectVal []byte
ExpectErr string
}{
{
"multiple arguments",
[]interface{}{"foo", "bar"},
nil,
"single",
},
{
"not a string",
[]interface{}{42},
nil,
"must be a string",
},
{
"string",
[]interface{}{"fOO"},
[]byte("foo\x00"),
"",
},
}
for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
var idx IndexConnectService
val, err := idx.FromArgs(tc.Args...)
if tc.ExpectErr != "" {
require.Error(err)
require.Contains(err.Error(), tc.ExpectErr)
return
}
require.NoError(err)
require.Equal(tc.ExpectVal, val)
})
}
}

View File

@ -0,0 +1,366 @@
package state
import (
"fmt"
"sort"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
)
const (
intentionsTableName = "connect-intentions"
)
// intentionsTableSchema returns a new table schema used for storing
// intentions for Connect.
func intentionsTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: intentionsTableName,
Indexes: map[string]*memdb.IndexSchema{
"id": &memdb.IndexSchema{
Name: "id",
AllowMissing: false,
Unique: true,
Indexer: &memdb.UUIDFieldIndex{
Field: "ID",
},
},
"destination": &memdb.IndexSchema{
Name: "destination",
AllowMissing: true,
// This index is not unique since we need uniqueness across the whole
// 4-tuple.
Unique: false,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "DestinationNS",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "DestinationName",
Lowercase: true,
},
},
},
},
"source": &memdb.IndexSchema{
Name: "source",
AllowMissing: true,
// This index is not unique since we need uniqueness across the whole
// 4-tuple.
Unique: false,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "SourceNS",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "SourceName",
Lowercase: true,
},
},
},
},
"source_destination": &memdb.IndexSchema{
Name: "source_destination",
AllowMissing: true,
Unique: true,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "SourceNS",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "SourceName",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "DestinationNS",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "DestinationName",
Lowercase: true,
},
},
},
},
},
}
}
func init() {
registerSchema(intentionsTableSchema)
}
// Intentions is used to pull all the intentions from the snapshot.
func (s *Snapshot) Intentions() (structs.Intentions, error) {
ixns, err := s.tx.Get(intentionsTableName, "id")
if err != nil {
return nil, err
}
var ret structs.Intentions
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
ret = append(ret, wrapped.(*structs.Intention))
}
return ret, nil
}
// Intention is used when restoring from a snapshot.
func (s *Restore) Intention(ixn *structs.Intention) error {
// Insert the intention
if err := s.tx.Insert(intentionsTableName, ixn); err != nil {
return fmt.Errorf("failed restoring intention: %s", err)
}
if err := indexUpdateMaxTxn(s.tx, ixn.ModifyIndex, intentionsTableName); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
return nil
}
// Intentions returns the list of all intentions.
func (s *Store) Intentions(ws memdb.WatchSet) (uint64, structs.Intentions, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the index
idx := maxIndexTxn(tx, intentionsTableName)
if idx < 1 {
idx = 1
}
// Get all intentions
iter, err := tx.Get(intentionsTableName, "id")
if err != nil {
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
}
ws.Add(iter.WatchCh())
var results structs.Intentions
for ixn := iter.Next(); ixn != nil; ixn = iter.Next() {
results = append(results, ixn.(*structs.Intention))
}
// Sort by precedence just because that's nicer and probably what most clients
// want for presentation.
sort.Sort(structs.IntentionPrecedenceSorter(results))
return idx, results, nil
}
// IntentionSet creates or updates an intention.
func (s *Store) IntentionSet(idx uint64, ixn *structs.Intention) error {
tx := s.db.Txn(true)
defer tx.Abort()
if err := s.intentionSetTxn(tx, idx, ixn); err != nil {
return err
}
tx.Commit()
return nil
}
// intentionSetTxn is the inner method used to insert an intention with
// the proper indexes into the state store.
func (s *Store) intentionSetTxn(tx *memdb.Txn, idx uint64, ixn *structs.Intention) error {
// ID is required
if ixn.ID == "" {
return ErrMissingIntentionID
}
// Ensure Precedence is populated correctly on "write"
ixn.UpdatePrecedence()
// Check for an existing intention
existing, err := tx.First(intentionsTableName, "id", ixn.ID)
if err != nil {
return fmt.Errorf("failed intention lookup: %s", err)
}
if existing != nil {
oldIxn := existing.(*structs.Intention)
ixn.CreateIndex = oldIxn.CreateIndex
ixn.CreatedAt = oldIxn.CreatedAt
} else {
ixn.CreateIndex = idx
}
ixn.ModifyIndex = idx
// Check for duplicates on the 4-tuple.
duplicate, err := tx.First(intentionsTableName, "source_destination",
ixn.SourceNS, ixn.SourceName, ixn.DestinationNS, ixn.DestinationName)
if err != nil {
return fmt.Errorf("failed intention lookup: %s", err)
}
if duplicate != nil {
dupIxn := duplicate.(*structs.Intention)
// Same ID is OK - this is an update
if dupIxn.ID != ixn.ID {
return fmt.Errorf("duplicate intention found: %s", dupIxn.String())
}
}
// We always force meta to be non-nil so that we its an empty map.
// This makes it easy for API responses to not nil-check this everywhere.
if ixn.Meta == nil {
ixn.Meta = make(map[string]string)
}
// Insert
if err := tx.Insert(intentionsTableName, ixn); err != nil {
return err
}
if err := tx.Insert("index", &IndexEntry{intentionsTableName, idx}); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
return nil
}
// IntentionGet returns the given intention by ID.
func (s *Store) IntentionGet(ws memdb.WatchSet, id string) (uint64, *structs.Intention, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexTxn(tx, intentionsTableName)
if idx < 1 {
idx = 1
}
// Look up by its ID.
watchCh, intention, err := tx.FirstWatch(intentionsTableName, "id", id)
if err != nil {
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
}
ws.Add(watchCh)
// Convert the interface{} if it is non-nil
var result *structs.Intention
if intention != nil {
result = intention.(*structs.Intention)
}
return idx, result, nil
}
// IntentionDelete deletes the given intention by ID.
func (s *Store) IntentionDelete(idx uint64, id string) error {
tx := s.db.Txn(true)
defer tx.Abort()
if err := s.intentionDeleteTxn(tx, idx, id); err != nil {
return fmt.Errorf("failed intention delete: %s", err)
}
tx.Commit()
return nil
}
// intentionDeleteTxn is the inner method used to delete a intention
// with the proper indexes into the state store.
func (s *Store) intentionDeleteTxn(tx *memdb.Txn, idx uint64, queryID string) error {
// Pull the query.
wrapped, err := tx.First(intentionsTableName, "id", queryID)
if err != nil {
return fmt.Errorf("failed intention lookup: %s", err)
}
if wrapped == nil {
return nil
}
// Delete the query and update the index.
if err := tx.Delete(intentionsTableName, wrapped); err != nil {
return fmt.Errorf("failed intention delete: %s", err)
}
if err := tx.Insert("index", &IndexEntry{intentionsTableName, idx}); err != nil {
return fmt.Errorf("failed updating index: %s", err)
}
return nil
}
// IntentionMatch returns the list of intentions that match the namespace and
// name for either a source or destination. This applies the resolution rules
// so wildcards will match any value.
//
// The returned value is the list of intentions in the same order as the
// entries in args. The intentions themselves are sorted based on the
// intention precedence rules. i.e. result[0][0] is the highest precedent
// rule to match for the first entry.
func (s *Store) IntentionMatch(ws memdb.WatchSet, args *structs.IntentionQueryMatch) (uint64, []structs.Intentions, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexTxn(tx, intentionsTableName)
if idx < 1 {
idx = 1
}
// Make all the calls and accumulate the results
results := make([]structs.Intentions, len(args.Entries))
for i, entry := range args.Entries {
// Each search entry may require multiple queries to memdb, so this
// returns the arguments for each necessary Get. Note on performance:
// this is not the most optimal set of queries since we repeat some
// many times (such as */*). We can work on improving that in the
// future, the test cases shouldn't have to change for that.
getParams, err := s.intentionMatchGetParams(entry)
if err != nil {
return 0, nil, err
}
// Perform each call and accumulate the result.
var ixns structs.Intentions
for _, params := range getParams {
iter, err := tx.Get(intentionsTableName, string(args.Type), params...)
if err != nil {
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
}
ws.Add(iter.WatchCh())
for ixn := iter.Next(); ixn != nil; ixn = iter.Next() {
ixns = append(ixns, ixn.(*structs.Intention))
}
}
// Sort the results by precedence
sort.Sort(structs.IntentionPrecedenceSorter(ixns))
// Store the result
results[i] = ixns
}
return idx, results, nil
}
// intentionMatchGetParams returns the tx.Get parameters to find all the
// intentions for a certain entry.
func (s *Store) intentionMatchGetParams(entry structs.IntentionMatchEntry) ([][]interface{}, error) {
// We always query for "*/*" so include that. If the namespace is a
// wildcard, then we're actually done.
result := make([][]interface{}, 0, 3)
result = append(result, []interface{}{"*", "*"})
if entry.Namespace == structs.IntentionWildcard {
return result, nil
}
// Search for NS/* intentions. If we have a wildcard name, then we're done.
result = append(result, []interface{}{entry.Namespace, "*"})
if entry.Name == structs.IntentionWildcard {
return result, nil
}
// Search for the exact NS/N value.
result = append(result, []interface{}{entry.Namespace, entry.Name})
return result, nil
}

View File

@ -0,0 +1,559 @@
package state
import (
"testing"
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-memdb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStore_IntentionGet_none(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Querying with no results returns nil.
ws := memdb.NewWatchSet()
idx, res, err := s.IntentionGet(ws, testUUID())
assert.Equal(uint64(1), idx)
assert.Nil(res)
assert.Nil(err)
}
func TestStore_IntentionSetGet_basic(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call Get to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.IntentionGet(ws, testUUID())
assert.Nil(err)
// Build a valid intention
ixn := &structs.Intention{
ID: testUUID(),
SourceNS: "default",
SourceName: "*",
DestinationNS: "default",
DestinationName: "web",
Meta: map[string]string{},
}
// Inserting a with empty ID is disallowed.
assert.NoError(s.IntentionSet(1, ixn))
// Make sure the index got updated.
assert.Equal(uint64(1), s.maxIndex(intentionsTableName))
assert.True(watchFired(ws), "watch fired")
// Read it back out and verify it.
expected := &structs.Intention{
ID: ixn.ID,
SourceNS: "default",
SourceName: "*",
DestinationNS: "default",
DestinationName: "web",
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 1,
ModifyIndex: 1,
},
}
expected.UpdatePrecedence()
ws = memdb.NewWatchSet()
idx, actual, err := s.IntentionGet(ws, ixn.ID)
assert.NoError(err)
assert.Equal(expected.CreateIndex, idx)
assert.Equal(expected, actual)
// Change a value and test updating
ixn.SourceNS = "foo"
assert.NoError(s.IntentionSet(2, ixn))
// Change a value that isn't in the unique 4 tuple and check we don't
// incorrectly consider this a duplicate when updating.
ixn.Action = structs.IntentionActionDeny
assert.NoError(s.IntentionSet(2, ixn))
// Make sure the index got updated.
assert.Equal(uint64(2), s.maxIndex(intentionsTableName))
assert.True(watchFired(ws), "watch fired")
// Read it back and verify the data was updated
expected.SourceNS = ixn.SourceNS
expected.Action = structs.IntentionActionDeny
expected.ModifyIndex = 2
ws = memdb.NewWatchSet()
idx, actual, err = s.IntentionGet(ws, ixn.ID)
assert.NoError(err)
assert.Equal(expected.ModifyIndex, idx)
assert.Equal(expected, actual)
// Attempt to insert another intention with duplicate 4-tuple
ixn = &structs.Intention{
ID: testUUID(),
SourceNS: "default",
SourceName: "*",
DestinationNS: "default",
DestinationName: "web",
Meta: map[string]string{},
}
// Duplicate 4-tuple should cause an error
ws = memdb.NewWatchSet()
assert.Error(s.IntentionSet(3, ixn))
// Make sure the index did NOT get updated.
assert.Equal(uint64(2), s.maxIndex(intentionsTableName))
assert.False(watchFired(ws), "watch not fired")
}
func TestStore_IntentionSet_emptyId(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
ws := memdb.NewWatchSet()
_, _, err := s.IntentionGet(ws, testUUID())
assert.NoError(err)
// Inserting a with empty ID is disallowed.
err = s.IntentionSet(1, &structs.Intention{})
assert.Error(err)
assert.Contains(err.Error(), ErrMissingIntentionID.Error())
// Index is not updated if nothing is saved.
assert.Equal(s.maxIndex(intentionsTableName), uint64(0))
assert.False(watchFired(ws), "watch fired")
}
func TestStore_IntentionSet_updateCreatedAt(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Build a valid intention
now := time.Now()
ixn := structs.Intention{
ID: testUUID(),
CreatedAt: now,
}
// Insert
assert.NoError(s.IntentionSet(1, &ixn))
// Change a value and test updating
ixnUpdate := ixn
ixnUpdate.CreatedAt = now.Add(10 * time.Second)
assert.NoError(s.IntentionSet(2, &ixnUpdate))
// Read it back and verify
_, actual, err := s.IntentionGet(nil, ixn.ID)
assert.NoError(err)
assert.Equal(now, actual.CreatedAt)
}
func TestStore_IntentionSet_metaNil(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Build a valid intention
ixn := structs.Intention{
ID: testUUID(),
}
// Insert
assert.NoError(s.IntentionSet(1, &ixn))
// Read it back and verify
_, actual, err := s.IntentionGet(nil, ixn.ID)
assert.NoError(err)
assert.NotNil(actual.Meta)
}
func TestStore_IntentionSet_metaSet(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Build a valid intention
ixn := structs.Intention{
ID: testUUID(),
Meta: map[string]string{"foo": "bar"},
}
// Insert
assert.NoError(s.IntentionSet(1, &ixn))
// Read it back and verify
_, actual, err := s.IntentionGet(nil, ixn.ID)
assert.NoError(err)
assert.Equal(ixn.Meta, actual.Meta)
}
func TestStore_IntentionDelete(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Call Get to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.IntentionGet(ws, testUUID())
assert.NoError(err)
// Create
ixn := &structs.Intention{ID: testUUID()}
assert.NoError(s.IntentionSet(1, ixn))
// Make sure the index got updated.
assert.Equal(s.maxIndex(intentionsTableName), uint64(1))
assert.True(watchFired(ws), "watch fired")
// Delete
assert.NoError(s.IntentionDelete(2, ixn.ID))
// Make sure the index got updated.
assert.Equal(s.maxIndex(intentionsTableName), uint64(2))
assert.True(watchFired(ws), "watch fired")
// Sanity check to make sure it's not there.
idx, actual, err := s.IntentionGet(nil, ixn.ID)
assert.NoError(err)
assert.Equal(idx, uint64(2))
assert.Nil(actual)
}
func TestStore_IntentionsList(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Querying with no results returns nil.
ws := memdb.NewWatchSet()
idx, res, err := s.Intentions(ws)
assert.NoError(err)
assert.Nil(res)
assert.Equal(uint64(1), idx)
// Create some intentions
ixns := structs.Intentions{
&structs.Intention{
ID: testUUID(),
Meta: map[string]string{},
},
&structs.Intention{
ID: testUUID(),
Meta: map[string]string{},
},
}
// Force deterministic sort order
ixns[0].ID = "a" + ixns[0].ID[1:]
ixns[1].ID = "b" + ixns[1].ID[1:]
// Create
for i, ixn := range ixns {
assert.NoError(s.IntentionSet(uint64(1+i), ixn))
}
assert.True(watchFired(ws), "watch fired")
// Read it back and verify.
expected := structs.Intentions{
&structs.Intention{
ID: ixns[0].ID,
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 1,
ModifyIndex: 1,
},
},
&structs.Intention{
ID: ixns[1].ID,
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 2,
ModifyIndex: 2,
},
},
}
for i := range expected {
expected[i].UpdatePrecedence() // to match what is returned...
}
idx, actual, err := s.Intentions(nil)
assert.NoError(err)
assert.Equal(idx, uint64(2))
assert.Equal(expected, actual)
}
// Test the matrix of match logic.
//
// Note that this doesn't need to test the intention sort logic exhaustively
// since this is tested in their sort implementation in the structs.
func TestStore_IntentionMatch_table(t *testing.T) {
type testCase struct {
Name string
Insert [][]string // List of intentions to insert
Query [][]string // List of intentions to match
Expected [][][]string // List of matches, where each match is a list of intentions
}
cases := []testCase{
{
"single exact namespace/name",
[][]string{
{"foo", "*"},
{"foo", "bar"},
{"foo", "baz"}, // shouldn't match
{"bar", "bar"}, // shouldn't match
{"bar", "*"}, // shouldn't match
{"*", "*"},
},
[][]string{
{"foo", "bar"},
},
[][][]string{
{
{"foo", "bar"},
{"foo", "*"},
{"*", "*"},
},
},
},
{
"multiple exact namespace/name",
[][]string{
{"foo", "*"},
{"foo", "bar"},
{"foo", "baz"}, // shouldn't match
{"bar", "bar"},
{"bar", "*"},
},
[][]string{
{"foo", "bar"},
{"bar", "bar"},
},
[][][]string{
{
{"foo", "bar"},
{"foo", "*"},
},
{
{"bar", "bar"},
{"bar", "*"},
},
},
},
{
"single exact namespace/name with duplicate destinations",
[][]string{
// 4-tuple specifies src and destination to test duplicate destinations
// with different sources. We flip them around to test in both
// directions. The first pair are the ones searched on in both cases so
// the duplicates need to be there.
{"foo", "bar", "foo", "*"},
{"foo", "bar", "bar", "*"},
{"*", "*", "*", "*"},
},
[][]string{
{"foo", "bar"},
},
[][][]string{
{
// Note the first two have the same precedence so we rely on arbitrary
// lexicographical tie-break behaviour.
{"foo", "bar", "bar", "*"},
{"foo", "bar", "foo", "*"},
{"*", "*", "*", "*"},
},
},
},
}
// testRunner implements the test for a single case, but can be
// parameterized to run for both source and destination so we can
// test both cases.
testRunner := func(t *testing.T, tc testCase, typ structs.IntentionMatchType) {
// Insert the set
assert := assert.New(t)
s := testStateStore(t)
var idx uint64 = 1
for _, v := range tc.Insert {
ixn := &structs.Intention{ID: testUUID()}
switch typ {
case structs.IntentionMatchDestination:
ixn.DestinationNS = v[0]
ixn.DestinationName = v[1]
if len(v) == 4 {
ixn.SourceNS = v[2]
ixn.SourceName = v[3]
}
case structs.IntentionMatchSource:
ixn.SourceNS = v[0]
ixn.SourceName = v[1]
if len(v) == 4 {
ixn.DestinationNS = v[2]
ixn.DestinationName = v[3]
}
}
assert.NoError(s.IntentionSet(idx, ixn))
idx++
}
// Build the arguments
args := &structs.IntentionQueryMatch{Type: typ}
for _, q := range tc.Query {
args.Entries = append(args.Entries, structs.IntentionMatchEntry{
Namespace: q[0],
Name: q[1],
})
}
// Match
_, matches, err := s.IntentionMatch(nil, args)
assert.NoError(err)
// Should have equal lengths
require.Len(t, matches, len(tc.Expected))
// Verify matches
for i, expected := range tc.Expected {
var actual [][]string
for _, ixn := range matches[i] {
switch typ {
case structs.IntentionMatchDestination:
if len(expected) > 1 && len(expected[0]) == 4 {
actual = append(actual, []string{
ixn.DestinationNS,
ixn.DestinationName,
ixn.SourceNS,
ixn.SourceName,
})
} else {
actual = append(actual, []string{ixn.DestinationNS, ixn.DestinationName})
}
case structs.IntentionMatchSource:
if len(expected) > 1 && len(expected[0]) == 4 {
actual = append(actual, []string{
ixn.SourceNS,
ixn.SourceName,
ixn.DestinationNS,
ixn.DestinationName,
})
} else {
actual = append(actual, []string{ixn.SourceNS, ixn.SourceName})
}
}
}
assert.Equal(expected, actual)
}
}
for _, tc := range cases {
t.Run(tc.Name+" (destination)", func(t *testing.T) {
testRunner(t, tc, structs.IntentionMatchDestination)
})
t.Run(tc.Name+" (source)", func(t *testing.T) {
testRunner(t, tc, structs.IntentionMatchSource)
})
}
}
func TestStore_Intention_Snapshot_Restore(t *testing.T) {
assert := assert.New(t)
s := testStateStore(t)
// Create some intentions.
ixns := structs.Intentions{
&structs.Intention{
DestinationName: "foo",
},
&structs.Intention{
DestinationName: "bar",
},
&structs.Intention{
DestinationName: "baz",
},
}
// Force the sort order of the UUIDs before we create them so the
// order is deterministic.
id := testUUID()
ixns[0].ID = "a" + id[1:]
ixns[1].ID = "b" + id[1:]
ixns[2].ID = "c" + id[1:]
// Now create
for i, ixn := range ixns {
assert.NoError(s.IntentionSet(uint64(4+i), ixn))
}
// Snapshot the queries.
snap := s.Snapshot()
defer snap.Close()
// Alter the real state store.
assert.NoError(s.IntentionDelete(7, ixns[0].ID))
// Verify the snapshot.
assert.Equal(snap.LastIndex(), uint64(6))
// Expect them sorted in insertion order
expected := structs.Intentions{
&structs.Intention{
ID: ixns[0].ID,
DestinationName: "foo",
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 4,
ModifyIndex: 4,
},
},
&structs.Intention{
ID: ixns[1].ID,
DestinationName: "bar",
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 5,
ModifyIndex: 5,
},
},
&structs.Intention{
ID: ixns[2].ID,
DestinationName: "baz",
Meta: map[string]string{},
RaftIndex: structs.RaftIndex{
CreateIndex: 6,
ModifyIndex: 6,
},
},
}
for i := range expected {
expected[i].UpdatePrecedence() // to match what is returned...
}
dump, err := snap.Intentions()
assert.NoError(err)
assert.Equal(expected, dump)
// Restore the values into a new state store.
func() {
s := testStateStore(t)
restore := s.Restore()
for _, ixn := range dump {
assert.NoError(restore.Intention(ixn))
}
restore.Commit()
// Read the restored values back out and verify that they match. Note that
// Intentions are returned precedence sorted unlike the snapshot so we need
// to rearrange the expected slice some.
expected[0], expected[1], expected[2] = expected[1], expected[2], expected[0]
idx, actual, err := s.Intentions(nil)
assert.NoError(err)
assert.Equal(idx, uint64(6))
assert.Equal(expected, actual)
}()
}

View File

@ -28,6 +28,14 @@ var (
// ErrMissingQueryID is returned when a Query set is called on
// a Query with an empty ID.
ErrMissingQueryID = errors.New("Missing Query ID")
// ErrMissingCARootID is returned when an CARoot set is called
// with an CARoot with an empty ID.
ErrMissingCARootID = errors.New("Missing CA Root ID")
// ErrMissingIntentionID is returned when an Intention set is called
// with an Intention with an empty ID.
ErrMissingIntentionID = errors.New("Missing Intention ID")
)
const (

View File

@ -339,7 +339,7 @@ func (d *DNSServer) addSOA(msg *dns.Msg) {
// nameservers returns the names and ip addresses of up to three random servers
// in the current cluster which serve as authoritative name servers for zone.
func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "")
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false)
if err != nil {
d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err)
return nil, nil
@ -417,7 +417,7 @@ PARSE:
n = n + 1
}
switch labels[n-1] {
switch kind := labels[n-1]; kind {
case "service":
if n == 1 {
goto INVALID
@ -435,7 +435,7 @@ PARSE:
}
// _name._tag.service.consul
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, req, resp)
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp)
// Consul 0.3 and prior format for SRV queries
} else {
@ -447,9 +447,17 @@ PARSE:
}
// tag[.tag].name.service.consul
d.serviceLookup(network, datacenter, labels[n-2], tag, req, resp)
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp)
}
case "connect":
if n == 1 {
goto INVALID
}
// name.connect.consul
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp)
case "node":
if n == 1 {
goto INVALID
@ -913,8 +921,9 @@ func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed
}
// lookupServiceNodes returns nodes with a given service.
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string) (structs.IndexedCheckServiceNodes, error) {
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool) (structs.IndexedCheckServiceNodes, error) {
args := structs.ServiceSpecificRequest{
Connect: connect,
Datacenter: datacenter,
ServiceName: service,
ServiceTag: tag,
@ -950,8 +959,8 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string) (structs
}
// serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, resp *dns.Msg) {
out, err := d.lookupServiceNodes(datacenter, service, tag)
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg) {
out, err := d.lookupServiceNodes(datacenter, service, tag, connect)
if err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)

View File

@ -17,6 +17,7 @@ import (
"github.com/hashicorp/serf/coordinate"
"github.com/miekg/dns"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -1125,6 +1126,51 @@ func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) {
verify.Values(t, "extra", in.Extra, wantExtra)
}
func TestDNS_ConnectServiceLookup(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
{
args := structs.TestRegisterRequestProxy(t)
args.Address = "127.0.0.55"
args.Service.ProxyDestination = "db"
args.Service.Address = ""
args.Service.Port = 12345
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
}
// Look up the service
questions := []string{
"db.connect.consul.",
}
for _, question := range questions {
m := new(dns.Msg)
m.SetQuestion(question, dns.TypeSRV)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
assert.Nil(err)
assert.Len(in.Answer, 1)
srvRec, ok := in.Answer[0].(*dns.SRV)
assert.True(ok)
assert.Equal(uint16(12345), srvRec.Port)
assert.Equal("foo.node.dc1.consul.", srvRec.Target)
assert.Equal(uint32(0), srvRec.Hdr.Ttl)
cnameRec, ok := in.Extra[0].(*dns.A)
assert.True(ok)
assert.Equal("foo.node.dc1.consul.", cnameRec.Hdr.Name)
assert.Equal(uint32(0), srvRec.Hdr.Ttl)
assert.Equal("127.0.0.55", cnameRec.A.String())
}
}
func TestDNS_ExternalServiceLookup(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")

View File

@ -143,9 +143,17 @@ RETRY_ONCE:
return out.HealthChecks, nil
}
func (s *HTTPServer) HealthConnectServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return s.healthServiceNodes(resp, req, true)
}
func (s *HTTPServer) HealthServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return s.healthServiceNodes(resp, req, false)
}
func (s *HTTPServer) healthServiceNodes(resp http.ResponseWriter, req *http.Request, connect bool) (interface{}, error) {
// Set default DC
args := structs.ServiceSpecificRequest{}
args := structs.ServiceSpecificRequest{Connect: connect}
s.parseSource(req, &args.Source)
args.NodeMetaFilters = s.parseMetaFilter(req)
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
@ -159,8 +167,14 @@ func (s *HTTPServer) HealthServiceNodes(resp http.ResponseWriter, req *http.Requ
args.TagFilter = true
}
// Determine the prefix
prefix := "/v1/health/service/"
if connect {
prefix = "/v1/health/connect/"
}
// Pull out the service name
args.ServiceName = strings.TrimPrefix(req.URL.Path, "/v1/health/service/")
args.ServiceName = strings.TrimPrefix(req.URL.Path, prefix)
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service name")

View File

@ -13,6 +13,7 @@ import (
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/testutil/retry"
"github.com/hashicorp/serf/coordinate"
"github.com/stretchr/testify/assert"
)
func TestHealthChecksInState(t *testing.T) {
@ -770,6 +771,105 @@ func TestHealthServiceNodes_WanTranslation(t *testing.T) {
}
}
func TestHealthConnectServiceNodes(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
args := structs.TestRegisterRequestProxy(t)
var out struct{}
assert.Nil(a.RPC("Catalog.Register", args, &out))
// Request
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?dc=dc1", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
// Should be a non-nil empty list for checks
nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 1)
assert.Len(nodes[0].Checks, 0)
}
func TestHealthConnectServiceNodes_PassingFilter(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Register
args := structs.TestRegisterRequestProxy(t)
args.Check = &structs.HealthCheck{
Node: args.Node,
Name: "check",
ServiceID: args.Service.Service,
Status: api.HealthCritical,
}
var out struct{}
assert.Nil(t, a.RPC("Catalog.Register", args, &out))
t.Run("bc_no_query_value", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
// Should be 0 health check for consul
nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 0)
})
t.Run("passing_true", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing=true", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
// Should be 0 health check for consul
nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 0)
})
t.Run("passing_false", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing=false", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.Nil(err)
assertIndex(t, resp)
// Should be 1
nodes := obj.(structs.CheckServiceNodes)
assert.Len(nodes, 1)
})
t.Run("passing_bad", func(t *testing.T) {
assert := assert.New(t)
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing=nope-nope", args.Service.ProxyDestination), nil)
resp := httptest.NewRecorder()
a.srv.HealthConnectServiceNodes(resp, req)
assert.Equal(400, resp.Code)
body, err := ioutil.ReadAll(resp.Body)
assert.Nil(err)
assert.True(bytes.Contains(body, []byte("Invalid value for ?passing")))
})
}
func TestFilterNonPassing(t *testing.T) {
t.Parallel()
nodes := structs.CheckServiceNodes{

View File

@ -3,6 +3,7 @@ package agent
import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/pprof"
@ -16,6 +17,7 @@ import (
"github.com/NYTimes/gziphandler"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-cleanhttp"
"github.com/mitchellh/mapstructure"
@ -384,6 +386,13 @@ func (s *HTTPServer) Index(resp http.ResponseWriter, req *http.Request) {
// decodeBody is used to decode a JSON request body
func decodeBody(req *http.Request, out interface{}, cb func(interface{}) error) error {
// This generally only happens in tests since real HTTP requests set
// a non-nil body with no content. We guard against it anyways to prevent
// a panic. The EOF response is the same behavior as an empty reader.
if req.Body == nil {
return io.EOF
}
var raw interface{}
dec := json.NewDecoder(req.Body)
if err := dec.Decode(&raw); err != nil {
@ -409,6 +418,14 @@ func setTranslateAddr(resp http.ResponseWriter, active bool) {
// setIndex is used to set the index response header
func setIndex(resp http.ResponseWriter, index uint64) {
// If we ever return X-Consul-Index of 0 blocking clients will go into a busy
// loop and hammer us since ?index=0 will never block. It's always safe to
// return index=1 since the very first Raft write is always an internal one
// writing the raft config for the cluster so no user-facing blocking query
// will ever legitimately have an X-Consul-Index of 1.
if index == 0 {
index = 1
}
resp.Header().Set("X-Consul-Index", strconv.FormatUint(index, 10))
}
@ -444,6 +461,15 @@ func setMeta(resp http.ResponseWriter, m *structs.QueryMeta) {
setConsistency(resp, m.ConsistencyLevel)
}
// setCacheMeta sets http response headers to indicate cache status.
func setCacheMeta(resp http.ResponseWriter, m *cache.ResultMeta) {
str := "MISS"
if m != nil && m.Hit {
str = "HIT"
}
resp.Header().Set("X-Cache", str)
}
// setHeaders is used to set canonical response header fields
func setHeaders(resp http.ResponseWriter, headers map[string]string) {
for field, value := range headers {

View File

@ -29,16 +29,27 @@ func init() {
registerEndpoint("/v1/agent/check/warn/", []string{"PUT"}, (*HTTPServer).AgentCheckWarn)
registerEndpoint("/v1/agent/check/fail/", []string{"PUT"}, (*HTTPServer).AgentCheckFail)
registerEndpoint("/v1/agent/check/update/", []string{"PUT"}, (*HTTPServer).AgentCheckUpdate)
registerEndpoint("/v1/agent/connect/authorize", []string{"POST"}, (*HTTPServer).AgentConnectAuthorize)
registerEndpoint("/v1/agent/connect/ca/roots", []string{"GET"}, (*HTTPServer).AgentConnectCARoots)
registerEndpoint("/v1/agent/connect/ca/leaf/", []string{"GET"}, (*HTTPServer).AgentConnectCALeafCert)
registerEndpoint("/v1/agent/connect/proxy/", []string{"GET"}, (*HTTPServer).AgentConnectProxyConfig)
registerEndpoint("/v1/agent/service/register", []string{"PUT"}, (*HTTPServer).AgentRegisterService)
registerEndpoint("/v1/agent/service/deregister/", []string{"PUT"}, (*HTTPServer).AgentDeregisterService)
registerEndpoint("/v1/agent/service/maintenance/", []string{"PUT"}, (*HTTPServer).AgentServiceMaintenance)
registerEndpoint("/v1/catalog/register", []string{"PUT"}, (*HTTPServer).CatalogRegister)
registerEndpoint("/v1/catalog/connect/", []string{"GET"}, (*HTTPServer).CatalogConnectServiceNodes)
registerEndpoint("/v1/catalog/deregister", []string{"PUT"}, (*HTTPServer).CatalogDeregister)
registerEndpoint("/v1/catalog/datacenters", []string{"GET"}, (*HTTPServer).CatalogDatacenters)
registerEndpoint("/v1/catalog/nodes", []string{"GET"}, (*HTTPServer).CatalogNodes)
registerEndpoint("/v1/catalog/services", []string{"GET"}, (*HTTPServer).CatalogServices)
registerEndpoint("/v1/catalog/service/", []string{"GET"}, (*HTTPServer).CatalogServiceNodes)
registerEndpoint("/v1/catalog/node/", []string{"GET"}, (*HTTPServer).CatalogNodeServices)
registerEndpoint("/v1/connect/ca/configuration", []string{"GET", "PUT"}, (*HTTPServer).ConnectCAConfiguration)
registerEndpoint("/v1/connect/ca/roots", []string{"GET"}, (*HTTPServer).ConnectCARoots)
registerEndpoint("/v1/connect/intentions", []string{"GET", "POST"}, (*HTTPServer).IntentionEndpoint)
registerEndpoint("/v1/connect/intentions/match", []string{"GET"}, (*HTTPServer).IntentionMatch)
registerEndpoint("/v1/connect/intentions/check", []string{"GET"}, (*HTTPServer).IntentionCheck)
registerEndpoint("/v1/connect/intentions/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).IntentionSpecific)
registerEndpoint("/v1/coordinate/datacenters", []string{"GET"}, (*HTTPServer).CoordinateDatacenters)
registerEndpoint("/v1/coordinate/nodes", []string{"GET"}, (*HTTPServer).CoordinateNodes)
registerEndpoint("/v1/coordinate/node/", []string{"GET"}, (*HTTPServer).CoordinateNode)
@ -49,6 +60,7 @@ func init() {
registerEndpoint("/v1/health/checks/", []string{"GET"}, (*HTTPServer).HealthServiceChecks)
registerEndpoint("/v1/health/state/", []string{"GET"}, (*HTTPServer).HealthChecksInState)
registerEndpoint("/v1/health/service/", []string{"GET"}, (*HTTPServer).HealthServiceNodes)
registerEndpoint("/v1/health/connect/", []string{"GET"}, (*HTTPServer).HealthConnectServiceNodes)
registerEndpoint("/v1/internal/ui/nodes", []string{"GET"}, (*HTTPServer).UINodes)
registerEndpoint("/v1/internal/ui/node/", []string{"GET"}, (*HTTPServer).UINodeInfo)
registerEndpoint("/v1/internal/ui/services", []string{"GET"}, (*HTTPServer).UIServices)

View File

@ -0,0 +1,302 @@
package agent
import (
"fmt"
"net/http"
"strings"
"github.com/hashicorp/consul/agent/consul"
"github.com/hashicorp/consul/agent/structs"
)
// /v1/connection/intentions
func (s *HTTPServer) IntentionEndpoint(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
switch req.Method {
case "GET":
return s.IntentionList(resp, req)
case "POST":
return s.IntentionCreate(resp, req)
default:
return nil, MethodNotAllowedError{req.Method, []string{"GET", "POST"}}
}
}
// GET /v1/connect/intentions
func (s *HTTPServer) IntentionList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in IntentionEndpoint
var args structs.DCSpecificRequest
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
var reply structs.IndexedIntentions
if err := s.agent.RPC("Intention.List", &args, &reply); err != nil {
return nil, err
}
return reply.Intentions, nil
}
// POST /v1/connect/intentions
func (s *HTTPServer) IntentionCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in IntentionEndpoint
args := structs.IntentionRequest{
Op: structs.IntentionOpCreate,
}
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req, &args.Intention, nil); err != nil {
return nil, fmt.Errorf("Failed to decode request body: %s", err)
}
var reply string
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
return nil, err
}
return intentionCreateResponse{reply}, nil
}
// GET /v1/connect/intentions/match
func (s *HTTPServer) IntentionMatch(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Prepare args
args := &structs.IntentionQueryRequest{Match: &structs.IntentionQueryMatch{}}
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
q := req.URL.Query()
// Extract the "by" query parameter
if by, ok := q["by"]; !ok || len(by) != 1 {
return nil, fmt.Errorf("required query parameter 'by' not set")
} else {
switch v := structs.IntentionMatchType(by[0]); v {
case structs.IntentionMatchSource, structs.IntentionMatchDestination:
args.Match.Type = v
default:
return nil, fmt.Errorf("'by' parameter must be one of 'source' or 'destination'")
}
}
// Extract all the match names
names, ok := q["name"]
if !ok || len(names) == 0 {
return nil, fmt.Errorf("required query parameter 'name' not set")
}
// Build the entries in order. The order matters since that is the
// order of the returned responses.
args.Match.Entries = make([]structs.IntentionMatchEntry, len(names))
for i, n := range names {
entry, err := parseIntentionMatchEntry(n)
if err != nil {
return nil, fmt.Errorf("name %q is invalid: %s", n, err)
}
args.Match.Entries[i] = entry
}
var reply structs.IndexedIntentionMatches
if err := s.agent.RPC("Intention.Match", args, &reply); err != nil {
return nil, err
}
// We must have an identical count of matches
if len(reply.Matches) != len(names) {
return nil, fmt.Errorf("internal error: match response count didn't match input count")
}
// Use empty list instead of nil.
response := make(map[string]structs.Intentions)
for i, ixns := range reply.Matches {
response[names[i]] = ixns
}
return response, nil
}
// GET /v1/connect/intentions/check
func (s *HTTPServer) IntentionCheck(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Prepare args
args := &structs.IntentionQueryRequest{Check: &structs.IntentionQueryCheck{}}
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
q := req.URL.Query()
// Set the source type if set
args.Check.SourceType = structs.IntentionSourceConsul
if sourceType, ok := q["source-type"]; ok && len(sourceType) > 0 {
args.Check.SourceType = structs.IntentionSourceType(sourceType[0])
}
// Extract the source/destination
source, ok := q["source"]
if !ok || len(source) != 1 {
return nil, fmt.Errorf("required query parameter 'source' not set")
}
destination, ok := q["destination"]
if !ok || len(destination) != 1 {
return nil, fmt.Errorf("required query parameter 'destination' not set")
}
// We parse them the same way as matches to extract namespace/name
args.Check.SourceName = source[0]
if args.Check.SourceType == structs.IntentionSourceConsul {
entry, err := parseIntentionMatchEntry(source[0])
if err != nil {
return nil, fmt.Errorf("source %q is invalid: %s", source[0], err)
}
args.Check.SourceNS = entry.Namespace
args.Check.SourceName = entry.Name
}
// The destination is always in the Consul format
entry, err := parseIntentionMatchEntry(destination[0])
if err != nil {
return nil, fmt.Errorf("destination %q is invalid: %s", destination[0], err)
}
args.Check.DestinationNS = entry.Namespace
args.Check.DestinationName = entry.Name
var reply structs.IntentionQueryCheckResponse
if err := s.agent.RPC("Intention.Check", args, &reply); err != nil {
return nil, err
}
return &reply, nil
}
// IntentionSpecific handles the endpoint for /v1/connection/intentions/:id
func (s *HTTPServer) IntentionSpecific(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
id := strings.TrimPrefix(req.URL.Path, "/v1/connect/intentions/")
switch req.Method {
case "GET":
return s.IntentionSpecificGet(id, resp, req)
case "PUT":
return s.IntentionSpecificUpdate(id, resp, req)
case "DELETE":
return s.IntentionSpecificDelete(id, resp, req)
default:
return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}}
}
}
// GET /v1/connect/intentions/:id
func (s *HTTPServer) IntentionSpecificGet(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in IntentionEndpoint
args := structs.IntentionQueryRequest{
IntentionID: id,
}
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
return nil, nil
}
var reply structs.IndexedIntentions
if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil {
// We have to check the string since the RPC sheds the error type
if err.Error() == consul.ErrIntentionNotFound.Error() {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error())
return nil, nil
}
return nil, err
}
// This shouldn't happen since the called API documents it shouldn't,
// but we check since the alternative if it happens is a panic.
if len(reply.Intentions) == 0 {
resp.WriteHeader(http.StatusNotFound)
return nil, nil
}
return reply.Intentions[0], nil
}
// PUT /v1/connect/intentions/:id
func (s *HTTPServer) IntentionSpecificUpdate(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in IntentionEndpoint
args := structs.IntentionRequest{
Op: structs.IntentionOpUpdate,
}
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req, &args.Intention, nil); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
}
// Use the ID from the URL
args.Intention.ID = id
var reply string
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
return nil, err
}
// Update uses the same create response
return intentionCreateResponse{reply}, nil
}
// DELETE /v1/connect/intentions/:id
func (s *HTTPServer) IntentionSpecificDelete(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Method is tested in IntentionEndpoint
args := structs.IntentionRequest{
Op: structs.IntentionOpDelete,
Intention: &structs.Intention{ID: id},
}
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
var reply string
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
return nil, err
}
return true, nil
}
// intentionCreateResponse is the response structure for creating an intention.
type intentionCreateResponse struct{ ID string }
// parseIntentionMatchEntry parses the query parameter for an intention
// match query entry.
func parseIntentionMatchEntry(input string) (structs.IntentionMatchEntry, error) {
var result structs.IntentionMatchEntry
result.Namespace = structs.IntentionDefaultNamespace
// TODO(mitchellh): when namespaces are introduced, set the default
// namespace to be the namespace of the requestor.
// Get the index to the '/'. If it doesn't exist, we have just a name
// so just set that and return.
idx := strings.IndexByte(input, '/')
if idx == -1 {
result.Name = input
return result, nil
}
result.Namespace = input[:idx]
result.Name = input[idx+1:]
if strings.IndexByte(result.Name, '/') != -1 {
return result, fmt.Errorf("input can contain at most one '/'")
}
return result, nil
}

View File

@ -0,0 +1,502 @@
package agent
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIntentionsList_empty(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Make sure an empty list is non-nil.
req, _ := http.NewRequest("GET", "/v1/connect/intentions", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionList(resp, req)
assert.Nil(err)
value := obj.(structs.Intentions)
assert.NotNil(value)
assert.Len(value, 0)
}
func TestIntentionsList_values(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Create some intentions, note we create the lowest precedence first to test
// sorting.
for _, v := range []string{"*", "foo", "bar"} {
req := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: structs.TestIntention(t),
}
req.Intention.SourceName = v
var reply string
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
}
// Request
req, _ := http.NewRequest("GET", "/v1/connect/intentions", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionList(resp, req)
assert.NoError(err)
value := obj.(structs.Intentions)
assert.Len(value, 3)
expected := []string{"bar", "foo", "*"}
actual := []string{
value[0].SourceName,
value[1].SourceName,
value[2].SourceName,
}
assert.Equal(expected, actual)
}
func TestIntentionsMatch_basic(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Create some intentions
{
insert := [][]string{
{"foo", "*", "foo", "*"},
{"foo", "*", "foo", "bar"},
{"foo", "*", "foo", "baz"}, // shouldn't match
{"foo", "*", "bar", "bar"}, // shouldn't match
{"foo", "*", "bar", "*"}, // shouldn't match
{"foo", "*", "*", "*"},
{"bar", "*", "foo", "bar"}, // duplicate destination different source
}
for _, v := range insert {
ixn := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: structs.TestIntention(t),
}
ixn.Intention.SourceNS = v[0]
ixn.Intention.SourceName = v[1]
ixn.Intention.DestinationNS = v[2]
ixn.Intention.DestinationName = v[3]
// Create
var reply string
assert.Nil(a.RPC("Intention.Apply", &ixn, &reply))
}
}
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/match?by=destination&name=foo/bar", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionMatch(resp, req)
assert.Nil(err)
value := obj.(map[string]structs.Intentions)
assert.Len(value, 1)
var actual [][]string
expected := [][]string{
{"bar", "*", "foo", "bar"},
{"foo", "*", "foo", "bar"},
{"foo", "*", "foo", "*"},
{"foo", "*", "*", "*"},
}
for _, ixn := range value["foo/bar"] {
actual = append(actual, []string{
ixn.SourceNS,
ixn.SourceName,
ixn.DestinationNS,
ixn.DestinationName,
})
}
assert.Equal(expected, actual)
}
func TestIntentionsMatch_noBy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/match?name=foo/bar", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionMatch(resp, req)
assert.NotNil(err)
assert.Contains(err.Error(), "by")
assert.Nil(obj)
}
func TestIntentionsMatch_byInvalid(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/match?by=datacenter", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionMatch(resp, req)
assert.NotNil(err)
assert.Contains(err.Error(), "'by' parameter")
assert.Nil(obj)
}
func TestIntentionsMatch_noName(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/match?by=source", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionMatch(resp, req)
assert.NotNil(err)
assert.Contains(err.Error(), "'name' not set")
assert.Nil(obj)
}
func TestIntentionsCheck_basic(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Create some intentions
{
insert := [][]string{
{"foo", "*", "foo", "*"},
{"foo", "*", "foo", "bar"},
{"bar", "*", "foo", "bar"},
}
for _, v := range insert {
ixn := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: structs.TestIntention(t),
}
ixn.Intention.SourceNS = v[0]
ixn.Intention.SourceName = v[1]
ixn.Intention.DestinationNS = v[2]
ixn.Intention.DestinationName = v[3]
ixn.Intention.Action = structs.IntentionActionDeny
// Create
var reply string
require.Nil(a.RPC("Intention.Apply", &ixn, &reply))
}
}
// Request matching intention
{
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/test?source=foo/bar&destination=foo/baz", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionCheck(resp, req)
require.Nil(err)
value := obj.(*structs.IntentionQueryCheckResponse)
require.False(value.Allowed)
}
// Request non-matching intention
{
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/test?source=foo/bar&destination=bar/qux", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionCheck(resp, req)
require.Nil(err)
value := obj.(*structs.IntentionQueryCheckResponse)
require.True(value.Allowed)
}
}
func TestIntentionsCheck_noSource(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/test?destination=B", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionCheck(resp, req)
require.NotNil(err)
require.Contains(err.Error(), "'source' not set")
require.Nil(obj)
}
func TestIntentionsCheck_noDestination(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Request
req, _ := http.NewRequest("GET",
"/v1/connect/intentions/test?source=B", nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionCheck(resp, req)
require.NotNil(err)
require.Contains(err.Error(), "'destination' not set")
require.Nil(obj)
}
func TestIntentionsCreate_good(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Make sure an empty list is non-nil.
args := structs.TestIntention(t)
args.SourceName = "foo"
req, _ := http.NewRequest("POST", "/v1/connect/intentions", jsonReader(args))
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionCreate(resp, req)
assert.Nil(err)
value := obj.(intentionCreateResponse)
assert.NotEqual("", value.ID)
// Read the value
{
req := &structs.IntentionQueryRequest{
Datacenter: "dc1",
IntentionID: value.ID,
}
var resp structs.IndexedIntentions
assert.Nil(a.RPC("Intention.Get", req, &resp))
assert.Len(resp.Intentions, 1)
actual := resp.Intentions[0]
assert.Equal("foo", actual.SourceName)
}
}
func TestIntentionsCreate_noBody(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// Create with no body
req, _ := http.NewRequest("POST", "/v1/connect/intentions", nil)
resp := httptest.NewRecorder()
_, err := a.srv.IntentionCreate(resp, req)
require.Error(t, err)
}
func TestIntentionsSpecificGet_good(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// The intention
ixn := structs.TestIntention(t)
// Create an intention directly
var reply string
{
req := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: ixn,
}
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
}
// Get the value
req, _ := http.NewRequest("GET", fmt.Sprintf("/v1/connect/intentions/%s", reply), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionSpecific(resp, req)
assert.Nil(err)
value := obj.(*structs.Intention)
assert.Equal(reply, value.ID)
ixn.ID = value.ID
ixn.RaftIndex = value.RaftIndex
ixn.CreatedAt, ixn.UpdatedAt = value.CreatedAt, value.UpdatedAt
assert.Equal(ixn, value)
}
func TestIntentionsSpecificUpdate_good(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// The intention
ixn := structs.TestIntention(t)
// Create an intention directly
var reply string
{
req := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: ixn,
}
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
}
// Update the intention
ixn.ID = "bogus"
ixn.SourceName = "bar"
req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/connect/intentions/%s", reply), jsonReader(ixn))
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionSpecific(resp, req)
assert.Nil(err)
value := obj.(intentionCreateResponse)
assert.Equal(reply, value.ID)
// Read the value
{
req := &structs.IntentionQueryRequest{
Datacenter: "dc1",
IntentionID: reply,
}
var resp structs.IndexedIntentions
assert.Nil(a.RPC("Intention.Get", req, &resp))
assert.Len(resp.Intentions, 1)
actual := resp.Intentions[0]
assert.Equal("bar", actual.SourceName)
}
}
func TestIntentionsSpecificDelete_good(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := NewTestAgent(t.Name(), "")
defer a.Shutdown()
// The intention
ixn := structs.TestIntention(t)
ixn.SourceName = "foo"
// Create an intention directly
var reply string
{
req := structs.IntentionRequest{
Datacenter: "dc1",
Op: structs.IntentionOpCreate,
Intention: ixn,
}
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
}
// Sanity check that the intention exists
{
req := &structs.IntentionQueryRequest{
Datacenter: "dc1",
IntentionID: reply,
}
var resp structs.IndexedIntentions
assert.Nil(a.RPC("Intention.Get", req, &resp))
assert.Len(resp.Intentions, 1)
actual := resp.Intentions[0]
assert.Equal("foo", actual.SourceName)
}
// Delete the intention
req, _ := http.NewRequest("DELETE", fmt.Sprintf("/v1/connect/intentions/%s", reply), nil)
resp := httptest.NewRecorder()
obj, err := a.srv.IntentionSpecific(resp, req)
assert.Nil(err)
assert.Equal(true, obj)
// Verify the intention is gone
{
req := &structs.IntentionQueryRequest{
Datacenter: "dc1",
IntentionID: reply,
}
var resp structs.IndexedIntentions
err := a.RPC("Intention.Get", req, &resp)
assert.NotNil(err)
assert.Contains(err.Error(), "not found")
}
}
func TestParseIntentionMatchEntry(t *testing.T) {
cases := []struct {
Input string
Expected structs.IntentionMatchEntry
Err bool
}{
{
"foo",
structs.IntentionMatchEntry{
Namespace: structs.IntentionDefaultNamespace,
Name: "foo",
},
false,
},
{
"foo/bar",
structs.IntentionMatchEntry{
Namespace: "foo",
Name: "bar",
},
false,
},
{
"foo/bar/baz",
structs.IntentionMatchEntry{},
true,
},
}
for _, tc := range cases {
t.Run(tc.Input, func(t *testing.T) {
assert := assert.New(t)
actual, err := parseIntentionMatchEntry(tc.Input)
assert.Equal(err != nil, tc.Err, err)
if err != nil {
return
}
assert.Equal(tc.Expected, actual)
})
}
}

View File

@ -3,6 +3,7 @@ package local
import (
"fmt"
"log"
"math/rand"
"reflect"
"strconv"
"strings"
@ -10,6 +11,8 @@ import (
"sync/atomic"
"time"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/agent/token"
@ -27,6 +30,8 @@ type Config struct {
NodeID types.NodeID
NodeName string
TaggedAddresses map[string]string
ProxyBindMinPort int
ProxyBindMaxPort int
}
// ServiceState describes the state of a service record.
@ -107,6 +112,32 @@ type rpc interface {
RPC(method string, args interface{}, reply interface{}) error
}
// ManagedProxy represents the local state for a registered proxy instance.
type ManagedProxy struct {
Proxy *structs.ConnectManagedProxy
// ProxyToken is a special local-only security token that grants the bearer
// access to the proxy's config as well as allowing it to request certificates
// on behalf of the target service. Certain connect endpoints will validate
// against this token and if it matches will then use the target service's
// registration token to actually authenticate the upstream RPC on behalf of
// the service. This token is passed securely to the proxy process via ENV
// vars and should never be exposed any other way. Unmanaged proxies will
// never see this and need to use service-scoped ACL tokens distributed
// externally. It is persisted in the local state to allow authenticating
// running proxies after the agent restarts.
//
// TODO(banks): In theory we only need to persist this at all to _validate_
// which means we could keep only a hash in memory and on disk and only pass
// the actual token to the process on startup. That would require a bit of
// refactoring though to have the required interaction with the proxy manager.
ProxyToken string
// WatchCh is a close-only chan that is closed when the proxy is removed or
// updated.
WatchCh chan struct{}
}
// State is used to represent the node's services,
// and checks. We use it to perform anti-entropy with the
// catalog representation
@ -150,9 +181,23 @@ type State struct {
// tokens contains the ACL tokens
tokens *token.Store
// managedProxies is a map of all managed connect proxies registered locally on
// this agent. This is NOT kept in sync with servers since it's agent-local
// config only. Proxy instances have separate service registrations in the
// services map above which are kept in sync via anti-entropy. Un-managed
// proxies (that registered themselves separately from the service
// registration) do not appear here as the agent doesn't need to manage their
// process nor config. The _do_ still exist in services above though as
// services with Kind == connect-proxy.
//
// managedProxyHandlers is a map of registered channel listeners that
// are sent a message each time a proxy changes via Add or RemoveProxy.
managedProxies map[string]*ManagedProxy
managedProxyHandlers map[chan<- struct{}]struct{}
}
// NewLocalState creates a new local state for the agent.
// NewState creates a new local state for the agent.
func NewState(c Config, lg *log.Logger, tokens *token.Store) *State {
l := &State{
config: c,
@ -161,6 +206,8 @@ func NewState(c Config, lg *log.Logger, tokens *token.Store) *State {
checks: make(map[types.CheckID]*CheckState),
metadata: make(map[string]string),
tokens: tokens,
managedProxies: make(map[string]*ManagedProxy),
managedProxyHandlers: make(map[chan<- struct{}]struct{}),
}
l.SetDiscardCheckOutput(c.DiscardCheckOutput)
return l
@ -529,6 +576,204 @@ func (l *State) CriticalCheckStates() map[types.CheckID]*CheckState {
return m
}
// AddProxy is used to add a connect proxy entry to the local state. This
// assumes the proxy's NodeService is already registered via Agent.AddService
// (since that has to do other book keeping). The token passed here is the ACL
// token the service used to register itself so must have write on service
// record. AddProxy returns the newly added proxy and an error.
//
// The restoredProxyToken argument should only be used when restoring proxy
// definitions from disk; new proxies must leave it blank to get a new token
// assigned. We need to restore from disk to enable to continue authenticating
// running proxies that already had that credential injected.
func (l *State) AddProxy(proxy *structs.ConnectManagedProxy, token,
restoredProxyToken string) (*ManagedProxy, error) {
if proxy == nil {
return nil, fmt.Errorf("no proxy")
}
// Lookup the local service
target := l.Service(proxy.TargetServiceID)
if target == nil {
return nil, fmt.Errorf("target service ID %s not registered",
proxy.TargetServiceID)
}
// Get bind info from config
cfg, err := proxy.ParseConfig()
if err != nil {
return nil, err
}
// Construct almost all of the NodeService that needs to be registered by the
// caller outside of the lock.
svc := &structs.NodeService{
Kind: structs.ServiceKindConnectProxy,
ID: target.ID + "-proxy",
Service: target.ID + "-proxy",
ProxyDestination: target.Service,
Address: cfg.BindAddress,
Port: cfg.BindPort,
}
// Lock now. We can't lock earlier as l.Service would deadlock and shouldn't
// anyway to minimise the critical section.
l.Lock()
defer l.Unlock()
pToken := restoredProxyToken
// Does this proxy instance allready exist?
if existing, ok := l.managedProxies[svc.ID]; ok {
// Keep the existing proxy token so we don't have to restart proxy to
// re-inject token.
pToken = existing.ProxyToken
// If the user didn't explicitly change the port, use the old one instead of
// assigning new.
if svc.Port < 1 {
svc.Port = existing.Proxy.ProxyService.Port
}
} else if proxyService, ok := l.services[svc.ID]; ok {
// The proxy-service already exists so keep the port that got assigned. This
// happens on reload from disk since service definitions are reloaded first.
svc.Port = proxyService.Service.Port
}
// If this is a new instance, generate a token
if pToken == "" {
pToken, err = uuid.GenerateUUID()
if err != nil {
return nil, err
}
}
// Allocate port if needed (min and max inclusive).
rangeLen := l.config.ProxyBindMaxPort - l.config.ProxyBindMinPort + 1
if svc.Port < 1 && l.config.ProxyBindMinPort > 0 && rangeLen > 0 {
// This should be a really short list so don't bother optimising lookup yet.
OUTER:
for _, offset := range rand.Perm(rangeLen) {
p := l.config.ProxyBindMinPort + offset
// See if this port was already allocated to another proxy
for _, other := range l.managedProxies {
if other.Proxy.ProxyService.Port == p {
// allready taken, skip to next random pick in the range
continue OUTER
}
}
// We made it through all existing proxies without a match so claim this one
svc.Port = p
break
}
}
// If no ports left (or auto ports disabled) fail
if svc.Port < 1 {
return nil, fmt.Errorf("no port provided for proxy bind_port and none "+
" left in the allocated range [%d, %d]", l.config.ProxyBindMinPort,
l.config.ProxyBindMaxPort)
}
proxy.ProxyService = svc
// All set, add the proxy and return the service
if old, ok := l.managedProxies[svc.ID]; ok {
// Notify watchers of the existing proxy config that it's changing. Note
// this is safe here even before the map is updated since we still hold the
// state lock and the watcher can't re-read the new config until we return
// anyway.
close(old.WatchCh)
}
l.managedProxies[svc.ID] = &ManagedProxy{
Proxy: proxy,
ProxyToken: pToken,
WatchCh: make(chan struct{}),
}
// Notify
for ch := range l.managedProxyHandlers {
// Do not block
select {
case ch <- struct{}{}:
default:
}
}
// No need to trigger sync as proxy state is local only.
return l.managedProxies[svc.ID], nil
}
// RemoveProxy is used to remove a proxy entry from the local state.
// This returns the proxy that was removed.
func (l *State) RemoveProxy(id string) (*ManagedProxy, error) {
l.Lock()
defer l.Unlock()
p := l.managedProxies[id]
if p == nil {
return nil, fmt.Errorf("Proxy %s does not exist", id)
}
delete(l.managedProxies, id)
// Notify watchers of the existing proxy config that it's changed.
close(p.WatchCh)
// Notify
for ch := range l.managedProxyHandlers {
// Do not block
select {
case ch <- struct{}{}:
default:
}
}
// No need to trigger sync as proxy state is local only.
return p, nil
}
// Proxy returns the local proxy state.
func (l *State) Proxy(id string) *ManagedProxy {
l.RLock()
defer l.RUnlock()
return l.managedProxies[id]
}
// Proxies returns the locally registered proxies.
func (l *State) Proxies() map[string]*ManagedProxy {
l.RLock()
defer l.RUnlock()
m := make(map[string]*ManagedProxy)
for id, p := range l.managedProxies {
m[id] = p
}
return m
}
// NotifyProxy will register a channel to receive messages when the
// configuration or set of proxies changes. This will not block on
// channel send so ensure the channel has a buffer. Note that any buffer
// size is generally fine since actual data is not sent over the channel,
// so a dropped send due to a full buffer does not result in any loss of
// data. The fact that a buffer already contains a notification means that
// the receiver will still be notified that changes occurred.
//
// NOTE(mitchellh): This could be more generalized but for my use case I
// only needed proxy events. In the future if it were to be generalized I
// would add a new Notify method and remove the proxy-specific ones.
func (l *State) NotifyProxy(ch chan<- struct{}) {
l.Lock()
defer l.Unlock()
l.managedProxyHandlers[ch] = struct{}{}
}
// StopNotifyProxy will deregister a channel receiving proxy notifications.
// Pair this with all calls to NotifyProxy to clean up state.
func (l *State) StopNotifyProxy(ch chan<- struct{}) {
l.Lock()
defer l.Unlock()
delete(l.managedProxyHandlers, ch)
}
// Metadata returns the local node metadata fields that the
// agent is aware of and are being kept in sync with the server
func (l *State) Metadata() map[string]string {

View File

@ -3,10 +3,16 @@ package local_test
import (
"errors"
"fmt"
"log"
"os"
"reflect"
"testing"
"time"
"github.com/hashicorp/go-memdb"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent"
"github.com/hashicorp/consul/agent/config"
"github.com/hashicorp/consul/agent/local"
@ -16,6 +22,7 @@ import (
"github.com/hashicorp/consul/testutil/retry"
"github.com/hashicorp/consul/types"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/assert"
)
func TestAgentAntiEntropy_Services(t *testing.T) {
@ -224,6 +231,145 @@ func TestAgentAntiEntropy_Services(t *testing.T) {
}
}
func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
t.Parallel()
assert := assert.New(t)
a := &agent.TestAgent{Name: t.Name()}
a.Start()
defer a.Shutdown()
// Register node info
var out struct{}
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: a.Config.NodeName,
Address: "127.0.0.1",
}
// Exists both same (noop)
srv1 := &structs.NodeService{
Kind: structs.ServiceKindConnectProxy,
ID: "mysql-proxy",
Service: "mysql-proxy",
Port: 5000,
ProxyDestination: "db",
}
a.State.AddService(srv1, "")
args.Service = srv1
assert.Nil(a.RPC("Catalog.Register", args, &out))
// Exists both, different (update)
srv2 := &structs.NodeService{
ID: "redis-proxy",
Service: "redis-proxy",
Port: 8000,
Kind: structs.ServiceKindConnectProxy,
ProxyDestination: "redis",
}
a.State.AddService(srv2, "")
srv2_mod := new(structs.NodeService)
*srv2_mod = *srv2
srv2_mod.Port = 9000
args.Service = srv2_mod
assert.Nil(a.RPC("Catalog.Register", args, &out))
// Exists local (create)
srv3 := &structs.NodeService{
ID: "web-proxy",
Service: "web-proxy",
Port: 80,
Kind: structs.ServiceKindConnectProxy,
ProxyDestination: "web",
}
a.State.AddService(srv3, "")
// Exists remote (delete)
srv4 := &structs.NodeService{
ID: "lb-proxy",
Service: "lb-proxy",
Port: 443,
Kind: structs.ServiceKindConnectProxy,
ProxyDestination: "lb",
}
args.Service = srv4
assert.Nil(a.RPC("Catalog.Register", args, &out))
// Exists local, in sync, remote missing (create)
srv5 := &structs.NodeService{
ID: "cache-proxy",
Service: "cache-proxy",
Port: 11211,
Kind: structs.ServiceKindConnectProxy,
ProxyDestination: "cache-proxy",
}
a.State.SetServiceState(&local.ServiceState{
Service: srv5,
InSync: true,
})
assert.Nil(a.State.SyncFull())
var services structs.IndexedNodeServices
req := structs.NodeSpecificRequest{
Datacenter: "dc1",
Node: a.Config.NodeName,
}
assert.Nil(a.RPC("Catalog.NodeServices", &req, &services))
// We should have 5 services (consul included)
assert.Len(services.NodeServices.Services, 5)
// All the services should match
for id, serv := range services.NodeServices.Services {
serv.CreateIndex, serv.ModifyIndex = 0, 0
switch id {
case "mysql-proxy":
assert.Equal(srv1, serv)
case "redis-proxy":
assert.Equal(srv2, serv)
case "web-proxy":
assert.Equal(srv3, serv)
case "cache-proxy":
assert.Equal(srv5, serv)
case structs.ConsulServiceID:
// ignore
default:
t.Fatalf("unexpected service: %v", id)
}
}
assert.Nil(servicesInSync(a.State, 4))
// Remove one of the services
a.State.RemoveService("cache-proxy")
assert.Nil(a.State.SyncFull())
assert.Nil(a.RPC("Catalog.NodeServices", &req, &services))
// We should have 4 services (consul included)
assert.Len(services.NodeServices.Services, 4)
// All the services should match
for id, serv := range services.NodeServices.Services {
serv.CreateIndex, serv.ModifyIndex = 0, 0
switch id {
case "mysql-proxy":
assert.Equal(srv1, serv)
case "redis-proxy":
assert.Equal(srv2, serv)
case "web-proxy":
assert.Equal(srv3, serv)
case structs.ConsulServiceID:
// ignore
default:
t.Fatalf("unexpected service: %v", id)
}
}
assert.Nil(servicesInSync(a.State, 3))
}
func TestAgentAntiEntropy_EnableTagOverride(t *testing.T) {
t.Parallel()
a := &agent.TestAgent{Name: t.Name()}
@ -1524,3 +1670,263 @@ func checksInSync(state *local.State, wantChecks int) error {
}
return nil
}
func TestStateProxyManagement(t *testing.T) {
t.Parallel()
state := local.NewState(local.Config{
ProxyBindMinPort: 20000,
ProxyBindMaxPort: 20001,
}, log.New(os.Stderr, "", log.LstdFlags), &token.Store{})
// Stub state syncing
state.TriggerSyncChanges = func() {}
p1 := structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"consul", "connect", "proxy"},
TargetServiceID: "web",
}
require := require.New(t)
assert := assert.New(t)
_, err := state.AddProxy(&p1, "fake-token", "")
require.Error(err, "should fail as the target service isn't registered")
// Sanity check done, lets add a couple of target services to the state
err = state.AddService(&structs.NodeService{
Service: "web",
}, "fake-token-web")
require.NoError(err)
err = state.AddService(&structs.NodeService{
Service: "cache",
}, "fake-token-cache")
require.NoError(err)
require.NoError(err)
err = state.AddService(&structs.NodeService{
Service: "db",
}, "fake-token-db")
require.NoError(err)
// Should work now
pstate, err := state.AddProxy(&p1, "fake-token", "")
require.NoError(err)
svc := pstate.Proxy.ProxyService
assert.Equal("web-proxy", svc.ID)
assert.Equal("web-proxy", svc.Service)
assert.Equal(structs.ServiceKindConnectProxy, svc.Kind)
assert.Equal("web", svc.ProxyDestination)
assert.Equal("", svc.Address, "should have empty address by default")
// Port is non-deterministic but could be either of 20000 or 20001
assert.Contains([]int{20000, 20001}, svc.Port)
{
// Re-registering same proxy again should not pick a random port but re-use
// the assigned one. It should also keep the same proxy token since we don't
// want to force restart for config change.
pstateDup, err := state.AddProxy(&p1, "fake-token", "")
require.NoError(err)
svcDup := pstateDup.Proxy.ProxyService
assert.Equal("web-proxy", svcDup.ID)
assert.Equal("web-proxy", svcDup.Service)
assert.Equal(structs.ServiceKindConnectProxy, svcDup.Kind)
assert.Equal("web", svcDup.ProxyDestination)
assert.Equal("", svcDup.Address, "should have empty address by default")
// Port must be same as before
assert.Equal(svc.Port, svcDup.Port)
// Same ProxyToken
assert.Equal(pstate.ProxyToken, pstateDup.ProxyToken)
}
// Let's register a notifier now
notifyCh := make(chan struct{}, 1)
state.NotifyProxy(notifyCh)
defer state.StopNotifyProxy(notifyCh)
assert.Empty(notifyCh)
drainCh(notifyCh)
// Second proxy should claim other port
p2 := p1
p2.TargetServiceID = "cache"
pstate2, err := state.AddProxy(&p2, "fake-token", "")
require.NoError(err)
svc2 := pstate2.Proxy.ProxyService
assert.Contains([]int{20000, 20001}, svc2.Port)
assert.NotEqual(svc.Port, svc2.Port)
// Should have a notification
assert.NotEmpty(notifyCh)
drainCh(notifyCh)
// Store this for later
p2token := state.Proxy(svc2.ID).ProxyToken
// Third proxy should fail as all ports are used
p3 := p1
p3.TargetServiceID = "db"
_, err = state.AddProxy(&p3, "fake-token", "")
require.Error(err)
// Should have a notification but we'll do nothing so that the next
// receive should block (we set cap == 1 above)
// But if we set a port explicitly it should be OK
p3.Config = map[string]interface{}{
"bind_port": 1234,
"bind_address": "0.0.0.0",
}
pstate3, err := state.AddProxy(&p3, "fake-token", "")
require.NoError(err)
svc3 := pstate3.Proxy.ProxyService
require.Equal("0.0.0.0", svc3.Address)
require.Equal(1234, svc3.Port)
// Should have a notification
assert.NotEmpty(notifyCh)
drainCh(notifyCh)
// Update config of an already registered proxy should work
p3updated := p3
p3updated.Config["foo"] = "bar"
// Setup multiple watchers who should all witness the change
gotP3 := state.Proxy(svc3.ID)
require.NotNil(gotP3)
var ws memdb.WatchSet
ws.Add(gotP3.WatchCh)
pstate3, err = state.AddProxy(&p3updated, "fake-token", "")
require.NoError(err)
svc3 = pstate3.Proxy.ProxyService
require.Equal("0.0.0.0", svc3.Address)
require.Equal(1234, svc3.Port)
gotProxy3 := state.Proxy(svc3.ID)
require.NotNil(gotProxy3)
require.Equal(p3updated.Config, gotProxy3.Proxy.Config)
assert.False(ws.Watch(time.After(500*time.Millisecond)),
"watch should have fired so ws.Watch should not timeout")
drainCh(notifyCh)
// Remove one of the auto-assigned proxies
_, err = state.RemoveProxy(svc2.ID)
require.NoError(err)
// Should have a notification
assert.NotEmpty(notifyCh)
drainCh(notifyCh)
// Should be able to create a new proxy for that service with the port (it
// should have been "freed").
p4 := p2
pstate4, err := state.AddProxy(&p4, "fake-token", "")
require.NoError(err)
svc4 := pstate4.Proxy.ProxyService
assert.Contains([]int{20000, 20001}, svc2.Port)
assert.Equal(svc4.Port, svc2.Port, "should get the same port back that we freed")
// Remove a proxy that doesn't exist should error
_, err = state.RemoveProxy("nope")
require.Error(err)
assert.Equal(&p4, state.Proxy(p4.ProxyService.ID).Proxy,
"should fetch the right proxy details")
assert.Nil(state.Proxy("nope"))
proxies := state.Proxies()
assert.Len(proxies, 3)
assert.Equal(&p1, proxies[svc.ID].Proxy)
assert.Equal(&p4, proxies[svc4.ID].Proxy)
assert.Equal(&p3, proxies[svc3.ID].Proxy)
tokens := make([]string, 4)
tokens[0] = state.Proxy(svc.ID).ProxyToken
// p2 not registered anymore but lets make sure p4 got a new token when it
// re-registered with same ID.
tokens[1] = p2token
tokens[2] = state.Proxy(svc2.ID).ProxyToken
tokens[3] = state.Proxy(svc3.ID).ProxyToken
// Quick check all are distinct
for i := 0; i < len(tokens)-1; i++ {
assert.Len(tokens[i], 36) // Sanity check for UUIDish thing.
for j := i + 1; j < len(tokens); j++ {
assert.NotEqual(tokens[i], tokens[j], "tokens for proxy %d and %d match",
i+1, j+1)
}
}
}
// Tests the logic for retaining tokens and ports through restore (i.e.
// proxy-service already restored and token passed in externally)
func TestStateProxyRestore(t *testing.T) {
t.Parallel()
state := local.NewState(local.Config{
// Wide random range to make it very unlikely to pass by chance
ProxyBindMinPort: 10000,
ProxyBindMaxPort: 20000,
}, log.New(os.Stderr, "", log.LstdFlags), &token.Store{})
// Stub state syncing
state.TriggerSyncChanges = func() {}
webSvc := structs.NodeService{
Service: "web",
}
p1 := structs.ConnectManagedProxy{
ExecMode: structs.ProxyExecModeDaemon,
Command: []string{"consul", "connect", "proxy"},
TargetServiceID: "web",
}
p2 := p1
require := require.New(t)
assert := assert.New(t)
// Add a target service
require.NoError(state.AddService(&webSvc, "fake-token-web"))
// Add the proxy for first time to get the proper service definition to
// register
pstate, err := state.AddProxy(&p1, "fake-token", "")
require.NoError(err)
// Now start again with a brand new state
state2 := local.NewState(local.Config{
// Wide random range to make it very unlikely to pass by chance
ProxyBindMinPort: 10000,
ProxyBindMaxPort: 20000,
}, log.New(os.Stderr, "", log.LstdFlags), &token.Store{})
// Stub state syncing
state2.TriggerSyncChanges = func() {}
// Register the target service
require.NoError(state2.AddService(&webSvc, "fake-token-web"))
// "Restore" the proxy service
require.NoError(state.AddService(p1.ProxyService, "fake-token-web"))
// Now we can AddProxy with the "restored" token
pstate2, err := state.AddProxy(&p2, "fake-token", pstate.ProxyToken)
require.NoError(err)
// Check it still has the same port and token as before
assert.Equal(pstate.ProxyToken, pstate2.ProxyToken)
assert.Equal(p1.ProxyService.Port, p2.ProxyService.Port)
}
// drainCh drains a channel by reading messages until it would block.
func drainCh(ch chan struct{}) {
for {
select {
case <-ch:
default:
return
}
}
}

19
agent/local/testing.go Normal file
View File

@ -0,0 +1,19 @@
package local
import (
"log"
"os"
"github.com/hashicorp/consul/agent/token"
"github.com/mitchellh/go-testing-interface"
)
// TestState returns a configured *State for testing.
func TestState(t testing.T) *State {
result := NewState(Config{
ProxyBindMinPort: 20000,
ProxyBindMaxPort: 20500,
}, log.New(os.Stderr, "", log.LstdFlags), &token.Store{})
result.TriggerSyncChanges = func() {}
return result
}

Some files were not shown because too many files have changed in this diff Show More