Merge remote-tracking branch 'connect/f-connect'
This commit is contained in:
commit
1da3c42867
|
@ -2,12 +2,15 @@
|
|||
|
||||
FEATURES:
|
||||
|
||||
* **Connect Feature Beta**: This version includes a major new feature for Consul named Connect. Connect enables secure service-to-service communication with automatic TLS encryption and identity-based authorization. For more details and links to demos and getting started guides, see the [announcement blog post](https://www.hashicorp.com/blog/consul-1-2-service-mesh).
|
||||
* Connect must be enabled explicitly in configuration so upgrading a cluster will not affect any existing functionality until it's enabled.
|
||||
* This is a Beta release, we don't recommend enabling this in production yet. Please see the documentation for more information.
|
||||
* dns: Enable PTR record lookups for services with IPs that have no registered node [[PR-4083](https://github.com/hashicorp/consul/pull/4083)]
|
||||
* ui: Default to serving the new UI. Setting the `CONSUL_UI_LEGACY` environment variable to `1` of `true` will revert to serving the old UI
|
||||
|
||||
IMPROVEMENTS:
|
||||
|
||||
* agent: A Consul user-agent string is now sent to providers when making retry-join requests [GH-4013](https://github.com/hashicorp/consul/pull/4013)
|
||||
* agent: A Consul user-agent string is now sent to providers when making retry-join requests [[GH-4013](https://github.com/hashicorp/consul/issues/4013)](https://github.com/hashicorp/consul/pull/4013)
|
||||
* client: Add metrics for failed RPCs [PR-4220](https://github.com/hashicorp/consul/pull/4220)
|
||||
* agent: Add configuration entry to control including TXT records for node meta in DNS responses [PR-4215](https://github.com/hashicorp/consul/pull/4215)
|
||||
* client: Make RPC rate limit configuration reloadable [GH-4012](https://github.com/hashicorp/consul/issues/4012)
|
||||
|
|
14
GNUmakefile
14
GNUmakefile
|
@ -145,10 +145,16 @@ test: other-consul dev-build vet
|
|||
@# _something_ to stop them terminating us due to inactivity...
|
||||
{ go test $(GOTEST_FLAGS) -tags '$(GOTAGS)' -timeout 5m $(GOTEST_PKGS) 2>&1 ; echo $$? > exit-code ; } | tee test.log | egrep '^(ok|FAIL)\s*github.com/hashicorp/consul'
|
||||
@echo "Exit code: $$(cat exit-code)" >> test.log
|
||||
@grep -A5 'DATA RACE' test.log || true
|
||||
@grep -A10 'panic: test timed out' test.log || true
|
||||
@grep -A1 -- '--- SKIP:' test.log || true
|
||||
@grep -A1 -- '--- FAIL:' test.log || true
|
||||
@# This prints all the race report between ====== lines
|
||||
@awk '/^WARNING: DATA RACE/ {do_print=1; print "=================="} do_print==1 {print} /^={10,}/ {do_print=0}' test.log || true
|
||||
@grep -A10 'panic: ' test.log || true
|
||||
@# Prints all the failure output until the next non-indented line - testify
|
||||
@# helpers often output multiple lines for readability but useless if we can't
|
||||
@# see them. Un-intuitive order of matches is necessary. No || true because
|
||||
@# awk always returns true even if there is no match and it breaks non-bash
|
||||
@# shells locally.
|
||||
@awk '/^[^[:space:]]/ {do_print=0} /--- SKIP/ {do_print=1} do_print==1 {print}' test.log
|
||||
@awk '/^[^[:space:]]/ {do_print=0} /--- FAIL/ {do_print=1} do_print==1 {print}' test.log
|
||||
@grep '^FAIL' test.log || true
|
||||
@if [ "$$(cat exit-code)" == "0" ] ; then echo "PASS" ; exit 0 ; else exit 1 ; fi
|
||||
|
||||
|
|
105
README.md
105
README.md
|
@ -1,75 +1,66 @@
|
|||
# Consul [![Build Status](https://travis-ci.org/hashicorp/consul.svg?branch=master)](https://travis-ci.org/hashicorp/consul) [![Join the chat at https://gitter.im/hashicorp-consul/Lobby](https://badges.gitter.im/hashicorp-consul/Lobby.svg)](https://gitter.im/hashicorp-consul/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
**This is a temporary README. We'll restore the old README prior to PR upstream.**
|
||||
|
||||
* Website: https://www.consul.io
|
||||
* Chat: [Gitter](https://gitter.im/hashicorp-consul/Lobby)
|
||||
* Mailing list: [Google Groups](https://groups.google.com/group/consul-tool/)
|
||||
# Consul Connect
|
||||
|
||||
Consul is a tool for service discovery and configuration. Consul is
|
||||
distributed, highly available, and extremely scalable.
|
||||
This repository is the forked repository for Consul Connect work to happen
|
||||
in private prior to public release. This README will explain how to safely
|
||||
use this fork, how to bring in upstream changes, etc.
|
||||
|
||||
Consul provides several key features:
|
||||
## Cloning
|
||||
|
||||
* **Service Discovery** - Consul makes it simple for services to register
|
||||
themselves and to discover other services via a DNS or HTTP interface.
|
||||
External services such as SaaS providers can be registered as well.
|
||||
To use this repository, clone it into your GOPATH as usual but you must
|
||||
**rename `consul-connect` to `consul`** so that Go imports continue working
|
||||
as usual.
|
||||
|
||||
* **Health Checking** - Health Checking enables Consul to quickly alert
|
||||
operators about any issues in a cluster. The integration with service
|
||||
discovery prevents routing traffic to unhealthy hosts and enables service
|
||||
level circuit breakers.
|
||||
## Important: Never Modify Master
|
||||
|
||||
* **Key/Value Storage** - A flexible key/value store enables storing
|
||||
dynamic configuration, feature flagging, coordination, leader election and
|
||||
more. The simple HTTP API makes it easy to use anywhere.
|
||||
**NEVER MODIFY MASTER! NEVER MODIFY MASTER!**
|
||||
|
||||
* **Multi-Datacenter** - Consul is built to be datacenter aware, and can
|
||||
support any number of regions without complex configuration.
|
||||
We want to keep the "master" branch equivalent to OSS master. This will make
|
||||
rebasing easy for master. Instead, we'll use the branch `f-connect`. All
|
||||
feature branches should branch from `f-connect` and make PRs against
|
||||
`f-connect`.
|
||||
|
||||
Consul runs on Linux, Mac OS X, FreeBSD, Solaris, and Windows. A commercial
|
||||
version called [Consul Enterprise](https://www.hashicorp.com/products/consul)
|
||||
is also available.
|
||||
When we're ready to merge back to upstream, we can make a single mega PR
|
||||
merging `f-connect` into OSS master. This way we don't have a sudden mega
|
||||
push to master on OSS.
|
||||
|
||||
## Quick Start
|
||||
## Creating a Feature Branch
|
||||
|
||||
An extensive quick start is viewable on the Consul website:
|
||||
To create a feature branch, branch from `f-connect`:
|
||||
|
||||
https://www.consul.io/intro/getting-started/install.html
|
||||
|
||||
## Documentation
|
||||
|
||||
Full, comprehensive documentation is viewable on the Consul website:
|
||||
|
||||
https://www.consul.io/docs
|
||||
|
||||
## Developing Consul
|
||||
|
||||
If you wish to work on Consul itself, you'll first need [Go](https://golang.org)
|
||||
installed (version 1.9+ is _required_). Make sure you have Go properly installed,
|
||||
including setting up your [GOPATH](https://golang.org/doc/code.html#GOPATH).
|
||||
|
||||
Next, clone this repository into `$GOPATH/src/github.com/hashicorp/consul` and
|
||||
then just type `make`. In a few moments, you'll have a working `consul` executable:
|
||||
|
||||
```
|
||||
$ make
|
||||
...
|
||||
$ bin/consul
|
||||
...
|
||||
```sh
|
||||
git checkout f-connect
|
||||
git checkout -b my-new-branch
|
||||
```
|
||||
|
||||
*Note: `make` will build all os/architecture combinations. Set the environment variable `CONSUL_DEV=1` to build it just for your local machine's os/architecture, or use `make dev`.*
|
||||
All merged Connect features will be in `f-connect`, so you want to work
|
||||
from that branch. When making a PR for your feature branch, target the
|
||||
`f-connect` branch as the merge target. You can do this by using the dropdowns
|
||||
in the GitHub UI when creating a PR.
|
||||
|
||||
*Note: `make` will also place a copy of the binary in the first part of your `$GOPATH`.*
|
||||
## Syncing Upstream
|
||||
|
||||
You can run tests by typing `make test`. The test suite may fail if
|
||||
over-parallelized, so if you are seeing stochastic failures try
|
||||
`GOTEST_FLAGS="-p 2 -parallel 2" make test`.
|
||||
First update our local master:
|
||||
|
||||
If you make any changes to the code, run `make format` in order to automatically
|
||||
format the code according to Go standards.
|
||||
```sh
|
||||
# This has to happen on forked master
|
||||
git checkout master
|
||||
|
||||
## Vendoring
|
||||
# Add upstream to OSS Consul
|
||||
git remote add upstream https://github.com/hashicorp/consul.git
|
||||
|
||||
Consul currently uses [govendor](https://github.com/kardianos/govendor) for
|
||||
vendoring and [vendorfmt](https://github.com/magiconair/vendorfmt) for formatting
|
||||
`vendor.json` to a more merge-friendly "one line per package" format.
|
||||
# Fetch it
|
||||
git fetch upstream
|
||||
|
||||
# Rebase forked master onto upstream. This should have no changes since
|
||||
# we're never modifying master.
|
||||
git rebase upstream master
|
||||
```
|
||||
|
||||
Next, update the `f-connect` branch:
|
||||
|
||||
```sh
|
||||
git checkout f-connect
|
||||
git rebase origin master
|
||||
```
|
||||
|
|
91
acl/acl.go
91
acl/acl.go
|
@ -60,6 +60,17 @@ type ACL interface {
|
|||
// EventWrite determines if a specific event may be fired.
|
||||
EventWrite(string) bool
|
||||
|
||||
// IntentionDefaultAllow determines the default authorized behavior
|
||||
// when no intentions match a Connect request.
|
||||
IntentionDefaultAllow() bool
|
||||
|
||||
// IntentionRead determines if a specific intention can be read.
|
||||
IntentionRead(string) bool
|
||||
|
||||
// IntentionWrite determines if a specific intention can be
|
||||
// created, modified, or deleted.
|
||||
IntentionWrite(string) bool
|
||||
|
||||
// KeyList checks for permission to list keys under a prefix
|
||||
KeyList(string) bool
|
||||
|
||||
|
@ -154,6 +165,18 @@ func (s *StaticACL) EventWrite(string) bool {
|
|||
return s.defaultAllow
|
||||
}
|
||||
|
||||
func (s *StaticACL) IntentionDefaultAllow() bool {
|
||||
return s.defaultAllow
|
||||
}
|
||||
|
||||
func (s *StaticACL) IntentionRead(string) bool {
|
||||
return s.defaultAllow
|
||||
}
|
||||
|
||||
func (s *StaticACL) IntentionWrite(string) bool {
|
||||
return s.defaultAllow
|
||||
}
|
||||
|
||||
func (s *StaticACL) KeyRead(string) bool {
|
||||
return s.defaultAllow
|
||||
}
|
||||
|
@ -275,6 +298,9 @@ type PolicyACL struct {
|
|||
// agentRules contains the agent policies
|
||||
agentRules *radix.Tree
|
||||
|
||||
// intentionRules contains the service intention policies
|
||||
intentionRules *radix.Tree
|
||||
|
||||
// keyRules contains the key policies
|
||||
keyRules *radix.Tree
|
||||
|
||||
|
@ -308,6 +334,7 @@ func New(parent ACL, policy *Policy, sentinel sentinel.Evaluator) (*PolicyACL, e
|
|||
p := &PolicyACL{
|
||||
parent: parent,
|
||||
agentRules: radix.New(),
|
||||
intentionRules: radix.New(),
|
||||
keyRules: radix.New(),
|
||||
nodeRules: radix.New(),
|
||||
serviceRules: radix.New(),
|
||||
|
@ -347,6 +374,25 @@ func New(parent ACL, policy *Policy, sentinel sentinel.Evaluator) (*PolicyACL, e
|
|||
sentinelPolicy: sp.Sentinel,
|
||||
}
|
||||
p.serviceRules.Insert(sp.Name, policyRule)
|
||||
|
||||
// Determine the intention. The intention could be blank (not set).
|
||||
// If the intention is not set, the value depends on the value of
|
||||
// the service policy.
|
||||
intention := sp.Intentions
|
||||
if intention == "" {
|
||||
switch sp.Policy {
|
||||
case PolicyRead, PolicyWrite:
|
||||
intention = PolicyRead
|
||||
default:
|
||||
intention = PolicyDeny
|
||||
}
|
||||
}
|
||||
|
||||
policyRule = PolicyRule{
|
||||
aclPolicy: intention,
|
||||
sentinelPolicy: sp.Sentinel,
|
||||
}
|
||||
p.intentionRules.Insert(sp.Name, policyRule)
|
||||
}
|
||||
|
||||
// Load the session policy
|
||||
|
@ -455,6 +501,51 @@ func (p *PolicyACL) EventWrite(name string) bool {
|
|||
return p.parent.EventWrite(name)
|
||||
}
|
||||
|
||||
// IntentionDefaultAllow returns whether the default behavior when there are
|
||||
// no matching intentions is to allow or deny.
|
||||
func (p *PolicyACL) IntentionDefaultAllow() bool {
|
||||
// We always go up, this can't be determined by a policy.
|
||||
return p.parent.IntentionDefaultAllow()
|
||||
}
|
||||
|
||||
// IntentionRead checks if writing (creating, updating, or deleting) of an
|
||||
// intention is allowed.
|
||||
func (p *PolicyACL) IntentionRead(prefix string) bool {
|
||||
// Check for an exact rule or catch-all
|
||||
_, rule, ok := p.intentionRules.LongestPrefix(prefix)
|
||||
if ok {
|
||||
pr := rule.(PolicyRule)
|
||||
switch pr.aclPolicy {
|
||||
case PolicyRead, PolicyWrite:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// No matching rule, use the parent.
|
||||
return p.parent.IntentionRead(prefix)
|
||||
}
|
||||
|
||||
// IntentionWrite checks if writing (creating, updating, or deleting) of an
|
||||
// intention is allowed.
|
||||
func (p *PolicyACL) IntentionWrite(prefix string) bool {
|
||||
// Check for an exact rule or catch-all
|
||||
_, rule, ok := p.intentionRules.LongestPrefix(prefix)
|
||||
if ok {
|
||||
pr := rule.(PolicyRule)
|
||||
switch pr.aclPolicy {
|
||||
case PolicyWrite:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// No matching rule, use the parent.
|
||||
return p.parent.IntentionWrite(prefix)
|
||||
}
|
||||
|
||||
// KeyRead returns if a key is allowed to be read
|
||||
func (p *PolicyACL) KeyRead(key string) bool {
|
||||
// Look for a matching rule
|
||||
|
|
|
@ -53,6 +53,12 @@ func TestStaticACL(t *testing.T) {
|
|||
if !all.EventWrite("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !all.IntentionDefaultAllow() {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !all.IntentionWrite("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !all.KeyRead("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
|
@ -123,6 +129,12 @@ func TestStaticACL(t *testing.T) {
|
|||
if none.EventWrite("") {
|
||||
t.Fatalf("should not allow")
|
||||
}
|
||||
if none.IntentionDefaultAllow() {
|
||||
t.Fatalf("should not allow")
|
||||
}
|
||||
if none.IntentionWrite("foo") {
|
||||
t.Fatalf("should not allow")
|
||||
}
|
||||
if none.KeyRead("foobar") {
|
||||
t.Fatalf("should not allow")
|
||||
}
|
||||
|
@ -187,6 +199,12 @@ func TestStaticACL(t *testing.T) {
|
|||
if !manage.EventWrite("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !manage.IntentionDefaultAllow() {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !manage.IntentionWrite("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
if !manage.KeyRead("foobar") {
|
||||
t.Fatalf("should allow")
|
||||
}
|
||||
|
@ -307,6 +325,12 @@ func TestPolicyACL(t *testing.T) {
|
|||
&ServicePolicy{
|
||||
Name: "barfoo",
|
||||
Policy: PolicyWrite,
|
||||
Intentions: PolicyWrite,
|
||||
},
|
||||
&ServicePolicy{
|
||||
Name: "intbaz",
|
||||
Policy: PolicyWrite,
|
||||
Intentions: PolicyDeny,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -344,6 +368,31 @@ func TestPolicyACL(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test the intentions
|
||||
type intentioncase struct {
|
||||
inp string
|
||||
read bool
|
||||
write bool
|
||||
}
|
||||
icases := []intentioncase{
|
||||
{"other", true, false},
|
||||
{"foo", true, false},
|
||||
{"bar", false, false},
|
||||
{"foobar", true, false},
|
||||
{"barfo", false, false},
|
||||
{"barfoo", true, true},
|
||||
{"barfoo2", true, true},
|
||||
{"intbaz", false, false},
|
||||
}
|
||||
for _, c := range icases {
|
||||
if c.read != acl.IntentionRead(c.inp) {
|
||||
t.Fatalf("Read fail: %#v", c)
|
||||
}
|
||||
if c.write != acl.IntentionWrite(c.inp) {
|
||||
t.Fatalf("Write fail: %#v", c)
|
||||
}
|
||||
}
|
||||
|
||||
// Test the services
|
||||
type servicecase struct {
|
||||
inp string
|
||||
|
@ -414,6 +463,11 @@ func TestPolicyACL(t *testing.T) {
|
|||
t.Fatalf("Prepared query fail: %#v", c)
|
||||
}
|
||||
}
|
||||
|
||||
// Check default intentions bubble up
|
||||
if !acl.IntentionDefaultAllow() {
|
||||
t.Fatal("should allow")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyACL_Parent(t *testing.T) {
|
||||
|
@ -567,6 +621,11 @@ func TestPolicyACL_Parent(t *testing.T) {
|
|||
if acl.Snapshot() {
|
||||
t.Fatalf("should not allow")
|
||||
}
|
||||
|
||||
// Check default intentions
|
||||
if acl.IntentionDefaultAllow() {
|
||||
t.Fatal("should not allow")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyACL_Agent(t *testing.T) {
|
||||
|
|
|
@ -73,6 +73,11 @@ type ServicePolicy struct {
|
|||
Name string `hcl:",key"`
|
||||
Policy string
|
||||
Sentinel Sentinel
|
||||
|
||||
// Intentions is the policy for intentions where this service is the
|
||||
// destination. This may be empty, in which case the Policy determines
|
||||
// the intentions policy.
|
||||
Intentions string
|
||||
}
|
||||
|
||||
func (s *ServicePolicy) GoString() string {
|
||||
|
@ -197,6 +202,9 @@ func Parse(rules string, sentinel sentinel.Evaluator) (*Policy, error) {
|
|||
if !isPolicyValid(sp.Policy) {
|
||||
return nil, fmt.Errorf("Invalid service policy: %#v", sp)
|
||||
}
|
||||
if sp.Intentions != "" && !isPolicyValid(sp.Intentions) {
|
||||
return nil, fmt.Errorf("Invalid service intentions policy: %#v", sp)
|
||||
}
|
||||
if err := isSentinelValid(sentinel, sp.Policy, sp.Sentinel); err != nil {
|
||||
return nil, fmt.Errorf("Invalid service Sentinel policy: %#v, got error:%v", sp, err)
|
||||
}
|
||||
|
|
|
@ -4,8 +4,85 @@ import (
|
|||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParse_table(t *testing.T) {
|
||||
// Note that the table tests are newer than other tests. Many of the
|
||||
// other aspects of policy parsing are tested in older tests below. New
|
||||
// parsing tests should be added to this table as its easier to maintain.
|
||||
cases := []struct {
|
||||
Name string
|
||||
Input string
|
||||
Expected *Policy
|
||||
Err string
|
||||
}{
|
||||
{
|
||||
"service no intentions",
|
||||
`
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
}
|
||||
`,
|
||||
&Policy{
|
||||
Services: []*ServicePolicy{
|
||||
{
|
||||
Name: "foo",
|
||||
Policy: "write",
|
||||
},
|
||||
},
|
||||
},
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"service intentions",
|
||||
`
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
intentions = "read"
|
||||
}
|
||||
`,
|
||||
&Policy{
|
||||
Services: []*ServicePolicy{
|
||||
{
|
||||
Name: "foo",
|
||||
Policy: "write",
|
||||
Intentions: "read",
|
||||
},
|
||||
},
|
||||
},
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"service intention: invalid value",
|
||||
`
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
intentions = "foo"
|
||||
}
|
||||
`,
|
||||
nil,
|
||||
"service intentions",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
actual, err := Parse(tc.Input, nil)
|
||||
assert.Equal(tc.Err != "", err != nil, err)
|
||||
if err != nil {
|
||||
assert.Contains(err.Error(), tc.Err)
|
||||
return
|
||||
}
|
||||
assert.Equal(tc.Expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestACLPolicy_Parse_HCL(t *testing.T) {
|
||||
inp := `
|
||||
agent "foo" {
|
||||
|
|
13
agent/acl.go
13
agent/acl.go
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
"github.com/hashicorp/consul/agent/local"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/hashicorp/golang-lru"
|
||||
|
@ -239,6 +240,18 @@ func (a *Agent) resolveToken(id string) (acl.ACL, error) {
|
|||
return a.acls.lookupACL(a, id)
|
||||
}
|
||||
|
||||
// resolveProxyToken attempts to resolve an ACL ID to a local proxy token.
|
||||
// If a local proxy isn't found with that token, nil is returned.
|
||||
func (a *Agent) resolveProxyToken(id string) *local.ManagedProxy {
|
||||
for _, p := range a.State.Proxies() {
|
||||
if p.ProxyToken == id {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// vetServiceRegister makes sure the service registration action is allowed by
|
||||
// the given token.
|
||||
func (a *Agent) vetServiceRegister(token string, service *structs.NodeService) error {
|
||||
|
|
620
agent/agent.go
620
agent/agent.go
|
@ -21,16 +21,20 @@ import (
|
|||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/ae"
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/cache-types"
|
||||
"github.com/hashicorp/consul/agent/checks"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
"github.com/hashicorp/consul/agent/consul"
|
||||
"github.com/hashicorp/consul/agent/local"
|
||||
"github.com/hashicorp/consul/agent/proxy"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/agent/systemd"
|
||||
"github.com/hashicorp/consul/agent/token"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/ipaddr"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/lib/file"
|
||||
"github.com/hashicorp/consul/logger"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/hashicorp/consul/watch"
|
||||
|
@ -46,6 +50,9 @@ const (
|
|||
// Path to save agent service definitions
|
||||
servicesDir = "services"
|
||||
|
||||
// Path to save agent proxy definitions
|
||||
proxyDir = "proxies"
|
||||
|
||||
// Path to save local agent checks
|
||||
checksDir = "checks"
|
||||
checkStateDir = "checks/state"
|
||||
|
@ -119,6 +126,9 @@ type Agent struct {
|
|||
// and the remote state.
|
||||
sync *ae.StateSyncer
|
||||
|
||||
// cache is the in-memory cache for data the Agent requests.
|
||||
cache *cache.Cache
|
||||
|
||||
// checkReapAfter maps the check ID to a timeout after which we should
|
||||
// reap its associated service
|
||||
checkReapAfter map[types.CheckID]time.Duration
|
||||
|
@ -195,6 +205,9 @@ type Agent struct {
|
|||
// be updated at runtime, so should always be used instead of going to
|
||||
// the configuration directly.
|
||||
tokens *token.Store
|
||||
|
||||
// proxyManager is the proxy process manager for managed Connect proxies.
|
||||
proxyManager *proxy.Manager
|
||||
}
|
||||
|
||||
func New(c *config.RuntimeConfig) (*Agent, error) {
|
||||
|
@ -247,6 +260,8 @@ func LocalConfig(cfg *config.RuntimeConfig) local.Config {
|
|||
NodeID: cfg.NodeID,
|
||||
NodeName: cfg.NodeName,
|
||||
TaggedAddresses: map[string]string{},
|
||||
ProxyBindMinPort: cfg.ConnectProxyBindMinPort,
|
||||
ProxyBindMaxPort: cfg.ConnectProxyBindMaxPort,
|
||||
}
|
||||
for k, v := range cfg.TaggedAddresses {
|
||||
lc.TaggedAddresses[k] = v
|
||||
|
@ -289,6 +304,9 @@ func (a *Agent) Start() error {
|
|||
// regular and on-demand state synchronizations (anti-entropy).
|
||||
a.sync = ae.NewStateSyncer(a.State, c.AEInterval, a.shutdownCh, a.logger)
|
||||
|
||||
// create the cache
|
||||
a.cache = cache.New(nil)
|
||||
|
||||
// create the config for the rpc server/client
|
||||
consulCfg, err := a.consulConfig()
|
||||
if err != nil {
|
||||
|
@ -325,10 +343,17 @@ func (a *Agent) Start() error {
|
|||
a.State.Delegate = a.delegate
|
||||
a.State.TriggerSyncChanges = a.sync.SyncChanges.Trigger
|
||||
|
||||
// Register the cache. We do this much later so the delegate is
|
||||
// populated from above.
|
||||
a.registerCache()
|
||||
|
||||
// Load checks/services/metadata.
|
||||
if err := a.loadServices(c); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.loadProxies(c); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.loadChecks(c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -336,6 +361,28 @@ func (a *Agent) Start() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// create the proxy process manager and start it. This is purposely
|
||||
// done here after the local state above is loaded in so we can have
|
||||
// a more accurate initial state view.
|
||||
if !c.ConnectTestDisableManagedProxies {
|
||||
a.proxyManager = proxy.NewManager()
|
||||
a.proxyManager.AllowRoot = a.config.ConnectProxyAllowManagedRoot
|
||||
a.proxyManager.State = a.State
|
||||
a.proxyManager.Logger = a.logger
|
||||
if a.config.DataDir != "" {
|
||||
// DataDir is required for all non-dev mode agents, but we want
|
||||
// to allow setting the data dir for demos and so on for the agent,
|
||||
// so do the check above instead.
|
||||
a.proxyManager.DataDir = filepath.Join(a.config.DataDir, "proxy")
|
||||
|
||||
// Restore from our snapshot (if it exists)
|
||||
if err := a.proxyManager.Restore(a.proxyManager.SnapshotPath()); err != nil {
|
||||
a.logger.Printf("[WARN] agent: error restoring proxy state: %s", err)
|
||||
}
|
||||
}
|
||||
go a.proxyManager.Run()
|
||||
}
|
||||
|
||||
// Start watching for critical services to deregister, based on their
|
||||
// checks.
|
||||
go a.reapServices()
|
||||
|
@ -605,6 +652,16 @@ func (a *Agent) reloadWatches(cfg *config.RuntimeConfig) error {
|
|||
return fmt.Errorf("Handler type '%s' not recognized", params["handler_type"])
|
||||
}
|
||||
|
||||
// Don't let people use connect watches via this mechanism for now as it
|
||||
// needs thought about how to do securely and shouldn't be necessary. Note
|
||||
// that if the type assertion fails an type is not a string then
|
||||
// ParseExample below will error so we don't need to handle that case.
|
||||
if typ, ok := params["type"].(string); ok {
|
||||
if strings.HasPrefix(typ, "connect_") {
|
||||
return fmt.Errorf("Watch type %s is not allowed in agent config", typ)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the watches, excluding 'handler' and 'args'
|
||||
wp, err := watch.ParseExempt(params, []string{"handler", "args"})
|
||||
if err != nil {
|
||||
|
@ -878,6 +935,47 @@ func (a *Agent) consulConfig() (*consul.Config, error) {
|
|||
base.TLSCipherSuites = a.config.TLSCipherSuites
|
||||
base.TLSPreferServerCipherSuites = a.config.TLSPreferServerCipherSuites
|
||||
|
||||
// Copy the Connect CA bootstrap config
|
||||
if a.config.ConnectEnabled {
|
||||
base.ConnectEnabled = true
|
||||
|
||||
// Allow config to specify cluster_id provided it's a valid UUID. This is
|
||||
// meant only for tests where a deterministic ID makes fixtures much simpler
|
||||
// to work with but since it's only read on initial cluster bootstrap it's not
|
||||
// that much of a liability in production. The worst a user could do is
|
||||
// configure logically separate clusters with same ID by mistake but we can
|
||||
// avoid documenting this is even an option.
|
||||
if clusterID, ok := a.config.ConnectCAConfig["cluster_id"]; ok {
|
||||
if cIDStr, ok := clusterID.(string); ok {
|
||||
if _, err := uuid.ParseUUID(cIDStr); err == nil {
|
||||
// Valid UUID configured, use that
|
||||
base.CAConfig.ClusterID = cIDStr
|
||||
}
|
||||
}
|
||||
if base.CAConfig.ClusterID == "" {
|
||||
// If the tried to specify an ID but typoed it don't ignore as they will
|
||||
// then bootstrap with a new ID and have to throw away the whole cluster
|
||||
// and start again.
|
||||
a.logger.Println("[ERR] connect CA config cluster_id specified but " +
|
||||
"is not a valid UUID, aborting startup")
|
||||
return nil, fmt.Errorf("cluster_id was supplied but was not a valid UUID")
|
||||
}
|
||||
}
|
||||
|
||||
if a.config.ConnectCAProvider != "" {
|
||||
base.CAConfig.Provider = a.config.ConnectCAProvider
|
||||
|
||||
// Merge with the default config if it's the consul provider.
|
||||
if a.config.ConnectCAProvider == "consul" {
|
||||
for k, v := range a.config.ConnectCAConfig {
|
||||
base.CAConfig.Config[k] = v
|
||||
}
|
||||
} else {
|
||||
base.CAConfig.Config = a.config.ConnectCAConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup the user event callback
|
||||
base.UserEventHandler = func(e serf.UserEvent) {
|
||||
select {
|
||||
|
@ -1010,7 +1108,7 @@ func (a *Agent) setupNodeID(config *config.RuntimeConfig) error {
|
|||
}
|
||||
|
||||
// For dev mode we have no filesystem access so just make one.
|
||||
if a.config.DevMode {
|
||||
if a.config.DataDir == "" {
|
||||
id, err := a.makeNodeID()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -1224,6 +1322,24 @@ func (a *Agent) ShutdownAgent() error {
|
|||
chk.Stop()
|
||||
}
|
||||
|
||||
// Stop the proxy manager
|
||||
if a.proxyManager != nil {
|
||||
// If persistence is disabled (implies DevMode but a subset of DevMode) then
|
||||
// don't leave the proxies running since the agent will not be able to
|
||||
// recover them later.
|
||||
if a.config.DataDir == "" {
|
||||
a.logger.Printf("[WARN] agent: dev mode disabled persistence, killing " +
|
||||
"all proxies since we can't recover them")
|
||||
if err := a.proxyManager.Kill(); err != nil {
|
||||
a.logger.Printf("[WARN] agent: error shutting down proxy manager: %s", err)
|
||||
}
|
||||
} else {
|
||||
if err := a.proxyManager.Close(); err != nil {
|
||||
a.logger.Printf("[WARN] agent: error shutting down proxy manager: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if a.delegate != nil {
|
||||
err = a.delegate.Shutdown()
|
||||
|
@ -1493,7 +1609,7 @@ func (a *Agent) persistService(service *structs.NodeService) error {
|
|||
return err
|
||||
}
|
||||
|
||||
return writeFileAtomic(svcPath, encoded)
|
||||
return file.WriteAtomic(svcPath, encoded)
|
||||
}
|
||||
|
||||
// purgeService removes a persisted service definition file from the data dir
|
||||
|
@ -1505,6 +1621,39 @@ func (a *Agent) purgeService(serviceID string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// persistedProxy is used to wrap a proxy definition and bundle it with an Proxy
|
||||
// token so we can continue to authenticate the running proxy after a restart.
|
||||
type persistedProxy struct {
|
||||
ProxyToken string
|
||||
Proxy *structs.ConnectManagedProxy
|
||||
}
|
||||
|
||||
// persistProxy saves a proxy definition to a JSON file in the data dir
|
||||
func (a *Agent) persistProxy(proxy *local.ManagedProxy) error {
|
||||
proxyPath := filepath.Join(a.config.DataDir, proxyDir,
|
||||
stringHash(proxy.Proxy.ProxyService.ID))
|
||||
|
||||
wrapped := persistedProxy{
|
||||
ProxyToken: proxy.ProxyToken,
|
||||
Proxy: proxy.Proxy,
|
||||
}
|
||||
encoded, err := json.Marshal(wrapped)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return file.WriteAtomic(proxyPath, encoded)
|
||||
}
|
||||
|
||||
// purgeProxy removes a persisted proxy definition file from the data dir
|
||||
func (a *Agent) purgeProxy(proxyID string) error {
|
||||
proxyPath := filepath.Join(a.config.DataDir, proxyDir, stringHash(proxyID))
|
||||
if _, err := os.Stat(proxyPath); err == nil {
|
||||
return os.Remove(proxyPath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// persistCheck saves a check definition to the local agent's state directory
|
||||
func (a *Agent) persistCheck(check *structs.HealthCheck, chkType *structs.CheckType) error {
|
||||
checkPath := filepath.Join(a.config.DataDir, checksDir, checkIDHash(check.CheckID))
|
||||
|
@ -1521,7 +1670,7 @@ func (a *Agent) persistCheck(check *structs.HealthCheck, chkType *structs.CheckT
|
|||
return err
|
||||
}
|
||||
|
||||
return writeFileAtomic(checkPath, encoded)
|
||||
return file.WriteAtomic(checkPath, encoded)
|
||||
}
|
||||
|
||||
// purgeCheck removes a persisted check definition file from the data dir
|
||||
|
@ -1533,43 +1682,6 @@ func (a *Agent) purgeCheck(checkID types.CheckID) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// writeFileAtomic writes the given contents to a temporary file in the same
|
||||
// directory, does an fsync and then renames the file to its real path
|
||||
func writeFileAtomic(path string, contents []byte) error {
|
||||
uuid, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tempPath := fmt.Sprintf("%s-%s.tmp", path, uuid)
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
fh, err := os.OpenFile(tempPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fh.Write(contents); err != nil {
|
||||
fh.Close()
|
||||
os.Remove(tempPath)
|
||||
return err
|
||||
}
|
||||
if err := fh.Sync(); err != nil {
|
||||
fh.Close()
|
||||
os.Remove(tempPath)
|
||||
return err
|
||||
}
|
||||
if err := fh.Close(); err != nil {
|
||||
os.Remove(tempPath)
|
||||
return err
|
||||
}
|
||||
if err := os.Rename(tempPath, path); err != nil {
|
||||
os.Remove(tempPath)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddService is used to add a service entry.
|
||||
// This entry is persistent and the agent will make a best effort to
|
||||
// ensure it is registered
|
||||
|
@ -1623,7 +1735,7 @@ func (a *Agent) AddService(service *structs.NodeService, chkTypes []*structs.Che
|
|||
a.State.AddService(service, token)
|
||||
|
||||
// Persist the service to a file
|
||||
if persist && !a.config.DevMode {
|
||||
if persist && a.config.DataDir != "" {
|
||||
if err := a.persistService(service); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1922,7 +2034,7 @@ func (a *Agent) AddCheck(check *structs.HealthCheck, chkType *structs.CheckType,
|
|||
}
|
||||
|
||||
// Persist the check
|
||||
if persist && !a.config.DevMode {
|
||||
if persist && a.config.DataDir != "" {
|
||||
return a.persistCheck(check, chkType)
|
||||
}
|
||||
|
||||
|
@ -1974,6 +2086,277 @@ func (a *Agent) RemoveCheck(checkID types.CheckID, persist bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// AddProxy adds a new local Connect Proxy instance to be managed by the agent.
|
||||
//
|
||||
// It REQUIRES that the service that is being proxied is already present in the
|
||||
// local state. Note that this is only used for agent-managed proxies so we can
|
||||
// ensure that we always make this true. For externally managed and registered
|
||||
// proxies we explicitly allow the proxy to be registered first to make
|
||||
// bootstrap ordering of a new service simpler but the same is not true here
|
||||
// since this is only ever called when setting up a _managed_ proxy which was
|
||||
// registered as part of a service registration either from config or HTTP API
|
||||
// call.
|
||||
//
|
||||
// The restoredProxyToken argument should only be used when restoring proxy
|
||||
// definitions from disk; new proxies must leave it blank to get a new token
|
||||
// assigned. We need to restore from disk to enable to continue authenticating
|
||||
// running proxies that already had that credential injected.
|
||||
func (a *Agent) AddProxy(proxy *structs.ConnectManagedProxy, persist bool,
|
||||
restoredProxyToken string) error {
|
||||
// Lookup the target service token in state if there is one.
|
||||
token := a.State.ServiceToken(proxy.TargetServiceID)
|
||||
|
||||
// Copy the basic proxy structure so it isn't modified w/ defaults
|
||||
proxyCopy := *proxy
|
||||
proxy = &proxyCopy
|
||||
if err := a.applyProxyDefaults(proxy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the proxy to local state first since we may need to assign a port which
|
||||
// needs to be coordinate under state lock. AddProxy will generate the
|
||||
// NodeService for the proxy populated with the allocated (or configured) port
|
||||
// and an ID, but it doesn't add it to the agent directly since that could
|
||||
// deadlock and we may need to coordinate adding it and persisting etc.
|
||||
proxyState, err := a.State.AddProxy(proxy, token, restoredProxyToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
proxyService := proxyState.Proxy.ProxyService
|
||||
|
||||
// Register proxy TCP check. The built in proxy doesn't listen publically
|
||||
// until it's loaded certs so this ensures we won't route traffic until it's
|
||||
// ready.
|
||||
proxyCfg, err := a.applyProxyConfigDefaults(proxyState.Proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chkTypes := []*structs.CheckType{
|
||||
&structs.CheckType{
|
||||
Name: "Connect Proxy Listening",
|
||||
TCP: fmt.Sprintf("%s:%d", proxyCfg["bind_address"],
|
||||
proxyCfg["bind_port"]),
|
||||
Interval: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
err = a.AddService(proxyService, chkTypes, persist, token)
|
||||
if err != nil {
|
||||
// Remove the state too
|
||||
a.State.RemoveProxy(proxyService.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
// Persist the proxy
|
||||
if persist && a.config.DataDir != "" {
|
||||
return a.persistProxy(proxyState)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyProxyConfigDefaults takes a *structs.ConnectManagedProxy and returns
|
||||
// it's Config map merged with any defaults from the Agent's config. It would be
|
||||
// nicer if this were defined as a method on structs.ConnectManagedProxy but we
|
||||
// can't do that because ot the import cycle it causes with agent/config.
|
||||
func (a *Agent) applyProxyConfigDefaults(p *structs.ConnectManagedProxy) (map[string]interface{}, error) {
|
||||
if p == nil || p.ProxyService == nil {
|
||||
// Should never happen but protect from panic
|
||||
return nil, fmt.Errorf("invalid proxy state")
|
||||
}
|
||||
|
||||
// Lookup the target service
|
||||
target := a.State.Service(p.TargetServiceID)
|
||||
if target == nil {
|
||||
// Can happen during deregistration race between proxy and scheduler.
|
||||
return nil, fmt.Errorf("unknown target service ID: %s", p.TargetServiceID)
|
||||
}
|
||||
|
||||
// Merge globals defaults
|
||||
config := make(map[string]interface{})
|
||||
for k, v := range a.config.ConnectProxyDefaultConfig {
|
||||
if _, ok := config[k]; !ok {
|
||||
config[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Copy config from the proxy
|
||||
for k, v := range p.Config {
|
||||
config[k] = v
|
||||
}
|
||||
|
||||
// Set defaults for anything that is still not specified but required.
|
||||
// Note that these are not included in the content hash. Since we expect
|
||||
// them to be static in general but some like the default target service
|
||||
// port might not be. In that edge case services can set that explicitly
|
||||
// when they re-register which will be caught though.
|
||||
if _, ok := config["bind_port"]; !ok {
|
||||
config["bind_port"] = p.ProxyService.Port
|
||||
}
|
||||
if _, ok := config["bind_address"]; !ok {
|
||||
// Default to binding to the same address the agent is configured to
|
||||
// bind to.
|
||||
config["bind_address"] = a.config.BindAddr.String()
|
||||
}
|
||||
if _, ok := config["local_service_address"]; !ok {
|
||||
// Default to localhost and the port the service registered with
|
||||
config["local_service_address"] = fmt.Sprintf("127.0.0.1:%d", target.Port)
|
||||
}
|
||||
|
||||
// Basic type conversions for expected types.
|
||||
if raw, ok := config["bind_port"]; ok {
|
||||
switch v := raw.(type) {
|
||||
case float64:
|
||||
// Common since HCL/JSON parse as float64
|
||||
config["bind_port"] = int(v)
|
||||
|
||||
// NOTE(mitchellh): No default case since errors and validation
|
||||
// are handled by the ServiceDefinition.Validate function.
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// applyProxyDefaults modifies the given proxy by applying any configured
|
||||
// defaults, such as the default execution mode, command, etc.
|
||||
func (a *Agent) applyProxyDefaults(proxy *structs.ConnectManagedProxy) error {
|
||||
// Set the default exec mode
|
||||
if proxy.ExecMode == structs.ProxyExecModeUnspecified {
|
||||
mode, err := structs.NewProxyExecMode(a.config.ConnectProxyDefaultExecMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
proxy.ExecMode = mode
|
||||
}
|
||||
if proxy.ExecMode == structs.ProxyExecModeUnspecified {
|
||||
proxy.ExecMode = structs.ProxyExecModeDaemon
|
||||
}
|
||||
|
||||
// Set the default command to the globally configured default
|
||||
if len(proxy.Command) == 0 {
|
||||
switch proxy.ExecMode {
|
||||
case structs.ProxyExecModeDaemon:
|
||||
proxy.Command = a.config.ConnectProxyDefaultDaemonCommand
|
||||
|
||||
case structs.ProxyExecModeScript:
|
||||
proxy.Command = a.config.ConnectProxyDefaultScriptCommand
|
||||
}
|
||||
}
|
||||
|
||||
// If there is no globally configured default we need to get the
|
||||
// default command so we can do "consul connect proxy"
|
||||
if len(proxy.Command) == 0 {
|
||||
command, err := defaultProxyCommand(a.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
proxy.Command = command
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveProxy stops and removes a local proxy instance.
|
||||
func (a *Agent) RemoveProxy(proxyID string, persist bool) error {
|
||||
// Validate proxyID
|
||||
if proxyID == "" {
|
||||
return fmt.Errorf("proxyID missing")
|
||||
}
|
||||
|
||||
// Remove the proxy from the local state
|
||||
p, err := a.State.RemoveProxy(proxyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the proxy service as well. The proxy ID is also the ID
|
||||
// of the servie, but we might as well use the service pointer.
|
||||
if err := a.RemoveService(p.Proxy.ProxyService.ID, persist); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if persist && a.config.DataDir != "" {
|
||||
return a.purgeProxy(proxyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyProxyToken takes a token and attempts to verify it against the
|
||||
// targetService name. If targetProxy is specified, then the local proxy token
|
||||
// must exactly match the given proxy ID. cert, config, etc.).
|
||||
//
|
||||
// The given token may be a local-only proxy token or it may be an ACL token. We
|
||||
// will attempt to verify the local proxy token first.
|
||||
//
|
||||
// The effective ACL token is returned along with a boolean which is true if the
|
||||
// match was against a proxy token rather than an ACL token, and any error. In
|
||||
// the case the token matches a proxy token, then the ACL token used to register
|
||||
// that proxy's target service is returned for use in any RPC calls the proxy
|
||||
// needs to make on behalf of that service. If the token was an ACL token
|
||||
// already then it is always returned. Provided error is nil, a valid ACL token
|
||||
// is always returned.
|
||||
func (a *Agent) verifyProxyToken(token, targetService,
|
||||
targetProxy string) (string, bool, error) {
|
||||
// If we specify a target proxy, we look up that proxy directly. Otherwise,
|
||||
// we resolve with any proxy we can find.
|
||||
var proxy *local.ManagedProxy
|
||||
if targetProxy != "" {
|
||||
proxy = a.State.Proxy(targetProxy)
|
||||
if proxy == nil {
|
||||
return "", false, fmt.Errorf("unknown proxy service ID: %q", targetProxy)
|
||||
}
|
||||
|
||||
// If the token DOESN'T match, then we reset the proxy which will
|
||||
// cause the logic below to fall back to normal ACLs. Otherwise,
|
||||
// we keep the proxy set because we also have to verify that the
|
||||
// target service matches on the proxy.
|
||||
if token != proxy.ProxyToken {
|
||||
proxy = nil
|
||||
}
|
||||
} else {
|
||||
proxy = a.resolveProxyToken(token)
|
||||
}
|
||||
|
||||
// The existence of a token isn't enough, we also need to verify
|
||||
// that the service name of the matching proxy matches our target
|
||||
// service.
|
||||
if proxy != nil {
|
||||
// Get the target service since we only have the name. The nil
|
||||
// check below should never be true since a proxy token always
|
||||
// represents the existence of a local service.
|
||||
target := a.State.Service(proxy.Proxy.TargetServiceID)
|
||||
if target == nil {
|
||||
return "", false, fmt.Errorf("proxy target service not found: %q",
|
||||
proxy.Proxy.TargetServiceID)
|
||||
}
|
||||
|
||||
if target.Service != targetService {
|
||||
return "", false, acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// Resolve the actual ACL token used to register the proxy/service and
|
||||
// return that for use in RPC calls.
|
||||
return a.State.ServiceToken(proxy.Proxy.TargetServiceID), true, nil
|
||||
}
|
||||
|
||||
// Doesn't match, we have to do a full token resolution. The required
|
||||
// permission for any proxy-related endpoint is service:write, since
|
||||
// to register a proxy you require that permission and sensitive data
|
||||
// is usually present in the configuration.
|
||||
rule, err := a.resolveToken(token)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
if rule != nil && !rule.ServiceWrite(targetService, nil) {
|
||||
return "", false, acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
return token, false, nil
|
||||
}
|
||||
|
||||
func (a *Agent) cancelCheckMonitors(checkID types.CheckID) {
|
||||
// Stop any monitors
|
||||
delete(a.checkReapAfter, checkID)
|
||||
|
@ -2018,7 +2401,7 @@ func (a *Agent) updateTTLCheck(checkID types.CheckID, status, output string) err
|
|||
check.SetStatus(status, output)
|
||||
|
||||
// We don't write any files in dev mode so bail here.
|
||||
if a.config.DevMode {
|
||||
if a.config.DataDir == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -2367,6 +2750,96 @@ func (a *Agent) unloadChecks() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// loadProxies will load connect proxy definitions from configuration and
|
||||
// persisted definitions on disk, and load them into the local agent.
|
||||
func (a *Agent) loadProxies(conf *config.RuntimeConfig) error {
|
||||
for _, svc := range conf.Services {
|
||||
if svc.Connect != nil {
|
||||
proxy, err := svc.ConnectManagedProxy()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding proxy: %s", err)
|
||||
}
|
||||
if proxy == nil {
|
||||
continue
|
||||
}
|
||||
if err := a.AddProxy(proxy, false, ""); err != nil {
|
||||
return fmt.Errorf("failed adding proxy: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load any persisted proxies
|
||||
proxyDir := filepath.Join(a.config.DataDir, proxyDir)
|
||||
files, err := ioutil.ReadDir(proxyDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("Failed reading proxies dir %q: %s", proxyDir, err)
|
||||
}
|
||||
for _, fi := range files {
|
||||
// Skip all dirs
|
||||
if fi.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip all partially written temporary files
|
||||
if strings.HasSuffix(fi.Name(), "tmp") {
|
||||
a.logger.Printf("[WARN] agent: Ignoring temporary proxy file %v", fi.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
// Open the file for reading
|
||||
file := filepath.Join(proxyDir, fi.Name())
|
||||
fh, err := os.Open(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed opening proxy file %q: %s", file, err)
|
||||
}
|
||||
|
||||
// Read the contents into a buffer
|
||||
buf, err := ioutil.ReadAll(fh)
|
||||
fh.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading proxy file %q: %s", file, err)
|
||||
}
|
||||
|
||||
// Try decoding the proxy definition
|
||||
var p persistedProxy
|
||||
if err := json.Unmarshal(buf, &p); err != nil {
|
||||
a.logger.Printf("[ERR] agent: Failed decoding proxy file %q: %s", file, err)
|
||||
continue
|
||||
}
|
||||
proxyID := p.Proxy.ProxyService.ID
|
||||
|
||||
if a.State.Proxy(proxyID) != nil {
|
||||
// Purge previously persisted proxy. This allows config to be preferred
|
||||
// over services persisted from the API.
|
||||
a.logger.Printf("[DEBUG] agent: proxy %q exists, not restoring from %q",
|
||||
proxyID, file)
|
||||
if err := a.purgeProxy(proxyID); err != nil {
|
||||
return fmt.Errorf("failed purging proxy %q: %s", proxyID, err)
|
||||
}
|
||||
} else {
|
||||
a.logger.Printf("[DEBUG] agent: restored proxy definition %q from %q",
|
||||
proxyID, file)
|
||||
if err := a.AddProxy(p.Proxy, false, p.ProxyToken); err != nil {
|
||||
return fmt.Errorf("failed adding proxy %q: %s", proxyID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// unloadProxies will deregister all proxies known to the local agent.
|
||||
func (a *Agent) unloadProxies() error {
|
||||
for id := range a.State.Proxies() {
|
||||
if err := a.RemoveProxy(id, false); err != nil {
|
||||
return fmt.Errorf("Failed deregistering proxy '%s': %s", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// snapshotCheckState is used to snapshot the current state of the health
|
||||
// checks. This is done before we reload our checks, so that we can properly
|
||||
// restore into the same state.
|
||||
|
@ -2508,6 +2981,9 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
|
|||
|
||||
// First unload all checks, services, and metadata. This lets us begin the reload
|
||||
// with a clean slate.
|
||||
if err := a.unloadProxies(); err != nil {
|
||||
return fmt.Errorf("Failed unloading proxies: %s", err)
|
||||
}
|
||||
if err := a.unloadServices(); err != nil {
|
||||
return fmt.Errorf("Failed unloading services: %s", err)
|
||||
}
|
||||
|
@ -2520,6 +2996,9 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
|
|||
if err := a.loadServices(newCfg); err != nil {
|
||||
return fmt.Errorf("Failed reloading services: %s", err)
|
||||
}
|
||||
if err := a.loadProxies(newCfg); err != nil {
|
||||
return fmt.Errorf("Failed reloading proxies: %s", err)
|
||||
}
|
||||
if err := a.loadChecks(newCfg); err != nil {
|
||||
return fmt.Errorf("Failed reloading checks: %s", err)
|
||||
}
|
||||
|
@ -2544,9 +3023,62 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
|
|||
}
|
||||
|
||||
// Update filtered metrics
|
||||
metrics.UpdateFilter(newCfg.TelemetryAllowedPrefixes, newCfg.TelemetryBlockedPrefixes)
|
||||
metrics.UpdateFilter(newCfg.Telemetry.AllowedPrefixes,
|
||||
newCfg.Telemetry.BlockedPrefixes)
|
||||
|
||||
a.State.SetDiscardCheckOutput(newCfg.DiscardCheckOutput)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerCache configures the cache and registers all the supported
|
||||
// types onto the cache. This is NOT safe to call multiple times so
|
||||
// care should be taken to call this exactly once after the cache
|
||||
// field has been initialized.
|
||||
func (a *Agent) registerCache() {
|
||||
a.cache.RegisterType(cachetype.ConnectCARootName, &cachetype.ConnectCARoot{
|
||||
RPC: a.delegate,
|
||||
}, &cache.RegisterOptions{
|
||||
// Maintain a blocking query, retry dropped connections quickly
|
||||
Refresh: true,
|
||||
RefreshTimer: 0 * time.Second,
|
||||
RefreshTimeout: 10 * time.Minute,
|
||||
})
|
||||
|
||||
a.cache.RegisterType(cachetype.ConnectCALeafName, &cachetype.ConnectCALeaf{
|
||||
RPC: a.delegate,
|
||||
Cache: a.cache,
|
||||
}, &cache.RegisterOptions{
|
||||
// Maintain a blocking query, retry dropped connections quickly
|
||||
Refresh: true,
|
||||
RefreshTimer: 0 * time.Second,
|
||||
RefreshTimeout: 10 * time.Minute,
|
||||
})
|
||||
|
||||
a.cache.RegisterType(cachetype.IntentionMatchName, &cachetype.IntentionMatch{
|
||||
RPC: a.delegate,
|
||||
}, &cache.RegisterOptions{
|
||||
// Maintain a blocking query, retry dropped connections quickly
|
||||
Refresh: true,
|
||||
RefreshTimer: 0 * time.Second,
|
||||
RefreshTimeout: 10 * time.Minute,
|
||||
})
|
||||
}
|
||||
|
||||
// defaultProxyCommand returns the default Connect managed proxy command.
|
||||
func defaultProxyCommand(agentCfg *config.RuntimeConfig) ([]string, error) {
|
||||
// Get the path to the current exectuable. This is cached once by the
|
||||
// library so this is effectively just a variable read.
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// "consul connect proxy" default value for managed daemon proxy
|
||||
cmd := []string{execPath, "connect", "proxy"}
|
||||
|
||||
if agentCfg != nil && agentCfg.LogLevel != "INFO" {
|
||||
cmd = append(cmd, "-log-level", agentCfg.LogLevel)
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
|
|
@ -4,12 +4,21 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/mitchellh/hashstructure"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/cache-types"
|
||||
"github.com/hashicorp/consul/agent/checks"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/ipaddr"
|
||||
|
@ -97,7 +106,7 @@ func (s *HTTPServer) AgentMetrics(resp http.ResponseWriter, req *http.Request) (
|
|||
return nil, acl.ErrPermissionDenied
|
||||
}
|
||||
if enablePrometheusOutput(req) {
|
||||
if s.agent.config.TelemetryPrometheusRetentionTime < 1 {
|
||||
if s.agent.config.Telemetry.PrometheusRetentionTime < 1 {
|
||||
resp.WriteHeader(http.StatusUnsupportedMediaType)
|
||||
fmt.Fprint(resp, "Prometheus is not enable since its retention time is not positive")
|
||||
return nil, nil
|
||||
|
@ -153,25 +162,49 @@ func (s *HTTPServer) AgentServices(resp http.ResponseWriter, req *http.Request)
|
|||
return nil, err
|
||||
}
|
||||
|
||||
proxies := s.agent.State.Proxies()
|
||||
|
||||
// Convert into api.AgentService since that includes Connect config but so far
|
||||
// NodeService doesn't need to internally. They are otherwise identical since
|
||||
// that is the struct used in client for reading the one we output here
|
||||
// anyway.
|
||||
agentSvcs := make(map[string]*api.AgentService)
|
||||
|
||||
// Use empty list instead of nil
|
||||
for id, s := range services {
|
||||
if s.Tags == nil || s.Meta == nil {
|
||||
clone := *s
|
||||
if s.Tags == nil {
|
||||
clone.Tags = make([]string, 0)
|
||||
} else {
|
||||
clone.Tags = s.Tags
|
||||
as := &api.AgentService{
|
||||
Kind: api.ServiceKind(s.Kind),
|
||||
ID: s.ID,
|
||||
Service: s.Service,
|
||||
Tags: s.Tags,
|
||||
Meta: s.Meta,
|
||||
Port: s.Port,
|
||||
Address: s.Address,
|
||||
EnableTagOverride: s.EnableTagOverride,
|
||||
CreateIndex: s.CreateIndex,
|
||||
ModifyIndex: s.ModifyIndex,
|
||||
ProxyDestination: s.ProxyDestination,
|
||||
}
|
||||
if s.Meta == nil {
|
||||
clone.Meta = make(map[string]string)
|
||||
} else {
|
||||
clone.Meta = s.Meta
|
||||
if as.Tags == nil {
|
||||
as.Tags = []string{}
|
||||
}
|
||||
services[id] = &clone
|
||||
if as.Meta == nil {
|
||||
as.Meta = map[string]string{}
|
||||
}
|
||||
// Attach Connect configs if the exist
|
||||
if proxy, ok := proxies[id+"-proxy"]; ok {
|
||||
as.Connect = &api.AgentServiceConnect{
|
||||
Proxy: &api.AgentServiceConnectProxy{
|
||||
ExecMode: api.ProxyExecMode(proxy.Proxy.ExecMode.String()),
|
||||
Command: proxy.Proxy.Command,
|
||||
Config: proxy.Proxy.Config,
|
||||
},
|
||||
}
|
||||
}
|
||||
agentSvcs[id] = as
|
||||
}
|
||||
|
||||
return services, nil
|
||||
return agentSvcs, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) AgentChecks(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
|
@ -554,6 +587,14 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// Run validation. This is the same validation that would happen on
|
||||
// the catalog endpoint so it helps ensure the sync will work properly.
|
||||
if err := ns.Validate(); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, err.Error())
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Verify the check type.
|
||||
chkTypes, err := args.CheckTypes()
|
||||
if err != nil {
|
||||
|
@ -576,10 +617,30 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Get any proxy registrations
|
||||
proxy, err := args.ConnectManagedProxy()
|
||||
if err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, err.Error())
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If we have a proxy, verify that we're allowed to add a proxy via the API
|
||||
if proxy != nil && !s.agent.config.ConnectProxyAllowManagedAPIRegistration {
|
||||
return nil, &BadRequestError{
|
||||
Reason: "Managed proxy registration via the API is disallowed."}
|
||||
}
|
||||
|
||||
// Add the service.
|
||||
if err := s.agent.AddService(ns, chkTypes, true, token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Add proxy (which will add proxy service so do it before we trigger sync)
|
||||
if proxy != nil {
|
||||
if err := s.agent.AddProxy(proxy, true, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
s.syncChanges()
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -594,9 +655,27 @@ func (s *HTTPServer) AgentDeregisterService(resp http.ResponseWriter, req *http.
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Verify this isn't a proxy
|
||||
if s.agent.State.Proxy(serviceID) != nil {
|
||||
return nil, &BadRequestError{
|
||||
Reason: "Managed proxy service cannot be deregistered directly. " +
|
||||
"Deregister the service that has a managed proxy to automatically " +
|
||||
"deregister the managed proxy itself."}
|
||||
}
|
||||
|
||||
if err := s.agent.RemoveService(serviceID, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove the associated managed proxy if it exists
|
||||
for proxyID, p := range s.agent.State.Proxies() {
|
||||
if p.Proxy.TargetServiceID == serviceID {
|
||||
if err := s.agent.RemoveProxy(proxyID, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.syncChanges()
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -828,3 +907,394 @@ func (s *HTTPServer) AgentToken(resp http.ResponseWriter, req *http.Request) (in
|
|||
s.agent.logger.Printf("[INFO] agent: Updated agent's ACL token %q", target)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// AgentConnectCARoots returns the trusted CA roots.
|
||||
func (s *HTTPServer) AgentConnectCARoots(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
var args structs.DCSpecificRequest
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
raw, m, err := s.agent.cache.Get(cachetype.ConnectCARootName, &args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer setCacheMeta(resp, &m)
|
||||
|
||||
// Add cache hit
|
||||
|
||||
reply, ok := raw.(*structs.IndexedCARoots)
|
||||
if !ok {
|
||||
// This should never happen, but we want to protect against panics
|
||||
return nil, fmt.Errorf("internal error: response type not correct")
|
||||
}
|
||||
defer setMeta(resp, &reply.QueryMeta)
|
||||
|
||||
return *reply, nil
|
||||
}
|
||||
|
||||
// AgentConnectCALeafCert returns the certificate bundle for a service
|
||||
// instance. This supports blocking queries to update the returned bundle.
|
||||
func (s *HTTPServer) AgentConnectCALeafCert(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Get the service name. Note that this is the name of the sevice,
|
||||
// not the ID of the service instance.
|
||||
serviceName := strings.TrimPrefix(req.URL.Path, "/v1/agent/connect/ca/leaf/")
|
||||
|
||||
args := cachetype.ConnectCALeafRequest{
|
||||
Service: serviceName, // Need name not ID
|
||||
}
|
||||
var qOpts structs.QueryOptions
|
||||
// Store DC in the ConnectCALeafRequest but query opts separately
|
||||
if done := s.parse(resp, req, &args.Datacenter, &qOpts); done {
|
||||
return nil, nil
|
||||
}
|
||||
args.MinQueryIndex = qOpts.MinQueryIndex
|
||||
|
||||
// Verify the proxy token. This will check both the local proxy token
|
||||
// as well as the ACL if the token isn't local.
|
||||
effectiveToken, _, err := s.agent.verifyProxyToken(qOpts.Token, serviceName, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args.Token = effectiveToken
|
||||
|
||||
raw, m, err := s.agent.cache.Get(cachetype.ConnectCALeafName, &args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer setCacheMeta(resp, &m)
|
||||
|
||||
reply, ok := raw.(*structs.IssuedCert)
|
||||
if !ok {
|
||||
// This should never happen, but we want to protect against panics
|
||||
return nil, fmt.Errorf("internal error: response type not correct")
|
||||
}
|
||||
setIndex(resp, reply.ModifyIndex)
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// GET /v1/agent/connect/proxy/:proxy_service_id
|
||||
//
|
||||
// Returns the local proxy config for the identified proxy. Requires token=
|
||||
// param with the correct local ProxyToken (not ACL token).
|
||||
func (s *HTTPServer) AgentConnectProxyConfig(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Get the proxy ID. Note that this is the ID of a proxy's service instance.
|
||||
id := strings.TrimPrefix(req.URL.Path, "/v1/agent/connect/proxy/")
|
||||
|
||||
// Maybe block
|
||||
var queryOpts structs.QueryOptions
|
||||
if parseWait(resp, req, &queryOpts) {
|
||||
// parseWait returns an error itself
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Parse the token
|
||||
var token string
|
||||
s.parseToken(req, &token)
|
||||
|
||||
// Parse hash specially since it's only this endpoint that uses it currently.
|
||||
// Eventually this should happen in parseWait and end up in QueryOptions but I
|
||||
// didn't want to make very general changes right away.
|
||||
hash := req.URL.Query().Get("hash")
|
||||
|
||||
return s.agentLocalBlockingQuery(resp, hash, &queryOpts,
|
||||
func(ws memdb.WatchSet) (string, interface{}, error) {
|
||||
// Retrieve the proxy specified
|
||||
proxy := s.agent.State.Proxy(id)
|
||||
if proxy == nil {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(resp, "unknown proxy service ID: %s", id)
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// Lookup the target service as a convenience
|
||||
target := s.agent.State.Service(proxy.Proxy.TargetServiceID)
|
||||
if target == nil {
|
||||
// Not found since this endpoint is only useful for agent-managed proxies so
|
||||
// service missing means the service was deregistered racily with this call.
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(resp, "unknown target service ID: %s", proxy.Proxy.TargetServiceID)
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// Validate the ACL token
|
||||
_, isProxyToken, err := s.agent.verifyProxyToken(token, target.Service, id)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Watch the proxy for changes
|
||||
ws.Add(proxy.WatchCh)
|
||||
|
||||
hash, err := hashstructure.Hash(proxy.Proxy, nil)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
contentHash := fmt.Sprintf("%x", hash)
|
||||
|
||||
// Set defaults
|
||||
config, err := s.agent.applyProxyConfigDefaults(proxy.Proxy)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Only merge in telemetry config from agent if the requested is
|
||||
// authorized with a proxy token. This prevents us leaking potentially
|
||||
// sensitive config like Circonus API token via a public endpoint. Proxy
|
||||
// tokens are only ever generated in-memory and passed via ENV to a child
|
||||
// proxy process so potential for abuse here seems small. This endpoint in
|
||||
// general is only useful for managed proxies now so it should _always_ be
|
||||
// true that auth is via a proxy token but inconvenient for testing if we
|
||||
// lock it down so strictly.
|
||||
if isProxyToken {
|
||||
// Add telemetry config. Copy the global config so we can customize the
|
||||
// prefix.
|
||||
telemetryCfg := s.agent.config.Telemetry
|
||||
telemetryCfg.MetricsPrefix = telemetryCfg.MetricsPrefix + ".proxy." + target.ID
|
||||
|
||||
// First see if the user has specified telemetry
|
||||
if userRaw, ok := config["telemetry"]; ok {
|
||||
// User specified domething, see if it is compatible with agent
|
||||
// telemetry config:
|
||||
var uCfg lib.TelemetryConfig
|
||||
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
Result: &uCfg,
|
||||
// Make sure that if the user passes something that isn't just a
|
||||
// simple override of a valid TelemetryConfig that we fail so that we
|
||||
// don't clobber their custom config.
|
||||
ErrorUnused: true,
|
||||
})
|
||||
if err == nil {
|
||||
if err = dec.Decode(userRaw); err == nil {
|
||||
// It did decode! Merge any unspecified fields from agent config.
|
||||
uCfg.MergeDefaults(&telemetryCfg)
|
||||
config["telemetry"] = uCfg
|
||||
}
|
||||
}
|
||||
// Failed to decode, just keep user's config["telemetry"] verbatim
|
||||
// with no agent merge.
|
||||
} else {
|
||||
// Add agent telemetry config.
|
||||
config["telemetry"] = telemetryCfg
|
||||
}
|
||||
}
|
||||
|
||||
reply := &api.ConnectProxyConfig{
|
||||
ProxyServiceID: proxy.Proxy.ProxyService.ID,
|
||||
TargetServiceID: target.ID,
|
||||
TargetServiceName: target.Service,
|
||||
ContentHash: contentHash,
|
||||
ExecMode: api.ProxyExecMode(proxy.Proxy.ExecMode.String()),
|
||||
Command: proxy.Proxy.Command,
|
||||
Config: config,
|
||||
}
|
||||
return contentHash, reply, nil
|
||||
})
|
||||
}
|
||||
|
||||
type agentLocalBlockingFunc func(ws memdb.WatchSet) (string, interface{}, error)
|
||||
|
||||
// agentLocalBlockingQuery performs a blocking query in a generic way against
|
||||
// local agent state that has no RPC or raft to back it. It uses `hash` paramter
|
||||
// instead of an `index`. The resp is needed to write the `X-Consul-ContentHash`
|
||||
// header back on return no Status nor body content is ever written to it.
|
||||
func (s *HTTPServer) agentLocalBlockingQuery(resp http.ResponseWriter, hash string,
|
||||
queryOpts *structs.QueryOptions, fn agentLocalBlockingFunc) (interface{}, error) {
|
||||
|
||||
// If we are not blocking we can skip tracking and allocating - nil WatchSet
|
||||
// is still valid to call Add on and will just be a no op.
|
||||
var ws memdb.WatchSet
|
||||
var timeout *time.Timer
|
||||
|
||||
if hash != "" {
|
||||
// TODO(banks) at least define these defaults somewhere in a const. Would be
|
||||
// nice not to duplicate the ones in consul/rpc.go too...
|
||||
wait := queryOpts.MaxQueryTime
|
||||
if wait == 0 {
|
||||
wait = 5 * time.Minute
|
||||
}
|
||||
if wait > 10*time.Minute {
|
||||
wait = 10 * time.Minute
|
||||
}
|
||||
// Apply a small amount of jitter to the request.
|
||||
wait += lib.RandomStagger(wait / 16)
|
||||
timeout = time.NewTimer(wait)
|
||||
}
|
||||
|
||||
for {
|
||||
// Must reset this every loop in case the Watch set is already closed but
|
||||
// hash remains same. In that case we'll need to re-block on ws.Watch()
|
||||
// again.
|
||||
ws = memdb.NewWatchSet()
|
||||
curHash, curResp, err := fn(ws)
|
||||
if err != nil {
|
||||
return curResp, err
|
||||
}
|
||||
// Return immediately if there is no timeout, the hash is different or the
|
||||
// Watch returns true (indicating timeout fired). Note that Watch on a nil
|
||||
// WatchSet immediately returns false which would incorrectly cause this to
|
||||
// loop and repeat again, however we rely on the invariant that ws == nil
|
||||
// IFF timeout == nil in which case the Watch call is never invoked.
|
||||
if timeout == nil || hash != curHash || ws.Watch(timeout.C) {
|
||||
resp.Header().Set("X-Consul-ContentHash", curHash)
|
||||
return curResp, err
|
||||
}
|
||||
// Watch returned false indicating a change was detected, loop and repeat
|
||||
// the callback to load the new value.
|
||||
}
|
||||
}
|
||||
|
||||
// AgentConnectAuthorize
|
||||
//
|
||||
// POST /v1/agent/connect/authorize
|
||||
//
|
||||
// Note: when this logic changes, consider if the Intention.Check RPC method
|
||||
// also needs to be updated.
|
||||
func (s *HTTPServer) AgentConnectAuthorize(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Fetch the token
|
||||
var token string
|
||||
s.parseToken(req, &token)
|
||||
|
||||
// Decode the request from the request body
|
||||
var authReq structs.ConnectAuthorizeRequest
|
||||
if err := decodeBody(req, &authReq, nil); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Request decode failed: %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// We need to have a target to check intentions
|
||||
if authReq.Target == "" {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Target service must be specified")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Parse the certificate URI from the client ID
|
||||
uriRaw, err := url.Parse(authReq.ClientCertURI)
|
||||
if err != nil {
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: false,
|
||||
Reason: fmt.Sprintf("Client ID must be a URI: %s", err),
|
||||
}, nil
|
||||
}
|
||||
uri, err := connect.ParseCertURI(uriRaw)
|
||||
if err != nil {
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: false,
|
||||
Reason: fmt.Sprintf("Invalid client ID: %s", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
uriService, ok := uri.(*connect.SpiffeIDService)
|
||||
if !ok {
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: false,
|
||||
Reason: "Client ID must be a valid SPIFFE service URI",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// We need to verify service:write permissions for the given token.
|
||||
// We do this manually here since the RPC request below only verifies
|
||||
// service:read.
|
||||
rule, err := s.agent.resolveToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rule != nil && !rule.ServiceWrite(authReq.Target, nil) {
|
||||
return nil, acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// Validate the trust domain matches ours. Later we will support explicit
|
||||
// external federation but not built yet.
|
||||
rootArgs := &structs.DCSpecificRequest{Datacenter: s.agent.config.Datacenter}
|
||||
raw, _, err := s.agent.cache.Get(cachetype.ConnectCARootName, rootArgs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roots, ok := raw.(*structs.IndexedCARoots)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal error: roots response type not correct")
|
||||
}
|
||||
if roots.TrustDomain == "" {
|
||||
return nil, fmt.Errorf("connect CA not bootstrapped yet")
|
||||
}
|
||||
if roots.TrustDomain != strings.ToLower(uriService.Host) {
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: false,
|
||||
Reason: fmt.Sprintf("Identity from an external trust domain: %s",
|
||||
uriService.Host),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO(banks): Implement revocation list checking here.
|
||||
|
||||
// Get the intentions for this target service.
|
||||
args := &structs.IntentionQueryRequest{
|
||||
Datacenter: s.agent.config.Datacenter,
|
||||
Match: &structs.IntentionQueryMatch{
|
||||
Type: structs.IntentionMatchDestination,
|
||||
Entries: []structs.IntentionMatchEntry{
|
||||
{
|
||||
Namespace: structs.IntentionDefaultNamespace,
|
||||
Name: authReq.Target,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
args.Token = token
|
||||
|
||||
raw, m, err := s.agent.cache.Get(cachetype.IntentionMatchName, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setCacheMeta(resp, &m)
|
||||
|
||||
reply, ok := raw.(*structs.IndexedIntentionMatches)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal error: response type not correct")
|
||||
}
|
||||
if len(reply.Matches) != 1 {
|
||||
return nil, fmt.Errorf("Internal error loading matches")
|
||||
}
|
||||
|
||||
// Test the authorization for each match
|
||||
for _, ixn := range reply.Matches[0] {
|
||||
if auth, ok := uriService.Authorize(ixn); ok {
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: auth,
|
||||
Reason: fmt.Sprintf("Matched intention: %s", ixn.String()),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// No match, we need to determine the default behavior. We do this by
|
||||
// specifying the anonymous token token, which will get that behavior.
|
||||
// The default behavior if ACLs are disabled is to allow connections
|
||||
// to mimic the behavior of Consul itself: everything is allowed if
|
||||
// ACLs are disabled.
|
||||
rule, err = s.agent.resolveToken("")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authz := true
|
||||
reason := "ACLs disabled, access is allowed by default"
|
||||
if rule != nil {
|
||||
authz = rule.IntentionDefaultAllow()
|
||||
reason = "Default behavior configured by ACLs"
|
||||
}
|
||||
|
||||
return &connectAuthorizeResp{
|
||||
Authorized: authz,
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// connectAuthorizeResp is the response format/structure for the
|
||||
// /v1/agent/connect/authorize endpoint.
|
||||
type connectAuthorizeResp struct {
|
||||
Authorized bool // True if authorized, false if not
|
||||
Reason string // Reason for the Authorized value (whether true or false)
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -16,6 +16,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/checks"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
|
@ -23,6 +24,8 @@ import (
|
|||
"github.com/hashicorp/consul/types"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func externalIP() (string, error) {
|
||||
|
@ -51,10 +54,62 @@ func TestAgent_MultiStartStop(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAgent_ConnectClusterIDConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hcl string
|
||||
wantClusterID string
|
||||
wantPanic bool
|
||||
}{
|
||||
{
|
||||
name: "default TestAgent has fixed cluster id",
|
||||
hcl: "",
|
||||
wantClusterID: connect.TestClusterID,
|
||||
},
|
||||
{
|
||||
name: "no cluster ID specified sets to test ID",
|
||||
hcl: "connect { enabled = true }",
|
||||
wantClusterID: connect.TestClusterID,
|
||||
},
|
||||
{
|
||||
name: "non-UUID cluster_id is fatal",
|
||||
hcl: `connect {
|
||||
enabled = true
|
||||
ca_config {
|
||||
cluster_id = "fake-id"
|
||||
}
|
||||
}`,
|
||||
wantClusterID: "",
|
||||
wantPanic: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Indirection to support panic recovery cleanly
|
||||
testFn := func() {
|
||||
a := &TestAgent{Name: "test", HCL: tt.hcl}
|
||||
a.ExpectConfigError = tt.wantPanic
|
||||
a.Start()
|
||||
defer a.Shutdown()
|
||||
|
||||
cfg := a.consulConfig()
|
||||
assert.Equal(t, tt.wantClusterID, cfg.CAConfig.ClusterID)
|
||||
}
|
||||
|
||||
if tt.wantPanic {
|
||||
require.Panics(t, testFn)
|
||||
} else {
|
||||
testFn()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_StartStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
// defer a.Shutdown()
|
||||
defer a.Shutdown()
|
||||
|
||||
if err := a.Leave(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
|
@ -1294,6 +1349,187 @@ func TestAgent_PurgeServiceOnDuplicate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAgent_PersistProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
|
||||
cfg := `
|
||||
server = false
|
||||
bootstrap = false
|
||||
data_dir = "` + dataDir + `"
|
||||
`
|
||||
a := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
|
||||
a.Start()
|
||||
defer os.RemoveAll(dataDir)
|
||||
defer a.Shutdown()
|
||||
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
|
||||
// Add a service to proxy (precondition for AddProxy)
|
||||
svc1 := &structs.NodeService{
|
||||
ID: "redis",
|
||||
Service: "redis",
|
||||
Tags: []string{"foo"},
|
||||
Port: 8000,
|
||||
}
|
||||
require.NoError(a.AddService(svc1, nil, true, ""))
|
||||
|
||||
// Add a proxy for it
|
||||
proxy := &structs.ConnectManagedProxy{
|
||||
TargetServiceID: svc1.ID,
|
||||
Command: []string{"/bin/sleep", "3600"},
|
||||
}
|
||||
|
||||
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash("redis-proxy"))
|
||||
|
||||
// Proxy is not persisted unless requested
|
||||
require.NoError(a.AddProxy(proxy, false, ""))
|
||||
_, err := os.Stat(file)
|
||||
require.Error(err, "proxy should not be persisted")
|
||||
|
||||
// Proxy is persisted if requested
|
||||
require.NoError(a.AddProxy(proxy, true, ""))
|
||||
_, err = os.Stat(file)
|
||||
require.NoError(err, "proxy should be persisted")
|
||||
|
||||
content, err := ioutil.ReadFile(file)
|
||||
require.NoError(err)
|
||||
|
||||
var gotProxy persistedProxy
|
||||
require.NoError(json.Unmarshal(content, &gotProxy))
|
||||
assert.Equal(proxy.Command, gotProxy.Proxy.Command)
|
||||
assert.Len(gotProxy.ProxyToken, 36) // sanity check for UUID
|
||||
|
||||
// Updates service definition on disk
|
||||
proxy.Config = map[string]interface{}{
|
||||
"foo": "bar",
|
||||
}
|
||||
require.NoError(a.AddProxy(proxy, true, ""))
|
||||
|
||||
content, err = ioutil.ReadFile(file)
|
||||
require.NoError(err)
|
||||
|
||||
require.NoError(json.Unmarshal(content, &gotProxy))
|
||||
assert.Equal(gotProxy.Proxy.Command, proxy.Command)
|
||||
assert.Equal(gotProxy.Proxy.Config, proxy.Config)
|
||||
assert.Len(gotProxy.ProxyToken, 36) // sanity check for UUID
|
||||
|
||||
a.Shutdown()
|
||||
|
||||
// Should load it back during later start
|
||||
a2 := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
|
||||
a2.Start()
|
||||
defer a2.Shutdown()
|
||||
|
||||
restored := a2.State.Proxy("redis-proxy")
|
||||
require.NotNil(restored)
|
||||
assert.Equal(gotProxy.ProxyToken, restored.ProxyToken)
|
||||
// Ensure the port that was auto picked at random is the same again
|
||||
assert.Equal(gotProxy.Proxy.ProxyService.Port, restored.Proxy.ProxyService.Port)
|
||||
assert.Equal(gotProxy.Proxy.Command, restored.Proxy.Command)
|
||||
}
|
||||
|
||||
func TestAgent_PurgeProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
// Add a service to proxy (precondition for AddProxy)
|
||||
svc1 := &structs.NodeService{
|
||||
ID: "redis",
|
||||
Service: "redis",
|
||||
Tags: []string{"foo"},
|
||||
Port: 8000,
|
||||
}
|
||||
require.NoError(a.AddService(svc1, nil, true, ""))
|
||||
|
||||
// Add a proxy for it
|
||||
proxy := &structs.ConnectManagedProxy{
|
||||
TargetServiceID: svc1.ID,
|
||||
Command: []string{"/bin/sleep", "3600"},
|
||||
}
|
||||
proxyID := "redis-proxy"
|
||||
require.NoError(a.AddProxy(proxy, true, ""))
|
||||
|
||||
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash("redis-proxy"))
|
||||
|
||||
// Not removed
|
||||
require.NoError(a.RemoveProxy(proxyID, false))
|
||||
_, err := os.Stat(file)
|
||||
require.NoError(err, "should not be removed")
|
||||
|
||||
// Re-add the proxy
|
||||
require.NoError(a.AddProxy(proxy, true, ""))
|
||||
|
||||
// Removed
|
||||
require.NoError(a.RemoveProxy(proxyID, true))
|
||||
_, err = os.Stat(file)
|
||||
require.Error(err, "should be removed")
|
||||
}
|
||||
|
||||
func TestAgent_PurgeProxyOnDuplicate(t *testing.T) {
|
||||
t.Parallel()
|
||||
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
|
||||
cfg := `
|
||||
data_dir = "` + dataDir + `"
|
||||
server = false
|
||||
bootstrap = false
|
||||
`
|
||||
a := &TestAgent{Name: t.Name(), HCL: cfg, DataDir: dataDir}
|
||||
a.Start()
|
||||
defer a.Shutdown()
|
||||
defer os.RemoveAll(dataDir)
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
// Add a service to proxy (precondition for AddProxy)
|
||||
svc1 := &structs.NodeService{
|
||||
ID: "redis",
|
||||
Service: "redis",
|
||||
Tags: []string{"foo"},
|
||||
Port: 8000,
|
||||
}
|
||||
require.NoError(a.AddService(svc1, nil, true, ""))
|
||||
|
||||
// Add a proxy for it
|
||||
proxy := &structs.ConnectManagedProxy{
|
||||
TargetServiceID: svc1.ID,
|
||||
Command: []string{"/bin/sleep", "3600"},
|
||||
}
|
||||
proxyID := "redis-proxy"
|
||||
require.NoError(a.AddProxy(proxy, true, ""))
|
||||
|
||||
a.Shutdown()
|
||||
|
||||
// Try bringing the agent back up with the service already
|
||||
// existing in the config
|
||||
a2 := &TestAgent{Name: t.Name() + "-a2", HCL: cfg + `
|
||||
service = {
|
||||
id = "redis"
|
||||
name = "redis"
|
||||
tags = ["bar"]
|
||||
port = 9000
|
||||
connect {
|
||||
proxy {
|
||||
command = ["/bin/sleep", "3600"]
|
||||
}
|
||||
}
|
||||
}
|
||||
`, DataDir: dataDir}
|
||||
a2.Start()
|
||||
defer a2.Shutdown()
|
||||
|
||||
file := filepath.Join(a.Config.DataDir, proxyDir, stringHash(proxyID))
|
||||
_, err := os.Stat(file)
|
||||
require.Error(err, "should have removed remote state")
|
||||
|
||||
result := a2.State.Proxy(proxyID)
|
||||
require.NotNil(result)
|
||||
require.Equal(proxy.Command, result.Proxy.Command)
|
||||
}
|
||||
|
||||
func TestAgent_PersistCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
dataDir := testutil.TempDir(t, "agent") // we manage the data dir
|
||||
|
@ -1629,6 +1865,96 @@ func TestAgent_unloadServices(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAgent_loadProxies(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), `
|
||||
service = {
|
||||
id = "rabbitmq"
|
||||
name = "rabbitmq"
|
||||
port = 5672
|
||||
token = "abc123"
|
||||
connect {
|
||||
proxy {
|
||||
config {
|
||||
bind_port = 1234
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`)
|
||||
defer a.Shutdown()
|
||||
|
||||
services := a.State.Services()
|
||||
if _, ok := services["rabbitmq"]; !ok {
|
||||
t.Fatalf("missing service")
|
||||
}
|
||||
if token := a.State.ServiceToken("rabbitmq"); token != "abc123" {
|
||||
t.Fatalf("bad: %s", token)
|
||||
}
|
||||
if _, ok := services["rabbitmq-proxy"]; !ok {
|
||||
t.Fatalf("missing proxy service")
|
||||
}
|
||||
if token := a.State.ServiceToken("rabbitmq-proxy"); token != "abc123" {
|
||||
t.Fatalf("bad: %s", token)
|
||||
}
|
||||
proxies := a.State.Proxies()
|
||||
if _, ok := proxies["rabbitmq-proxy"]; !ok {
|
||||
t.Fatalf("missing proxy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_loadProxies_nilProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), `
|
||||
service = {
|
||||
id = "rabbitmq"
|
||||
name = "rabbitmq"
|
||||
port = 5672
|
||||
token = "abc123"
|
||||
connect {
|
||||
}
|
||||
}
|
||||
`)
|
||||
defer a.Shutdown()
|
||||
|
||||
services := a.State.Services()
|
||||
require.Contains(t, services, "rabbitmq")
|
||||
require.Equal(t, "abc123", a.State.ServiceToken("rabbitmq"))
|
||||
require.NotContains(t, services, "rabbitme-proxy")
|
||||
require.Empty(t, a.State.Proxies())
|
||||
}
|
||||
|
||||
func TestAgent_unloadProxies(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), `
|
||||
service = {
|
||||
id = "rabbitmq"
|
||||
name = "rabbitmq"
|
||||
port = 5672
|
||||
token = "abc123"
|
||||
connect {
|
||||
proxy {
|
||||
config {
|
||||
bind_port = 1234
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`)
|
||||
defer a.Shutdown()
|
||||
|
||||
// Sanity check it's there
|
||||
require.NotNil(t, a.State.Proxy("rabbitmq-proxy"))
|
||||
|
||||
// Unload all proxies
|
||||
if err := a.unloadProxies(); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if len(a.State.Proxies()) != 0 {
|
||||
t.Fatalf("should have unloaded proxies")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_Service_MaintenanceMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
|
@ -2179,6 +2505,18 @@ func TestAgent_reloadWatches(t *testing.T) {
|
|||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
|
||||
// Should fail to reload with connect watches
|
||||
newConf.Watches = []map[string]interface{}{
|
||||
{
|
||||
"type": "connect_roots",
|
||||
"key": "asdf",
|
||||
"args": []interface{}{"ls"},
|
||||
},
|
||||
}
|
||||
if err := a.reloadWatches(&newConf); err == nil || !strings.Contains(err.Error(), "not allowed in agent config") {
|
||||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
|
||||
// Should still succeed with only HTTPS addresses
|
||||
newConf.HTTPSAddrs = newConf.HTTPAddrs
|
||||
newConf.HTTPAddrs = make([]net.Addr, 0)
|
||||
|
@ -2226,3 +2564,217 @@ func TestAgent_reloadWatchesHTTPS(t *testing.T) {
|
|||
t.Fatalf("bad: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), `
|
||||
node_name = "node1"
|
||||
|
||||
connect {
|
||||
proxy_defaults {
|
||||
exec_mode = "script"
|
||||
daemon_command = ["foo", "bar"]
|
||||
script_command = ["bar", "foo"]
|
||||
}
|
||||
}
|
||||
|
||||
ports {
|
||||
proxy_min_port = 20000
|
||||
proxy_max_port = 20000
|
||||
}
|
||||
`)
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register a target service we can use
|
||||
reg := &structs.NodeService{
|
||||
Service: "web",
|
||||
Port: 8080,
|
||||
}
|
||||
require.NoError(t, a.AddService(reg, nil, false, ""))
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
proxy, wantProxy *structs.ConnectManagedProxy
|
||||
wantTCPCheck string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
desc: "basic proxy adding, unregistered service",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
TargetServiceID: "db", // non-existent service.
|
||||
},
|
||||
// Target service must be registered.
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
desc: "basic proxy adding, registered service",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
desc: "default global exec mode",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantProxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeScript,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
desc: "default daemon command",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantProxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"foo", "bar"},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
desc: "default script command",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeScript,
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantProxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeScript,
|
||||
Command: []string{"bar", "foo"},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
desc: "managed proxy with custom bind port",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"bind_address": "127.10.10.10",
|
||||
"bind_port": 1234,
|
||||
},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantTCPCheck: "127.10.10.10:1234",
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
{
|
||||
// This test is necessary since JSON and HCL both will parse
|
||||
// numbers as a float64.
|
||||
desc: "managed proxy with custom bind port (float64)",
|
||||
proxy: &structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"bind_address": "127.10.10.10",
|
||||
"bind_port": float64(1234),
|
||||
},
|
||||
TargetServiceID: "web",
|
||||
},
|
||||
wantTCPCheck: "127.10.10.10:1234",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
err := a.AddProxy(tt.proxy, false, "")
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
// Test the ID was created as we expect.
|
||||
got := a.State.Proxy("web-proxy")
|
||||
wantProxy := tt.wantProxy
|
||||
if wantProxy == nil {
|
||||
wantProxy = tt.proxy
|
||||
}
|
||||
wantProxy.ProxyService = got.Proxy.ProxyService
|
||||
require.Equal(wantProxy, got.Proxy)
|
||||
|
||||
// Ensure a TCP check was created for the service.
|
||||
gotCheck := a.State.Check("service:web-proxy")
|
||||
require.NotNil(gotCheck)
|
||||
require.Equal("Connect Proxy Listening", gotCheck.Name)
|
||||
|
||||
// Confusingly, a.State.Check("service:web-proxy") will return the state
|
||||
// but it's Definition field will be empty. This appears to be expected
|
||||
// when adding Checks as part of `AddService`. Notice how `AddService`
|
||||
// tests in this file don't assert on that state but instead look at the
|
||||
// agent's check state directly to ensure the right thing was registered.
|
||||
// We'll do the same for now.
|
||||
gotTCP, ok := a.checkTCPs["service:web-proxy"]
|
||||
require.True(ok)
|
||||
wantTCPCheck := tt.wantTCPCheck
|
||||
if wantTCPCheck == "" {
|
||||
wantTCPCheck = "127.0.0.1:20000"
|
||||
}
|
||||
require.Equal(wantTCPCheck, gotTCP.TCP)
|
||||
require.Equal(10*time.Second, gotTCP.Interval)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), `
|
||||
node_name = "node1"
|
||||
`)
|
||||
defer a.Shutdown()
|
||||
require := require.New(t)
|
||||
|
||||
// Register a target service we can use
|
||||
reg := &structs.NodeService{
|
||||
Service: "web",
|
||||
Port: 8080,
|
||||
}
|
||||
require.NoError(a.AddService(reg, nil, false, ""))
|
||||
|
||||
// Add a proxy for web
|
||||
pReg := &structs.ConnectManagedProxy{
|
||||
TargetServiceID: "web",
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"foo"},
|
||||
}
|
||||
require.NoError(a.AddProxy(pReg, false, ""))
|
||||
|
||||
// Test the ID was created as we expect.
|
||||
gotProxy := a.State.Proxy("web-proxy")
|
||||
require.NotNil(gotProxy)
|
||||
|
||||
err := a.RemoveProxy("web-proxy", false)
|
||||
require.NoError(err)
|
||||
|
||||
gotProxy = a.State.Proxy("web-proxy")
|
||||
require.Nil(gotProxy)
|
||||
require.Nil(a.State.Service("web-proxy"), "web-proxy service")
|
||||
|
||||
// Removing invalid proxy should be an error
|
||||
err = a.RemoveProxy("foobar", false)
|
||||
require.Error(err)
|
||||
}
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,240 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// Recommended name for registration.
|
||||
const ConnectCALeafName = "connect-ca-leaf"
|
||||
|
||||
// ConnectCALeaf supports fetching and generating Connect leaf
|
||||
// certificates.
|
||||
type ConnectCALeaf struct {
|
||||
caIndex uint64 // Current index for CA roots
|
||||
|
||||
issuedCertsLock sync.RWMutex
|
||||
issuedCerts map[string]*structs.IssuedCert
|
||||
|
||||
RPC RPC // RPC client for remote requests
|
||||
Cache *cache.Cache // Cache that has CA root certs via ConnectCARoot
|
||||
}
|
||||
|
||||
func (c *ConnectCALeaf) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
|
||||
var result cache.FetchResult
|
||||
|
||||
// Get the correct type
|
||||
reqReal, ok := req.(*ConnectCALeafRequest)
|
||||
if !ok {
|
||||
return result, fmt.Errorf(
|
||||
"Internal cache failure: request wrong type: %T", req)
|
||||
}
|
||||
|
||||
// This channel watches our overall timeout. The other goroutines
|
||||
// launched in this function should end all around the same time so
|
||||
// they clean themselves up.
|
||||
timeoutCh := time.After(opts.Timeout)
|
||||
|
||||
// Kick off the goroutine that waits for new CA roots. The channel buffer
|
||||
// is so that the goroutine doesn't block forever if we return for other
|
||||
// reasons.
|
||||
newRootCACh := make(chan error, 1)
|
||||
go c.waitNewRootCA(reqReal.Datacenter, newRootCACh, opts.Timeout)
|
||||
|
||||
// Get our prior cert (if we had one) and use that to determine our
|
||||
// expiration time. If no cert exists, we expire immediately since we
|
||||
// need to generate.
|
||||
c.issuedCertsLock.RLock()
|
||||
lastCert := c.issuedCerts[reqReal.Service]
|
||||
c.issuedCertsLock.RUnlock()
|
||||
|
||||
var leafExpiryCh <-chan time.Time
|
||||
if lastCert != nil {
|
||||
// Determine how long we wait until triggering. If we've already
|
||||
// expired, we trigger immediately.
|
||||
if expiryDur := lastCert.ValidBefore.Sub(time.Now()); expiryDur > 0 {
|
||||
leafExpiryCh = time.After(expiryDur - 1*time.Hour)
|
||||
// TODO(mitchellh): 1 hour buffer is hardcoded above
|
||||
}
|
||||
}
|
||||
|
||||
if leafExpiryCh == nil {
|
||||
// If the channel is still nil then it means we need to generate
|
||||
// a cert no matter what: we either don't have an existing one or
|
||||
// it is expired.
|
||||
leafExpiryCh = time.After(0)
|
||||
}
|
||||
|
||||
// Block on the events that wake us up.
|
||||
select {
|
||||
case <-timeoutCh:
|
||||
// On a timeout, we just return the empty result and no error.
|
||||
// It isn't an error to timeout, its just the limit of time the
|
||||
// caching system wants us to block for. By returning an empty result
|
||||
// the caching system will ignore.
|
||||
return result, nil
|
||||
|
||||
case err := <-newRootCACh:
|
||||
// A new root CA triggers us to refresh the leaf certificate.
|
||||
// If there was an error while getting the root CA then we return.
|
||||
// Otherwise, we leave the select statement and move to generation.
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
case <-leafExpiryCh:
|
||||
// The existing leaf certificate is expiring soon, so we generate a
|
||||
// new cert with a healthy overlapping validity period (determined
|
||||
// by the above channel).
|
||||
}
|
||||
|
||||
// Need to lookup RootCAs response to discover trust domain. First just lookup
|
||||
// with no blocking info - this should be a cache hit most of the time.
|
||||
rawRoots, _, err := c.Cache.Get(ConnectCARootName, &structs.DCSpecificRequest{
|
||||
Datacenter: reqReal.Datacenter,
|
||||
})
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
roots, ok := rawRoots.(*structs.IndexedCARoots)
|
||||
if !ok {
|
||||
return result, errors.New("invalid RootCA response type")
|
||||
}
|
||||
if roots.TrustDomain == "" {
|
||||
return result, errors.New("cluster has no CA bootstrapped")
|
||||
}
|
||||
|
||||
// Build the service ID
|
||||
serviceID := &connect.SpiffeIDService{
|
||||
Host: roots.TrustDomain,
|
||||
Datacenter: reqReal.Datacenter,
|
||||
Namespace: "default",
|
||||
Service: reqReal.Service,
|
||||
}
|
||||
|
||||
// Create a new private key
|
||||
pk, pkPEM, err := connect.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Create a CSR.
|
||||
csr, err := connect.CreateCSR(serviceID, pk)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Request signing
|
||||
var reply structs.IssuedCert
|
||||
args := structs.CASignRequest{
|
||||
WriteRequest: structs.WriteRequest{Token: reqReal.Token},
|
||||
Datacenter: reqReal.Datacenter,
|
||||
CSR: csr,
|
||||
}
|
||||
if err := c.RPC.RPC("ConnectCA.Sign", &args, &reply); err != nil {
|
||||
return result, err
|
||||
}
|
||||
reply.PrivateKeyPEM = pkPEM
|
||||
|
||||
// Lock the issued certs map so we can insert it. We only insert if
|
||||
// we didn't happen to get a newer one. This should never happen since
|
||||
// the Cache should ensure only one Fetch per service, but we sanity
|
||||
// check just in case.
|
||||
c.issuedCertsLock.Lock()
|
||||
defer c.issuedCertsLock.Unlock()
|
||||
lastCert = c.issuedCerts[reqReal.Service]
|
||||
if lastCert == nil || lastCert.ModifyIndex < reply.ModifyIndex {
|
||||
if c.issuedCerts == nil {
|
||||
c.issuedCerts = make(map[string]*structs.IssuedCert)
|
||||
}
|
||||
|
||||
c.issuedCerts[reqReal.Service] = &reply
|
||||
lastCert = &reply
|
||||
}
|
||||
|
||||
result.Value = lastCert
|
||||
result.Index = lastCert.ModifyIndex
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// waitNewRootCA blocks until a new root CA is available or the timeout is
|
||||
// reached (on timeout ErrTimeout is returned on the channel).
|
||||
func (c *ConnectCALeaf) waitNewRootCA(datacenter string, ch chan<- error,
|
||||
timeout time.Duration) {
|
||||
// We always want to block on at least an initial value. If this isn't
|
||||
minIndex := atomic.LoadUint64(&c.caIndex)
|
||||
if minIndex == 0 {
|
||||
minIndex = 1
|
||||
}
|
||||
|
||||
// Fetch some new roots. This will block until our MinQueryIndex is
|
||||
// matched or the timeout is reached.
|
||||
rawRoots, _, err := c.Cache.Get(ConnectCARootName, &structs.DCSpecificRequest{
|
||||
Datacenter: datacenter,
|
||||
QueryOptions: structs.QueryOptions{
|
||||
MinQueryIndex: minIndex,
|
||||
MaxQueryTime: timeout,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ch <- err
|
||||
return
|
||||
}
|
||||
|
||||
roots, ok := rawRoots.(*structs.IndexedCARoots)
|
||||
if !ok {
|
||||
// This should never happen but we don't want to even risk a panic
|
||||
ch <- fmt.Errorf(
|
||||
"internal error: CA root cache returned bad type: %T", rawRoots)
|
||||
return
|
||||
}
|
||||
|
||||
// We do a loop here because there can be multiple waitNewRootCA calls
|
||||
// happening simultaneously. Each Fetch kicks off one call. These are
|
||||
// multiplexed through Cache.Get which should ensure we only ever
|
||||
// actually make a single RPC call. However, there is a race to set
|
||||
// the caIndex field so do a basic CAS loop here.
|
||||
for {
|
||||
// We only set our index if its newer than what is previously set.
|
||||
old := atomic.LoadUint64(&c.caIndex)
|
||||
if old == roots.Index || old > roots.Index {
|
||||
break
|
||||
}
|
||||
|
||||
// Set the new index atomically. If the caIndex value changed
|
||||
// in the meantime, retry.
|
||||
if atomic.CompareAndSwapUint64(&c.caIndex, old, roots.Index) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger the channel since we updated.
|
||||
ch <- nil
|
||||
}
|
||||
|
||||
// ConnectCALeafRequest is the cache.Request implementation for the
|
||||
// ConnectCALeaf cache type. This is implemented here and not in structs
|
||||
// since this is only used for cache-related requests and not forwarded
|
||||
// directly to any Consul servers.
|
||||
type ConnectCALeafRequest struct {
|
||||
Token string
|
||||
Datacenter string
|
||||
Service string // Service name, not ID
|
||||
MinQueryIndex uint64
|
||||
}
|
||||
|
||||
func (r *ConnectCALeafRequest) CacheInfo() cache.RequestInfo {
|
||||
return cache.RequestInfo{
|
||||
Token: r.Token,
|
||||
Key: r.Service,
|
||||
Datacenter: r.Datacenter,
|
||||
MinIndex: r.MinQueryIndex,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,209 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test that after an initial signing, new CA roots (new ID) will
|
||||
// trigger a blocking query to execute.
|
||||
func TestConnectCALeaf_changingRoots(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
|
||||
typ, rootsCh := testCALeafType(t, rpc)
|
||||
defer close(rootsCh)
|
||||
rootsCh <- structs.IndexedCARoots{
|
||||
ActiveRootID: "1",
|
||||
TrustDomain: "fake-trust-domain.consul",
|
||||
QueryMeta: structs.QueryMeta{Index: 1},
|
||||
}
|
||||
|
||||
// Instrument ConnectCA.Sign to return signed cert
|
||||
var resp *structs.IssuedCert
|
||||
var idx uint64
|
||||
rpc.On("RPC", "ConnectCA.Sign", mock.Anything, mock.Anything).Return(nil).
|
||||
Run(func(args mock.Arguments) {
|
||||
reply := args.Get(2).(*structs.IssuedCert)
|
||||
reply.ValidBefore = time.Now().Add(12 * time.Hour)
|
||||
reply.CreateIndex = atomic.AddUint64(&idx, 1)
|
||||
reply.ModifyIndex = reply.CreateIndex
|
||||
resp = reply
|
||||
})
|
||||
|
||||
// We'll reuse the fetch options and request
|
||||
opts := cache.FetchOptions{MinIndex: 0, Timeout: 10 * time.Second}
|
||||
req := &ConnectCALeafRequest{Datacenter: "dc1", Service: "web"}
|
||||
|
||||
// First fetch should return immediately
|
||||
fetchCh := TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("shouldn't block waiting for fetch")
|
||||
case result := <-fetchCh:
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 1,
|
||||
}, result)
|
||||
}
|
||||
|
||||
// Second fetch should block with set index
|
||||
fetchCh = TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case result := <-fetchCh:
|
||||
t.Fatalf("should not return: %#v", result)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Let's send in new roots, which should trigger the sign req
|
||||
rootsCh <- structs.IndexedCARoots{
|
||||
ActiveRootID: "2",
|
||||
TrustDomain: "fake-trust-domain.consul",
|
||||
QueryMeta: structs.QueryMeta{Index: 2},
|
||||
}
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("shouldn't block waiting for fetch")
|
||||
case result := <-fetchCh:
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 2,
|
||||
}, result)
|
||||
}
|
||||
|
||||
// Third fetch should block
|
||||
fetchCh = TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case result := <-fetchCh:
|
||||
t.Fatalf("should not return: %#v", result)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that after an initial signing, an expiringLeaf will trigger a
|
||||
// blocking query to resign.
|
||||
func TestConnectCALeaf_expiringLeaf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
|
||||
typ, rootsCh := testCALeafType(t, rpc)
|
||||
defer close(rootsCh)
|
||||
rootsCh <- structs.IndexedCARoots{
|
||||
ActiveRootID: "1",
|
||||
TrustDomain: "fake-trust-domain.consul",
|
||||
QueryMeta: structs.QueryMeta{Index: 1},
|
||||
}
|
||||
|
||||
// Instrument ConnectCA.Sign to
|
||||
var resp *structs.IssuedCert
|
||||
var idx uint64
|
||||
rpc.On("RPC", "ConnectCA.Sign", mock.Anything, mock.Anything).Return(nil).
|
||||
Run(func(args mock.Arguments) {
|
||||
reply := args.Get(2).(*structs.IssuedCert)
|
||||
reply.CreateIndex = atomic.AddUint64(&idx, 1)
|
||||
reply.ModifyIndex = reply.CreateIndex
|
||||
|
||||
// This sets the validity to 0 on the first call, and
|
||||
// 12 hours+ on subsequent calls. This means that our first
|
||||
// cert expires immediately.
|
||||
reply.ValidBefore = time.Now().Add((12 * time.Hour) *
|
||||
time.Duration(reply.CreateIndex-1))
|
||||
|
||||
resp = reply
|
||||
})
|
||||
|
||||
// We'll reuse the fetch options and request
|
||||
opts := cache.FetchOptions{MinIndex: 0, Timeout: 10 * time.Second}
|
||||
req := &ConnectCALeafRequest{Datacenter: "dc1", Service: "web"}
|
||||
|
||||
// First fetch should return immediately
|
||||
fetchCh := TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("shouldn't block waiting for fetch")
|
||||
case result := <-fetchCh:
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 1,
|
||||
}, result)
|
||||
}
|
||||
|
||||
// Second fetch should return immediately despite there being
|
||||
// no updated CA roots, because we issued an expired cert.
|
||||
fetchCh = TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("shouldn't block waiting for fetch")
|
||||
case result := <-fetchCh:
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 2,
|
||||
}, result)
|
||||
}
|
||||
|
||||
// Third fetch should block since the cert is not expiring and
|
||||
// we also didn't update CA certs.
|
||||
fetchCh = TestFetchCh(t, typ, opts, req)
|
||||
select {
|
||||
case result := <-fetchCh:
|
||||
t.Fatalf("should not return: %#v", result)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
// testCALeafType returns a *ConnectCALeaf that is pre-configured to
|
||||
// use the given RPC implementation for "ConnectCA.Sign" operations.
|
||||
func testCALeafType(t *testing.T, rpc RPC) (*ConnectCALeaf, chan structs.IndexedCARoots) {
|
||||
// This creates an RPC implementation that will block until the
|
||||
// value is sent on the channel. This lets us control when the
|
||||
// next values show up.
|
||||
rootsCh := make(chan structs.IndexedCARoots, 10)
|
||||
rootsRPC := &testGatedRootsRPC{ValueCh: rootsCh}
|
||||
|
||||
// Create a cache
|
||||
c := cache.TestCache(t)
|
||||
c.RegisterType(ConnectCARootName, &ConnectCARoot{RPC: rootsRPC}, &cache.RegisterOptions{
|
||||
// Disable refresh so that the gated channel controls the
|
||||
// request directly. Otherwise, we get background refreshes and
|
||||
// it screws up the ordering of the channel reads of the
|
||||
// testGatedRootsRPC implementation.
|
||||
Refresh: false,
|
||||
})
|
||||
|
||||
// Create the leaf type
|
||||
return &ConnectCALeaf{RPC: rpc, Cache: c}, rootsCh
|
||||
}
|
||||
|
||||
// testGatedRootsRPC will send each subsequent value on the channel as the
|
||||
// RPC response, blocking if it is waiting for a value on the channel. This
|
||||
// can be used to control when background fetches are returned and what they
|
||||
// return.
|
||||
//
|
||||
// This should be used with Refresh = false for the registration options so
|
||||
// automatic refreshes don't mess up the channel read ordering.
|
||||
type testGatedRootsRPC struct {
|
||||
ValueCh chan structs.IndexedCARoots
|
||||
}
|
||||
|
||||
func (r *testGatedRootsRPC) RPC(method string, args interface{}, reply interface{}) error {
|
||||
if method != "ConnectCA.Roots" {
|
||||
return fmt.Errorf("invalid RPC method: %s", method)
|
||||
}
|
||||
|
||||
replyReal := reply.(*structs.IndexedCARoots)
|
||||
*replyReal = <-r.ValueCh
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// Recommended name for registration.
|
||||
const ConnectCARootName = "connect-ca-root"
|
||||
|
||||
// ConnectCARoot supports fetching the Connect CA roots. This is a
|
||||
// straightforward cache type since it only has to block on the given
|
||||
// index and return the data.
|
||||
type ConnectCARoot struct {
|
||||
RPC RPC
|
||||
}
|
||||
|
||||
func (c *ConnectCARoot) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
|
||||
var result cache.FetchResult
|
||||
|
||||
// The request should be a DCSpecificRequest.
|
||||
reqReal, ok := req.(*structs.DCSpecificRequest)
|
||||
if !ok {
|
||||
return result, fmt.Errorf(
|
||||
"Internal cache failure: request wrong type: %T", req)
|
||||
}
|
||||
|
||||
// Set the minimum query index to our current index so we block
|
||||
reqReal.QueryOptions.MinQueryIndex = opts.MinIndex
|
||||
reqReal.QueryOptions.MaxQueryTime = opts.Timeout
|
||||
|
||||
// Fetch
|
||||
var reply structs.IndexedCARoots
|
||||
if err := c.RPC.RPC("ConnectCA.Roots", reqReal, &reply); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
result.Value = &reply
|
||||
result.Index = reply.QueryMeta.Index
|
||||
return result, nil
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnectCARoot(t *testing.T) {
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
typ := &ConnectCARoot{RPC: rpc}
|
||||
|
||||
// Expect the proper RPC call. This also sets the expected value
|
||||
// since that is return-by-pointer in the arguments.
|
||||
var resp *structs.IndexedCARoots
|
||||
rpc.On("RPC", "ConnectCA.Roots", mock.Anything, mock.Anything).Return(nil).
|
||||
Run(func(args mock.Arguments) {
|
||||
req := args.Get(1).(*structs.DCSpecificRequest)
|
||||
require.Equal(uint64(24), req.QueryOptions.MinQueryIndex)
|
||||
require.Equal(1*time.Second, req.QueryOptions.MaxQueryTime)
|
||||
|
||||
reply := args.Get(2).(*structs.IndexedCARoots)
|
||||
reply.QueryMeta.Index = 48
|
||||
resp = reply
|
||||
})
|
||||
|
||||
// Fetch
|
||||
result, err := typ.Fetch(cache.FetchOptions{
|
||||
MinIndex: 24,
|
||||
Timeout: 1 * time.Second,
|
||||
}, &structs.DCSpecificRequest{Datacenter: "dc1"})
|
||||
require.Nil(err)
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 48,
|
||||
}, result)
|
||||
}
|
||||
|
||||
func TestConnectCARoot_badReqType(t *testing.T) {
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
typ := &ConnectCARoot{RPC: rpc}
|
||||
|
||||
// Fetch
|
||||
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
|
||||
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
|
||||
require.NotNil(err)
|
||||
require.Contains(err.Error(), "wrong type")
|
||||
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// Recommended name for registration.
|
||||
const IntentionMatchName = "intention-match"
|
||||
|
||||
// IntentionMatch supports fetching the intentions via match queries.
|
||||
type IntentionMatch struct {
|
||||
RPC RPC
|
||||
}
|
||||
|
||||
func (c *IntentionMatch) Fetch(opts cache.FetchOptions, req cache.Request) (cache.FetchResult, error) {
|
||||
var result cache.FetchResult
|
||||
|
||||
// The request should be an IntentionQueryRequest.
|
||||
reqReal, ok := req.(*structs.IntentionQueryRequest)
|
||||
if !ok {
|
||||
return result, fmt.Errorf(
|
||||
"Internal cache failure: request wrong type: %T", req)
|
||||
}
|
||||
|
||||
// Set the minimum query index to our current index so we block
|
||||
reqReal.MinQueryIndex = opts.MinIndex
|
||||
reqReal.MaxQueryTime = opts.Timeout
|
||||
|
||||
// Fetch
|
||||
var reply structs.IndexedIntentionMatches
|
||||
if err := c.RPC.RPC("Intention.Match", reqReal, &reply); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
result.Value = &reply
|
||||
result.Index = reply.Index
|
||||
return result, nil
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIntentionMatch(t *testing.T) {
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
typ := &IntentionMatch{RPC: rpc}
|
||||
|
||||
// Expect the proper RPC call. This also sets the expected value
|
||||
// since that is return-by-pointer in the arguments.
|
||||
var resp *structs.IndexedIntentionMatches
|
||||
rpc.On("RPC", "Intention.Match", mock.Anything, mock.Anything).Return(nil).
|
||||
Run(func(args mock.Arguments) {
|
||||
req := args.Get(1).(*structs.IntentionQueryRequest)
|
||||
require.Equal(uint64(24), req.MinQueryIndex)
|
||||
require.Equal(1*time.Second, req.MaxQueryTime)
|
||||
|
||||
reply := args.Get(2).(*structs.IndexedIntentionMatches)
|
||||
reply.Index = 48
|
||||
resp = reply
|
||||
})
|
||||
|
||||
// Fetch
|
||||
result, err := typ.Fetch(cache.FetchOptions{
|
||||
MinIndex: 24,
|
||||
Timeout: 1 * time.Second,
|
||||
}, &structs.IntentionQueryRequest{Datacenter: "dc1"})
|
||||
require.NoError(err)
|
||||
require.Equal(cache.FetchResult{
|
||||
Value: resp,
|
||||
Index: 48,
|
||||
}, result)
|
||||
}
|
||||
|
||||
func TestIntentionMatch_badReqType(t *testing.T) {
|
||||
require := require.New(t)
|
||||
rpc := TestRPC(t)
|
||||
defer rpc.AssertExpectations(t)
|
||||
typ := &IntentionMatch{RPC: rpc}
|
||||
|
||||
// Fetch
|
||||
_, err := typ.Fetch(cache.FetchOptions{}, cache.TestRequest(
|
||||
t, cache.RequestInfo{Key: "foo", MinIndex: 64}))
|
||||
require.Error(err)
|
||||
require.Contains(err.Error(), "wrong type")
|
||||
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Code generated by mockery v1.0.0
|
||||
package cachetype
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockRPC is an autogenerated mock type for the RPC type
|
||||
type MockRPC struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// RPC provides a mock function with given fields: method, args, reply
|
||||
func (_m *MockRPC) RPC(method string, args interface{}, reply interface{}) error {
|
||||
ret := _m.Called(method, args, reply)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, interface{}, interface{}) error); ok {
|
||||
r0 = rf(method, args, reply)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
package cachetype
|
||||
|
||||
//go:generate mockery -all -inpkg
|
||||
|
||||
// RPC is an interface that an RPC client must implement. This is a helper
|
||||
// interface that is implemented by the agent delegate so that Type
|
||||
// implementations can request RPC access.
|
||||
type RPC interface {
|
||||
RPC(method string, args interface{}, reply interface{}) error
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package cachetype
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
)
|
||||
|
||||
// TestRPC returns a mock implementation of the RPC interface.
|
||||
func TestRPC(t testing.T) *MockRPC {
|
||||
// This function is relatively useless but this allows us to perhaps
|
||||
// perform some initialization later.
|
||||
return &MockRPC{}
|
||||
}
|
||||
|
||||
// TestFetchCh returns a channel that returns the result of the Fetch call.
|
||||
// This is useful for testing timing and concurrency with Fetch calls.
|
||||
// Errors will show up as an error type on the resulting channel so a
|
||||
// type switch should be used.
|
||||
func TestFetchCh(
|
||||
t testing.T,
|
||||
typ cache.Type,
|
||||
opts cache.FetchOptions,
|
||||
req cache.Request) <-chan interface{} {
|
||||
resultCh := make(chan interface{})
|
||||
go func() {
|
||||
result, err := typ.Fetch(opts, req)
|
||||
if err != nil {
|
||||
resultCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
resultCh <- result
|
||||
}()
|
||||
|
||||
return resultCh
|
||||
}
|
||||
|
||||
// TestFetchChResult tests that the result from TestFetchCh matches
|
||||
// within a reasonable period of time (it expects it to be "immediate" but
|
||||
// waits some milliseconds).
|
||||
func TestFetchChResult(t testing.T, ch <-chan interface{}, expected interface{}) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case result := <-ch:
|
||||
if err, ok := result.(error); ok {
|
||||
t.Fatalf("Result was error: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Fatalf("Result doesn't match!\n\n%#v\n\n%#v", result, expected)
|
||||
}
|
||||
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
|
@ -0,0 +1,536 @@
|
|||
// Package cache provides caching features for data from a Consul server.
|
||||
//
|
||||
// While this is similar in some ways to the "agent/ae" package, a key
|
||||
// difference is that with anti-entropy, the agent is the authoritative
|
||||
// source so it resolves differences the server may have. With caching (this
|
||||
// package), the server is the authoritative source and we do our best to
|
||||
// balance performance and correctness, depending on the type of data being
|
||||
// requested.
|
||||
//
|
||||
// The types of data that can be cached is configurable via the Type interface.
|
||||
// This allows specialized behavior for certain types of data. Each type of
|
||||
// Consul data (CA roots, leaf certs, intentions, KV, catalog, etc.) will
|
||||
// have to be manually implemented. This usually is not much work, see
|
||||
// the "agent/cache-types" package.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
)
|
||||
|
||||
//go:generate mockery -all -inpkg
|
||||
|
||||
// Constants related to refresh backoff. We probably don't ever need to
|
||||
// make these configurable knobs since they primarily exist to lower load.
|
||||
const (
|
||||
CacheRefreshBackoffMin = 3 // 3 attempts before backing off
|
||||
CacheRefreshMaxWait = 1 * time.Minute // maximum backoff wait time
|
||||
)
|
||||
|
||||
// Cache is a agent-local cache of Consul data. Create a Cache using the
|
||||
// New function. A zero-value Cache is not ready for usage and will result
|
||||
// in a panic.
|
||||
//
|
||||
// The types of data to be cached must be registered via RegisterType. Then,
|
||||
// calls to Get specify the type and a Request implementation. The
|
||||
// implementation of Request is usually done directly on the standard RPC
|
||||
// struct in agent/structs. This API makes cache usage a mostly drop-in
|
||||
// replacement for non-cached RPC calls.
|
||||
//
|
||||
// The cache is partitioned by ACL and datacenter. This allows the cache
|
||||
// to be safe for multi-DC queries and for queries where the data is modified
|
||||
// due to ACLs all without the cache having to have any clever logic, at
|
||||
// the slight expense of a less perfect cache.
|
||||
//
|
||||
// The Cache exposes various metrics via go-metrics. Please view the source
|
||||
// searching for "metrics." to see the various metrics exposed. These can be
|
||||
// used to explore the performance of the cache.
|
||||
type Cache struct {
|
||||
// types stores the list of data types that the cache knows how to service.
|
||||
// These can be dynamically registered with RegisterType.
|
||||
typesLock sync.RWMutex
|
||||
types map[string]typeEntry
|
||||
|
||||
// entries contains the actual cache data. Access to entries and
|
||||
// entriesExpiryHeap must be protected by entriesLock.
|
||||
//
|
||||
// entriesExpiryHeap is a heap of *cacheEntry values ordered by
|
||||
// expiry, with the soonest to expire being first in the list (index 0).
|
||||
//
|
||||
// NOTE(mitchellh): The entry map key is currently a string in the format
|
||||
// of "<DC>/<ACL token>/<Request key>" in order to properly partition
|
||||
// requests to different datacenters and ACL tokens. This format has some
|
||||
// big drawbacks: we can't evict by datacenter, ACL token, etc. For an
|
||||
// initial implementation this works and the tests are agnostic to the
|
||||
// internal storage format so changing this should be possible safely.
|
||||
entriesLock sync.RWMutex
|
||||
entries map[string]cacheEntry
|
||||
entriesExpiryHeap *expiryHeap
|
||||
}
|
||||
|
||||
// typeEntry is a single type that is registered with a Cache.
|
||||
type typeEntry struct {
|
||||
Type Type
|
||||
Opts *RegisterOptions
|
||||
}
|
||||
|
||||
// ResultMeta is returned from Get calls along with the value and can be used
|
||||
// to expose information about the cache status for debugging or testing.
|
||||
type ResultMeta struct {
|
||||
// Return whether or not the request was a cache hit
|
||||
Hit bool
|
||||
}
|
||||
|
||||
// Options are options for the Cache.
|
||||
type Options struct {
|
||||
// Nothing currently, reserved.
|
||||
}
|
||||
|
||||
// New creates a new cache with the given RPC client and reasonable defaults.
|
||||
// Further settings can be tweaked on the returned value.
|
||||
func New(*Options) *Cache {
|
||||
// Initialize the heap. The buffer of 1 is really important because
|
||||
// its possible for the expiry loop to trigger the heap to update
|
||||
// itself and it'd block forever otherwise.
|
||||
h := &expiryHeap{NotifyCh: make(chan struct{}, 1)}
|
||||
heap.Init(h)
|
||||
|
||||
c := &Cache{
|
||||
types: make(map[string]typeEntry),
|
||||
entries: make(map[string]cacheEntry),
|
||||
entriesExpiryHeap: h,
|
||||
}
|
||||
|
||||
// Start the expiry watcher
|
||||
go c.runExpiryLoop()
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// RegisterOptions are options that can be associated with a type being
|
||||
// registered for the cache. This changes the behavior of the cache for
|
||||
// this type.
|
||||
type RegisterOptions struct {
|
||||
// LastGetTTL is the time that the values returned by this type remain
|
||||
// in the cache after the last get operation. If a value isn't accessed
|
||||
// within this duration, the value is purged from the cache and
|
||||
// background refreshing will cease.
|
||||
LastGetTTL time.Duration
|
||||
|
||||
// Refresh configures whether the data is actively refreshed or if
|
||||
// the data is only refreshed on an explicit Get. The default (false)
|
||||
// is to only request data on explicit Get.
|
||||
Refresh bool
|
||||
|
||||
// RefreshTimer is the time between attempting to refresh data.
|
||||
// If this is zero, then data is refreshed immediately when a fetch
|
||||
// is returned.
|
||||
//
|
||||
// RefreshTimeout determines the maximum query time for a refresh
|
||||
// operation. This is specified as part of the query options and is
|
||||
// expected to be implemented by the Type itself.
|
||||
//
|
||||
// Using these values, various "refresh" mechanisms can be implemented:
|
||||
//
|
||||
// * With a high timer duration and a low timeout, a timer-based
|
||||
// refresh can be set that minimizes load on the Consul servers.
|
||||
//
|
||||
// * With a low timer and high timeout duration, a blocking-query-based
|
||||
// refresh can be set so that changes in server data are recognized
|
||||
// within the cache very quickly.
|
||||
//
|
||||
RefreshTimer time.Duration
|
||||
RefreshTimeout time.Duration
|
||||
}
|
||||
|
||||
// RegisterType registers a cacheable type.
|
||||
//
|
||||
// This makes the type available for Get but does not automatically perform
|
||||
// any prefetching. In order to populate the cache, Get must be called.
|
||||
func (c *Cache) RegisterType(n string, typ Type, opts *RegisterOptions) {
|
||||
if opts == nil {
|
||||
opts = &RegisterOptions{}
|
||||
}
|
||||
if opts.LastGetTTL == 0 {
|
||||
opts.LastGetTTL = 72 * time.Hour // reasonable default is days
|
||||
}
|
||||
|
||||
c.typesLock.Lock()
|
||||
defer c.typesLock.Unlock()
|
||||
c.types[n] = typeEntry{Type: typ, Opts: opts}
|
||||
}
|
||||
|
||||
// Get loads the data for the given type and request. If data satisfying the
|
||||
// minimum index is present in the cache, it is returned immediately. Otherwise,
|
||||
// this will block until the data is available or the request timeout is
|
||||
// reached.
|
||||
//
|
||||
// Multiple Get calls for the same Request (matching CacheKey value) will
|
||||
// block on a single network request.
|
||||
//
|
||||
// The timeout specified by the Request will be the timeout on the cache
|
||||
// Get, and does not correspond to the timeout of any background data
|
||||
// fetching. If the timeout is reached before data satisfying the minimum
|
||||
// index is retrieved, the last known value (maybe nil) is returned. No
|
||||
// error is returned on timeout. This matches the behavior of Consul blocking
|
||||
// queries.
|
||||
func (c *Cache) Get(t string, r Request) (interface{}, ResultMeta, error) {
|
||||
info := r.CacheInfo()
|
||||
if info.Key == "" {
|
||||
metrics.IncrCounter([]string{"consul", "cache", "bypass"}, 1)
|
||||
|
||||
// If no key is specified, then we do not cache this request.
|
||||
// Pass directly through to the backend.
|
||||
return c.fetchDirect(t, r)
|
||||
}
|
||||
|
||||
// Get the actual key for our entry
|
||||
key := c.entryKey(&info)
|
||||
|
||||
// First time through
|
||||
first := true
|
||||
|
||||
// timeoutCh for watching our timeout
|
||||
var timeoutCh <-chan time.Time
|
||||
|
||||
RETRY_GET:
|
||||
// Get the current value
|
||||
c.entriesLock.RLock()
|
||||
entry, ok := c.entries[key]
|
||||
c.entriesLock.RUnlock()
|
||||
|
||||
// If we have a current value and the index is greater than the
|
||||
// currently stored index then we return that right away. If the
|
||||
// index is zero and we have something in the cache we accept whatever
|
||||
// we have.
|
||||
if ok && entry.Valid {
|
||||
if info.MinIndex == 0 || info.MinIndex < entry.Index {
|
||||
meta := ResultMeta{}
|
||||
if first {
|
||||
metrics.IncrCounter([]string{"consul", "cache", t, "hit"}, 1)
|
||||
meta.Hit = true
|
||||
}
|
||||
|
||||
// Touch the expiration and fix the heap.
|
||||
c.entriesLock.Lock()
|
||||
entry.Expiry.Reset()
|
||||
c.entriesExpiryHeap.Fix(entry.Expiry)
|
||||
c.entriesLock.Unlock()
|
||||
|
||||
// We purposely do not return an error here since the cache
|
||||
// only works with fetching values that either have a value
|
||||
// or have an error, but not both. The Error may be non-nil
|
||||
// in the entry because of this to note future fetch errors.
|
||||
return entry.Value, meta, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If this isn't our first time through and our last value has an error,
|
||||
// then we return the error. This has the behavior that we don't sit in
|
||||
// a retry loop getting the same error for the entire duration of the
|
||||
// timeout. Instead, we make one effort to fetch a new value, and if
|
||||
// there was an error, we return.
|
||||
if !first && entry.Error != nil {
|
||||
return entry.Value, ResultMeta{}, entry.Error
|
||||
}
|
||||
|
||||
if first {
|
||||
// We increment two different counters for cache misses depending on
|
||||
// whether we're missing because we didn't have the data at all,
|
||||
// or if we're missing because we're blocking on a set index.
|
||||
if info.MinIndex == 0 {
|
||||
metrics.IncrCounter([]string{"consul", "cache", t, "miss_new"}, 1)
|
||||
} else {
|
||||
metrics.IncrCounter([]string{"consul", "cache", t, "miss_block"}, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// No longer our first time through
|
||||
first = false
|
||||
|
||||
// Set our timeout channel if we must
|
||||
if info.Timeout > 0 && timeoutCh == nil {
|
||||
timeoutCh = time.After(info.Timeout)
|
||||
}
|
||||
|
||||
// At this point, we know we either don't have a value at all or the
|
||||
// value we have is too old. We need to wait for new data.
|
||||
waiterCh, err := c.fetch(t, key, r, true, 0)
|
||||
if err != nil {
|
||||
return nil, ResultMeta{}, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-waiterCh:
|
||||
// Our fetch returned, retry the get from the cache
|
||||
goto RETRY_GET
|
||||
|
||||
case <-timeoutCh:
|
||||
// Timeout on the cache read, just return whatever we have.
|
||||
return entry.Value, ResultMeta{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// entryKey returns the key for the entry in the cache. See the note
|
||||
// about the entry key format in the structure docs for Cache.
|
||||
func (c *Cache) entryKey(r *RequestInfo) string {
|
||||
return fmt.Sprintf("%s/%s/%s", r.Datacenter, r.Token, r.Key)
|
||||
}
|
||||
|
||||
// fetch triggers a new background fetch for the given Request. If a
|
||||
// background fetch is already running for a matching Request, the waiter
|
||||
// channel for that request is returned. The effect of this is that there
|
||||
// is only ever one blocking query for any matching requests.
|
||||
//
|
||||
// If allowNew is true then the fetch should create the cache entry
|
||||
// if it doesn't exist. If this is false, then fetch will do nothing
|
||||
// if the entry doesn't exist. This latter case is to support refreshing.
|
||||
func (c *Cache) fetch(t, key string, r Request, allowNew bool, attempt uint) (<-chan struct{}, error) {
|
||||
// Get the type that we're fetching
|
||||
c.typesLock.RLock()
|
||||
tEntry, ok := c.types[t]
|
||||
c.typesLock.RUnlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown type in cache: %s", t)
|
||||
}
|
||||
|
||||
// We acquire a write lock because we may have to set Fetching to true.
|
||||
c.entriesLock.Lock()
|
||||
defer c.entriesLock.Unlock()
|
||||
entry, ok := c.entries[key]
|
||||
|
||||
// If we aren't allowing new values and we don't have an existing value,
|
||||
// return immediately. We return an immediately-closed channel so nothing
|
||||
// blocks.
|
||||
if !ok && !allowNew {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// If we already have an entry and it is actively fetching, then return
|
||||
// the currently active waiter.
|
||||
if ok && entry.Fetching {
|
||||
return entry.Waiter, nil
|
||||
}
|
||||
|
||||
// If we don't have an entry, then create it. The entry must be marked
|
||||
// as invalid so that it isn't returned as a valid value for a zero index.
|
||||
if !ok {
|
||||
entry = cacheEntry{Valid: false, Waiter: make(chan struct{})}
|
||||
}
|
||||
|
||||
// Set that we're fetching to true, which makes it so that future
|
||||
// identical calls to fetch will return the same waiter rather than
|
||||
// perform multiple fetches.
|
||||
entry.Fetching = true
|
||||
c.entries[key] = entry
|
||||
metrics.SetGauge([]string{"consul", "cache", "entries_count"}, float32(len(c.entries)))
|
||||
|
||||
// The actual Fetch must be performed in a goroutine.
|
||||
go func() {
|
||||
// Start building the new entry by blocking on the fetch.
|
||||
result, err := tEntry.Type.Fetch(FetchOptions{
|
||||
MinIndex: entry.Index,
|
||||
Timeout: tEntry.Opts.RefreshTimeout,
|
||||
}, r)
|
||||
|
||||
// Copy the existing entry to start.
|
||||
newEntry := entry
|
||||
newEntry.Fetching = false
|
||||
if result.Value != nil {
|
||||
// A new value was given, so we create a brand new entry.
|
||||
newEntry.Value = result.Value
|
||||
newEntry.Index = result.Index
|
||||
if newEntry.Index < 1 {
|
||||
// Less than one is invalid unless there was an error and in this case
|
||||
// there wasn't since a value was returned. If a badly behaved RPC
|
||||
// returns 0 when it has no data, we might get into a busy loop here. We
|
||||
// set this to minimum of 1 which is safe because no valid user data can
|
||||
// ever be written at raft index 1 due to the bootstrap process for
|
||||
// raft. This insure that any subsequent background refresh request will
|
||||
// always block, but allows the initial request to return immediately
|
||||
// even if there is no data.
|
||||
newEntry.Index = 1
|
||||
}
|
||||
|
||||
// This is a valid entry with a result
|
||||
newEntry.Valid = true
|
||||
}
|
||||
|
||||
// Error handling
|
||||
if err == nil {
|
||||
metrics.IncrCounter([]string{"consul", "cache", "fetch_success"}, 1)
|
||||
metrics.IncrCounter([]string{"consul", "cache", t, "fetch_success"}, 1)
|
||||
|
||||
if result.Index > 0 {
|
||||
// Reset the attempts counter so we don't have any backoff
|
||||
attempt = 0
|
||||
} else {
|
||||
// Result having a zero index is an implicit error case. There was no
|
||||
// actual error but it implies the RPC found in index (nothing written
|
||||
// yet for that type) but didn't take care to return safe "1" index. We
|
||||
// don't want to actually treat it like an error by setting
|
||||
// newEntry.Error to something non-nil, but we should guard against 100%
|
||||
// CPU burn hot loops caused by that case which will never block but
|
||||
// also won't backoff either. So we treat it as a failed attempt so that
|
||||
// at least the failure backoff will save our CPU while still
|
||||
// periodically refreshing so normal service can resume when the servers
|
||||
// actually have something to return from the RPC. If we get in this
|
||||
// state it can be considered a bug in the RPC implementation (to ever
|
||||
// return a zero index) however since it can happen this is a safety net
|
||||
// for the future.
|
||||
attempt++
|
||||
}
|
||||
} else {
|
||||
metrics.IncrCounter([]string{"consul", "cache", "fetch_error"}, 1)
|
||||
metrics.IncrCounter([]string{"consul", "cache", t, "fetch_error"}, 1)
|
||||
|
||||
// Increment attempt counter
|
||||
attempt++
|
||||
|
||||
// Always set the error. We don't override the value here because
|
||||
// if Valid is true, then we can reuse the Value in the case a
|
||||
// specific index isn't requested. However, for blocking queries,
|
||||
// we want Error to be set so that we can return early with the
|
||||
// error.
|
||||
newEntry.Error = err
|
||||
}
|
||||
|
||||
// Create a new waiter that will be used for the next fetch.
|
||||
newEntry.Waiter = make(chan struct{})
|
||||
|
||||
// Set our entry
|
||||
c.entriesLock.Lock()
|
||||
|
||||
// If this is a new entry (not in the heap yet), then setup the
|
||||
// initial expiry information and insert. If we're already in
|
||||
// the heap we do nothing since we're reusing the same entry.
|
||||
if newEntry.Expiry == nil || newEntry.Expiry.HeapIndex == -1 {
|
||||
newEntry.Expiry = &cacheEntryExpiry{
|
||||
Key: key,
|
||||
TTL: tEntry.Opts.LastGetTTL,
|
||||
}
|
||||
newEntry.Expiry.Reset()
|
||||
heap.Push(c.entriesExpiryHeap, newEntry.Expiry)
|
||||
}
|
||||
|
||||
c.entries[key] = newEntry
|
||||
c.entriesLock.Unlock()
|
||||
|
||||
// Trigger the old waiter
|
||||
close(entry.Waiter)
|
||||
|
||||
// If refresh is enabled, run the refresh in due time. The refresh
|
||||
// below might block, but saves us from spawning another goroutine.
|
||||
if tEntry.Opts.Refresh {
|
||||
c.refresh(tEntry.Opts, attempt, t, key, r)
|
||||
}
|
||||
}()
|
||||
|
||||
return entry.Waiter, nil
|
||||
}
|
||||
|
||||
// fetchDirect fetches the given request with no caching. Because this
|
||||
// bypasses the caching entirely, multiple matching requests will result
|
||||
// in multiple actual RPC calls (unlike fetch).
|
||||
func (c *Cache) fetchDirect(t string, r Request) (interface{}, ResultMeta, error) {
|
||||
// Get the type that we're fetching
|
||||
c.typesLock.RLock()
|
||||
tEntry, ok := c.types[t]
|
||||
c.typesLock.RUnlock()
|
||||
if !ok {
|
||||
return nil, ResultMeta{}, fmt.Errorf("unknown type in cache: %s", t)
|
||||
}
|
||||
|
||||
// Fetch it with the min index specified directly by the request.
|
||||
result, err := tEntry.Type.Fetch(FetchOptions{
|
||||
MinIndex: r.CacheInfo().MinIndex,
|
||||
}, r)
|
||||
if err != nil {
|
||||
return nil, ResultMeta{}, err
|
||||
}
|
||||
|
||||
// Return the result and ignore the rest
|
||||
return result.Value, ResultMeta{}, nil
|
||||
}
|
||||
|
||||
// refresh triggers a fetch for a specific Request according to the
|
||||
// registration options.
|
||||
func (c *Cache) refresh(opts *RegisterOptions, attempt uint, t string, key string, r Request) {
|
||||
// Sanity-check, we should not schedule anything that has refresh disabled
|
||||
if !opts.Refresh {
|
||||
return
|
||||
}
|
||||
|
||||
// If we're over the attempt minimum, start an exponential backoff.
|
||||
if attempt > CacheRefreshBackoffMin {
|
||||
waitTime := (1 << (attempt - CacheRefreshBackoffMin)) * time.Second
|
||||
if waitTime > CacheRefreshMaxWait {
|
||||
waitTime = CacheRefreshMaxWait
|
||||
}
|
||||
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
|
||||
// If we have a timer, wait for it
|
||||
if opts.RefreshTimer > 0 {
|
||||
time.Sleep(opts.RefreshTimer)
|
||||
}
|
||||
|
||||
// Trigger. The "allowNew" field is false because in the time we were
|
||||
// waiting to refresh we may have expired and got evicted. If that
|
||||
// happened, we don't want to create a new entry.
|
||||
c.fetch(t, key, r, false, attempt)
|
||||
}
|
||||
|
||||
// runExpiryLoop is a blocking function that watches the expiration
|
||||
// heap and invalidates entries that have expired.
|
||||
func (c *Cache) runExpiryLoop() {
|
||||
var expiryTimer *time.Timer
|
||||
for {
|
||||
// If we have a previous timer, stop it.
|
||||
if expiryTimer != nil {
|
||||
expiryTimer.Stop()
|
||||
}
|
||||
|
||||
// Get the entry expiring soonest
|
||||
var entry *cacheEntryExpiry
|
||||
var expiryCh <-chan time.Time
|
||||
c.entriesLock.RLock()
|
||||
if len(c.entriesExpiryHeap.Entries) > 0 {
|
||||
entry = c.entriesExpiryHeap.Entries[0]
|
||||
expiryTimer = time.NewTimer(entry.Expires.Sub(time.Now()))
|
||||
expiryCh = expiryTimer.C
|
||||
}
|
||||
c.entriesLock.RUnlock()
|
||||
|
||||
select {
|
||||
case <-c.entriesExpiryHeap.NotifyCh:
|
||||
// Entries changed, so the heap may have changed. Restart loop.
|
||||
|
||||
case <-expiryCh:
|
||||
c.entriesLock.Lock()
|
||||
|
||||
// Entry expired! Remove it.
|
||||
delete(c.entries, entry.Key)
|
||||
heap.Remove(c.entriesExpiryHeap, entry.HeapIndex)
|
||||
|
||||
// This is subtle but important: if we race and simultaneously
|
||||
// evict and fetch a new value, then we set this to -1 to
|
||||
// have it treated as a new value so that the TTL is extended.
|
||||
entry.HeapIndex = -1
|
||||
|
||||
// Set some metrics
|
||||
metrics.IncrCounter([]string{"consul", "cache", "evict_expired"}, 1)
|
||||
metrics.SetGauge([]string{"consul", "cache", "entries_count"}, float32(len(c.entries)))
|
||||
|
||||
c.entriesLock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,760 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test a basic Get with no indexes (and therefore no blocking queries).
|
||||
func TestCacheGet_noIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 42}, nil).Times(1)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Get, should not fetch since we already have a satisfying value
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.True(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test a basic Get with no index and a failed fetch.
|
||||
func TestCacheGet_initError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
fetcherr := fmt.Errorf("error")
|
||||
typ.Static(FetchResult{}, fetcherr).Times(2)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.Error(err)
|
||||
require.Nil(result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Get, should fetch again since our last fetch was an error
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.Error(err)
|
||||
require.Nil(result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test a Get with a request that returns a blank cache key. This should
|
||||
// force a backend request and skip the cache entirely.
|
||||
func TestCacheGet_blankCacheKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 42}, nil).Times(2)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: ""})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Get, should not fetch since we already have a satisfying value
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test that Get blocks on the initial value
|
||||
func TestCacheGet_blockingInitSameKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
triggerCh := make(chan time.Time)
|
||||
typ.Static(FetchResult{Value: 42}, nil).WaitUntil(triggerCh).Times(1)
|
||||
|
||||
// Perform multiple gets
|
||||
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
|
||||
// They should block
|
||||
select {
|
||||
case <-getCh1:
|
||||
t.Fatal("should block (ch1)")
|
||||
case <-getCh2:
|
||||
t.Fatal("should block (ch2)")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Trigger it
|
||||
close(triggerCh)
|
||||
|
||||
// Should return
|
||||
TestCacheGetChResult(t, getCh1, 42)
|
||||
TestCacheGetChResult(t, getCh2, 42)
|
||||
}
|
||||
|
||||
// Test that Get with different cache keys both block on initial value
|
||||
// but that the fetches were both properly called.
|
||||
func TestCacheGet_blockingInitDiffKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Keep track of the keys
|
||||
var keysLock sync.Mutex
|
||||
var keys []string
|
||||
|
||||
// Configure the type
|
||||
triggerCh := make(chan time.Time)
|
||||
typ.Static(FetchResult{Value: 42}, nil).
|
||||
WaitUntil(triggerCh).
|
||||
Times(2).
|
||||
Run(func(args mock.Arguments) {
|
||||
keysLock.Lock()
|
||||
defer keysLock.Unlock()
|
||||
keys = append(keys, args.Get(1).(Request).CacheInfo().Key)
|
||||
})
|
||||
|
||||
// Perform multiple gets
|
||||
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "goodbye"}))
|
||||
|
||||
// They should block
|
||||
select {
|
||||
case <-getCh1:
|
||||
t.Fatal("should block (ch1)")
|
||||
case <-getCh2:
|
||||
t.Fatal("should block (ch2)")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Trigger it
|
||||
close(triggerCh)
|
||||
|
||||
// Should return both!
|
||||
TestCacheGetChResult(t, getCh1, 42)
|
||||
TestCacheGetChResult(t, getCh2, 42)
|
||||
|
||||
// Verify proper keys
|
||||
sort.Strings(keys)
|
||||
require.Equal([]string{"goodbye", "hello"}, keys)
|
||||
}
|
||||
|
||||
// Test a get with an index set will wait until an index that is higher
|
||||
// is set in the cache.
|
||||
func TestCacheGet_blockingIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
triggerCh := make(chan time.Time)
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 42, Index: 6}, nil).WaitUntil(triggerCh)
|
||||
|
||||
// Fetch should block
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Key: "hello", MinIndex: 5}))
|
||||
|
||||
// Should block
|
||||
select {
|
||||
case <-resultCh:
|
||||
t.Fatal("should block")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Wait a bit
|
||||
close(triggerCh)
|
||||
|
||||
// Should return
|
||||
TestCacheGetChResult(t, resultCh, 42)
|
||||
}
|
||||
|
||||
// Test a get with an index set will timeout if the fetch doesn't return
|
||||
// anything.
|
||||
func TestCacheGet_blockingIndexTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
triggerCh := make(chan time.Time)
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 42, Index: 6}, nil).WaitUntil(triggerCh)
|
||||
|
||||
// Fetch should block
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Key: "hello", MinIndex: 5, Timeout: 200 * time.Millisecond}))
|
||||
|
||||
// Should block
|
||||
select {
|
||||
case <-resultCh:
|
||||
t.Fatal("should block")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Should return after more of the timeout
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
require.Equal(t, 12, result)
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
t.Fatal("should've returned")
|
||||
}
|
||||
}
|
||||
|
||||
// Test a get with an index set with requests returning an error
|
||||
// will return that error.
|
||||
func TestCacheGet_blockingIndexError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
var retries uint32
|
||||
fetchErr := fmt.Errorf("test fetch error")
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: nil, Index: 5}, fetchErr).Run(func(args mock.Arguments) {
|
||||
atomic.AddUint32(&retries, 1)
|
||||
})
|
||||
|
||||
// First good fetch to populate catch
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Fetch should not block and should return error
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Key: "hello", MinIndex: 7, Timeout: 1 * time.Minute}))
|
||||
TestCacheGetChResult(t, resultCh, nil)
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check the number
|
||||
actual := atomic.LoadUint32(&retries)
|
||||
require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
|
||||
}
|
||||
|
||||
// Test that if a Type returns an empty value on Fetch that the previous
|
||||
// value is preserved.
|
||||
func TestCacheGet_emptyFetchResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, nil)
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 42, Index: 1}, nil).Times(1)
|
||||
typ.Static(FetchResult{Value: nil}, nil)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Get, should not fetch since we already have a satisfying value
|
||||
req = TestRequest(t, RequestInfo{
|
||||
Key: "hello", MinIndex: 1, Timeout: 100 * time.Millisecond})
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test that a type registered with a periodic refresh will perform
|
||||
// that refresh after the timer is up.
|
||||
func TestCacheGet_periodicRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
Refresh: true,
|
||||
RefreshTimer: 100 * time.Millisecond,
|
||||
RefreshTimeout: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// This is a bit weird, but we do this to ensure that the final
|
||||
// call to the Fetch (if it happens, depends on timing) just blocks.
|
||||
triggerCh := make(chan time.Time)
|
||||
defer close(triggerCh)
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 12, Index: 5}, nil).WaitUntil(triggerCh)
|
||||
|
||||
// Fetch should block
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Fetch again almost immediately should return old result
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Wait for the timer
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 12)
|
||||
}
|
||||
|
||||
// Test that a type registered with a periodic refresh will perform
|
||||
// that refresh after the timer is up.
|
||||
func TestCacheGet_periodicRefreshMultiple(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
Refresh: true,
|
||||
RefreshTimer: 0 * time.Millisecond,
|
||||
RefreshTimeout: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// This is a bit weird, but we do this to ensure that the final
|
||||
// call to the Fetch (if it happens, depends on timing) just blocks.
|
||||
trigger := make([]chan time.Time, 3)
|
||||
for i := range trigger {
|
||||
trigger[i] = make(chan time.Time)
|
||||
}
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: 12, Index: 5}, nil).Once().WaitUntil(trigger[0])
|
||||
typ.Static(FetchResult{Value: 24, Index: 6}, nil).Once().WaitUntil(trigger[1])
|
||||
typ.Static(FetchResult{Value: 42, Index: 7}, nil).WaitUntil(trigger[2])
|
||||
|
||||
// Fetch should block
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Fetch again almost immediately should return old result
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Trigger the next, sleep a bit, and verify we get the next result
|
||||
close(trigger[0])
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 12)
|
||||
|
||||
// Trigger the next, sleep a bit, and verify we get the next result
|
||||
close(trigger[1])
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 24)
|
||||
}
|
||||
|
||||
// Test that a refresh performs a backoff.
|
||||
func TestCacheGet_periodicRefreshErrorBackoff(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
Refresh: true,
|
||||
RefreshTimer: 0,
|
||||
RefreshTimeout: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Configure the type
|
||||
var retries uint32
|
||||
fetchErr := fmt.Errorf("test fetch error")
|
||||
typ.Static(FetchResult{Value: 1, Index: 4}, nil).Once()
|
||||
typ.Static(FetchResult{Value: nil, Index: 5}, fetchErr).Run(func(args mock.Arguments) {
|
||||
atomic.AddUint32(&retries, 1)
|
||||
})
|
||||
|
||||
// Fetch
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Sleep a bit. The refresh will quietly fail in the background. What we
|
||||
// want to verify is that it doesn't retry too much. "Too much" is hard
|
||||
// to measure since its CPU dependent if this test is failing. But due
|
||||
// to the short sleep below, we can calculate about what we'd expect if
|
||||
// backoff IS working.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Fetch should work, we should get a 1 still. Errors are ignored.
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 1)
|
||||
|
||||
// Check the number
|
||||
actual := atomic.LoadUint32(&retries)
|
||||
require.True(t, actual < 10, fmt.Sprintf("actual: %d", actual))
|
||||
}
|
||||
|
||||
// Test that a badly behaved RPC that returns 0 index will perform a backoff.
|
||||
func TestCacheGet_periodicRefreshBadRPCZeroIndexErrorBackoff(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
Refresh: true,
|
||||
RefreshTimer: 0,
|
||||
RefreshTimeout: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Configure the type
|
||||
var retries uint32
|
||||
typ.Static(FetchResult{Value: 0, Index: 0}, nil).Run(func(args mock.Arguments) {
|
||||
atomic.AddUint32(&retries, 1)
|
||||
})
|
||||
|
||||
// Fetch
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 0)
|
||||
|
||||
// Sleep a bit. The refresh will quietly fail in the background. What we
|
||||
// want to verify is that it doesn't retry too much. "Too much" is hard
|
||||
// to measure since its CPU dependent if this test is failing. But due
|
||||
// to the short sleep below, we can calculate about what we'd expect if
|
||||
// backoff IS working.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Fetch should work, we should get a 0 still. Errors are ignored.
|
||||
resultCh = TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 0)
|
||||
|
||||
// Check the number
|
||||
actual := atomic.LoadUint32(&retries)
|
||||
require.True(t, actual < 10, fmt.Sprintf("%d retries, should be < 10", actual))
|
||||
}
|
||||
|
||||
// Test that fetching with no index makes an initial request with no index, but
|
||||
// then ensures all background refreshes have > 0. This ensures we don't end up
|
||||
// with any index 0 loops from background refreshed while also returning
|
||||
// immediately on the initial request if there is no data written to that table
|
||||
// yet.
|
||||
func TestCacheGet_noIndexSetsOne(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
Refresh: true,
|
||||
RefreshTimer: 0,
|
||||
RefreshTimeout: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Simulate "well behaved" RPC with no data yet but returning 1
|
||||
{
|
||||
first := int32(1)
|
||||
|
||||
typ.Static(FetchResult{Value: 0, Index: 1}, nil).Run(func(args mock.Arguments) {
|
||||
opts := args.Get(0).(FetchOptions)
|
||||
isFirst := atomic.SwapInt32(&first, 0)
|
||||
if isFirst == 1 {
|
||||
assert.Equal(t, uint64(0), opts.MinIndex)
|
||||
} else {
|
||||
assert.True(t, opts.MinIndex > 0, "minIndex > 0")
|
||||
}
|
||||
})
|
||||
|
||||
// Fetch
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 0)
|
||||
|
||||
// Sleep a bit so background refresh happens
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Same for "badly behaved" RPC that returns 0 index and no data
|
||||
{
|
||||
first := int32(1)
|
||||
|
||||
typ.Static(FetchResult{Value: 0, Index: 0}, nil).Run(func(args mock.Arguments) {
|
||||
opts := args.Get(0).(FetchOptions)
|
||||
isFirst := atomic.SwapInt32(&first, 0)
|
||||
if isFirst == 1 {
|
||||
assert.Equal(t, uint64(0), opts.MinIndex)
|
||||
} else {
|
||||
assert.True(t, opts.MinIndex > 0, "minIndex > 0")
|
||||
}
|
||||
})
|
||||
|
||||
// Fetch
|
||||
resultCh := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{Key: "hello"}))
|
||||
TestCacheGetChResult(t, resultCh, 0)
|
||||
|
||||
// Sleep a bit so background refresh happens
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the backend fetch sets the proper timeout.
|
||||
func TestCacheGet_fetchTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
|
||||
// Register the type with a timeout
|
||||
timeout := 10 * time.Minute
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
RefreshTimeout: timeout,
|
||||
})
|
||||
|
||||
// Configure the type
|
||||
var actual time.Duration
|
||||
typ.Static(FetchResult{Value: 42}, nil).Times(1).Run(func(args mock.Arguments) {
|
||||
opts := args.Get(0).(FetchOptions)
|
||||
actual = opts.Timeout
|
||||
})
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Test the timeout
|
||||
require.Equal(timeout, actual)
|
||||
}
|
||||
|
||||
// Test that entries expire
|
||||
func TestCacheGet_expire(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
|
||||
// Register the type with a timeout
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
LastGetTTL: 400 * time.Millisecond,
|
||||
})
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 42}, nil).Times(2)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Get, should not fetch, verified via the mock assertions above
|
||||
req = TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.True(meta.Hit)
|
||||
|
||||
// Sleep for the expiry
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Get, should fetch
|
||||
req = TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test that entries reset their TTL on Get
|
||||
func TestCacheGet_expireResetGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
typ := TestType(t)
|
||||
defer typ.AssertExpectations(t)
|
||||
c := TestCache(t)
|
||||
|
||||
// Register the type with a timeout
|
||||
c.RegisterType("t", typ, &RegisterOptions{
|
||||
LastGetTTL: 150 * time.Millisecond,
|
||||
})
|
||||
|
||||
// Configure the type
|
||||
typ.Static(FetchResult{Value: 42}, nil).Times(2)
|
||||
|
||||
// Get, should fetch
|
||||
req := TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err := c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Fetch multiple times, where the total time is well beyond
|
||||
// the TTL. We should not trigger any fetches during this time.
|
||||
for i := 0; i < 5; i++ {
|
||||
// Sleep a bit
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Get, should not fetch
|
||||
req = TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.True(meta.Hit)
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Get, should fetch
|
||||
req = TestRequest(t, RequestInfo{Key: "hello"})
|
||||
result, meta, err = c.Get("t", req)
|
||||
require.NoError(err)
|
||||
require.Equal(42, result)
|
||||
require.False(meta.Hit)
|
||||
|
||||
// Sleep a tiny bit just to let maybe some background calls happen
|
||||
// then verify that we still only got the one call
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
typ.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Test that Get partitions the caches based on DC so two equivalent requests
|
||||
// to different datacenters are automatically cached even if their keys are
|
||||
// the same.
|
||||
func TestCacheGet_partitionDC(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", &testPartitionType{}, nil)
|
||||
|
||||
// Perform multiple gets
|
||||
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Datacenter: "dc1", Key: "hello"}))
|
||||
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Datacenter: "dc9", Key: "hello"}))
|
||||
|
||||
// Should return both!
|
||||
TestCacheGetChResult(t, getCh1, "dc1")
|
||||
TestCacheGetChResult(t, getCh2, "dc9")
|
||||
}
|
||||
|
||||
// Test that Get partitions the caches based on token so two equivalent requests
|
||||
// with different ACL tokens do not return the same result.
|
||||
func TestCacheGet_partitionToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := TestCache(t)
|
||||
c.RegisterType("t", &testPartitionType{}, nil)
|
||||
|
||||
// Perform multiple gets
|
||||
getCh1 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Token: "", Key: "hello"}))
|
||||
getCh2 := TestCacheGetCh(t, c, "t", TestRequest(t, RequestInfo{
|
||||
Token: "foo", Key: "hello"}))
|
||||
|
||||
// Should return both!
|
||||
TestCacheGetChResult(t, getCh1, "")
|
||||
TestCacheGetChResult(t, getCh2, "foo")
|
||||
}
|
||||
|
||||
// testPartitionType implements Type for testing that simply returns a value
|
||||
// comprised of the request DC and ACL token, used for testing cache
|
||||
// partitioning.
|
||||
type testPartitionType struct{}
|
||||
|
||||
func (t *testPartitionType) Fetch(opts FetchOptions, r Request) (FetchResult, error) {
|
||||
info := r.CacheInfo()
|
||||
return FetchResult{
|
||||
Value: fmt.Sprintf("%s%s", info.Datacenter, info.Token),
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,143 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"time"
|
||||
)
|
||||
|
||||
// cacheEntry stores a single cache entry.
|
||||
//
|
||||
// Note that this isn't a very optimized structure currently. There are
|
||||
// a lot of improvements that can be made here in the long term.
|
||||
type cacheEntry struct {
|
||||
// Fields pertaining to the actual value
|
||||
Value interface{}
|
||||
Error error
|
||||
Index uint64
|
||||
|
||||
// Metadata that is used for internal accounting
|
||||
Valid bool // True if the Value is set
|
||||
Fetching bool // True if a fetch is already active
|
||||
Waiter chan struct{} // Closed when this entry is invalidated
|
||||
|
||||
// Expiry contains information about the expiration of this
|
||||
// entry. This is a pointer as its shared as a value in the
|
||||
// expiryHeap as well.
|
||||
Expiry *cacheEntryExpiry
|
||||
}
|
||||
|
||||
// cacheEntryExpiry contains the expiration information for a cache
|
||||
// entry. Any modifications to this struct should be done only while
|
||||
// the Cache entriesLock is held.
|
||||
type cacheEntryExpiry struct {
|
||||
Key string // Key in the cache map
|
||||
Expires time.Time // Time when entry expires (monotonic clock)
|
||||
TTL time.Duration // TTL for this entry to extend when resetting
|
||||
HeapIndex int // Index in the heap
|
||||
}
|
||||
|
||||
// Reset resets the expiration to be the ttl duration from now.
|
||||
func (e *cacheEntryExpiry) Reset() {
|
||||
e.Expires = time.Now().Add(e.TTL)
|
||||
}
|
||||
|
||||
// expiryHeap is a heap implementation that stores information about
|
||||
// when entires expire. Implements container/heap.Interface.
|
||||
//
|
||||
// All operations on the heap and read/write of the heap contents require
|
||||
// the proper entriesLock to be held on Cache.
|
||||
type expiryHeap struct {
|
||||
Entries []*cacheEntryExpiry
|
||||
|
||||
// NotifyCh is sent a value whenever the 0 index value of the heap
|
||||
// changes. This can be used to detect when the earliest value
|
||||
// changes.
|
||||
//
|
||||
// There is a single edge case where the heap will not automatically
|
||||
// send a notification: if heap.Fix is called manually and the index
|
||||
// changed is 0 and the change doesn't result in any moves (stays at index
|
||||
// 0), then we won't detect the change. To work around this, please
|
||||
// always call the expiryHeap.Fix method instead.
|
||||
NotifyCh chan struct{}
|
||||
}
|
||||
|
||||
// Identical to heap.Fix for this heap instance but will properly handle
|
||||
// the edge case where idx == 0 and no heap modification is necessary,
|
||||
// and still notify the NotifyCh.
|
||||
//
|
||||
// This is important for cache expiry since the expiry time may have been
|
||||
// extended and if we don't send a message to the NotifyCh then we'll never
|
||||
// reset the timer and the entry will be evicted early.
|
||||
func (h *expiryHeap) Fix(entry *cacheEntryExpiry) {
|
||||
idx := entry.HeapIndex
|
||||
heap.Fix(h, idx)
|
||||
|
||||
// This is the edge case we handle: if the prev (idx) and current (HeapIndex)
|
||||
// is zero, it means the head-of-line didn't change while the value
|
||||
// changed. Notify to reset our expiry worker.
|
||||
if idx == 0 && entry.HeapIndex == 0 {
|
||||
h.notify()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *expiryHeap) Len() int { return len(h.Entries) }
|
||||
|
||||
func (h *expiryHeap) Swap(i, j int) {
|
||||
h.Entries[i], h.Entries[j] = h.Entries[j], h.Entries[i]
|
||||
h.Entries[i].HeapIndex = i
|
||||
h.Entries[j].HeapIndex = j
|
||||
|
||||
// If we're moving the 0 index, update the channel since we need
|
||||
// to re-update the timer we're waiting on for the soonest expiring
|
||||
// value.
|
||||
if i == 0 || j == 0 {
|
||||
h.notify()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *expiryHeap) Less(i, j int) bool {
|
||||
// The usage of Before here is important (despite being obvious):
|
||||
// this function uses the monotonic time that should be available
|
||||
// on the time.Time value so the heap is immune to wall clock changes.
|
||||
return h.Entries[i].Expires.Before(h.Entries[j].Expires)
|
||||
}
|
||||
|
||||
// heap.Interface, this isn't expected to be called directly.
|
||||
func (h *expiryHeap) Push(x interface{}) {
|
||||
entry := x.(*cacheEntryExpiry)
|
||||
|
||||
// Set initial heap index, if we're going to the end then Swap
|
||||
// won't be called so we need to initialize
|
||||
entry.HeapIndex = len(h.Entries)
|
||||
|
||||
// For the first entry, we need to trigger a channel send because
|
||||
// Swap won't be called; nothing to swap! We can call it right away
|
||||
// because all heap operations are within a lock.
|
||||
if len(h.Entries) == 0 {
|
||||
h.notify()
|
||||
}
|
||||
|
||||
h.Entries = append(h.Entries, entry)
|
||||
}
|
||||
|
||||
// heap.Interface, this isn't expected to be called directly.
|
||||
func (h *expiryHeap) Pop() interface{} {
|
||||
old := h.Entries
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
h.Entries = old[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
func (h *expiryHeap) notify() {
|
||||
select {
|
||||
case h.NotifyCh <- struct{}{}:
|
||||
// Good
|
||||
|
||||
default:
|
||||
// If the send would've blocked, we just ignore it. The reason this
|
||||
// is safe is because NotifyCh should always be a buffered channel.
|
||||
// If this blocks, it means that there is a pending message anyways
|
||||
// so the receiver will restart regardless.
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExpiryHeap_impl(t *testing.T) {
|
||||
var _ heap.Interface = new(expiryHeap)
|
||||
}
|
||||
|
||||
func TestExpiryHeap(t *testing.T) {
|
||||
require := require.New(t)
|
||||
now := time.Now()
|
||||
ch := make(chan struct{}, 10) // buffered to prevent blocking in tests
|
||||
h := &expiryHeap{NotifyCh: ch}
|
||||
|
||||
// Init, shouldn't trigger anything
|
||||
heap.Init(h)
|
||||
testNoMessage(t, ch)
|
||||
|
||||
// Push an initial value, expect one message
|
||||
entry := &cacheEntryExpiry{Key: "foo", HeapIndex: -1, Expires: now.Add(100)}
|
||||
heap.Push(h, entry)
|
||||
require.Equal(0, entry.HeapIndex)
|
||||
testMessage(t, ch)
|
||||
testNoMessage(t, ch) // exactly one asserted above
|
||||
|
||||
// Push another that goes earlier than entry
|
||||
entry2 := &cacheEntryExpiry{Key: "bar", HeapIndex: -1, Expires: now.Add(50)}
|
||||
heap.Push(h, entry2)
|
||||
require.Equal(0, entry2.HeapIndex)
|
||||
require.Equal(1, entry.HeapIndex)
|
||||
testMessage(t, ch)
|
||||
testNoMessage(t, ch) // exactly one asserted above
|
||||
|
||||
// Push another that goes at the end
|
||||
entry3 := &cacheEntryExpiry{Key: "bar", HeapIndex: -1, Expires: now.Add(1000)}
|
||||
heap.Push(h, entry3)
|
||||
require.Equal(2, entry3.HeapIndex)
|
||||
testNoMessage(t, ch) // no notify cause index 0 stayed the same
|
||||
|
||||
// Remove the first entry (not Pop, since we don't use Pop, but that works too)
|
||||
remove := h.Entries[0]
|
||||
heap.Remove(h, remove.HeapIndex)
|
||||
require.Equal(0, entry.HeapIndex)
|
||||
require.Equal(1, entry3.HeapIndex)
|
||||
testMessage(t, ch)
|
||||
testMessage(t, ch) // we have two because two swaps happen
|
||||
testNoMessage(t, ch)
|
||||
|
||||
// Let's change entry 3 to be early, and fix it
|
||||
entry3.Expires = now.Add(10)
|
||||
h.Fix(entry3)
|
||||
require.Equal(1, entry.HeapIndex)
|
||||
require.Equal(0, entry3.HeapIndex)
|
||||
testMessage(t, ch)
|
||||
testNoMessage(t, ch)
|
||||
|
||||
// Let's change entry 3 again, this is an edge case where if the 0th
|
||||
// element changed, we didn't trigger the channel. Our Fix func should.
|
||||
entry.Expires = now.Add(20)
|
||||
h.Fix(entry3)
|
||||
require.Equal(1, entry.HeapIndex) // no move
|
||||
require.Equal(0, entry3.HeapIndex)
|
||||
testMessage(t, ch)
|
||||
testNoMessage(t, ch) // one message
|
||||
}
|
||||
|
||||
func testNoMessage(t *testing.T, ch <-chan struct{}) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatal("should not have a message")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func testMessage(t *testing.T, ch <-chan struct{}) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
t.Fatal("should have a message")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Code generated by mockery v1.0.0
|
||||
package cache
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockRequest is an autogenerated mock type for the Request type
|
||||
type MockRequest struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// CacheInfo provides a mock function with given fields:
|
||||
func (_m *MockRequest) CacheInfo() RequestInfo {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 RequestInfo
|
||||
if rf, ok := ret.Get(0).(func() RequestInfo); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(RequestInfo)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
// Code generated by mockery v1.0.0
|
||||
package cache
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockType is an autogenerated mock type for the Type type
|
||||
type MockType struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Fetch provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockType) Fetch(_a0 FetchOptions, _a1 Request) (FetchResult, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 FetchResult
|
||||
if rf, ok := ret.Get(0).(func(FetchOptions, Request) FetchResult); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
r0 = ret.Get(0).(FetchResult)
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(FetchOptions, Request) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Request is a cacheable request.
|
||||
//
|
||||
// This interface is typically implemented by request structures in
|
||||
// the agent/structs package.
|
||||
type Request interface {
|
||||
// CacheInfo returns information used for caching this request.
|
||||
CacheInfo() RequestInfo
|
||||
}
|
||||
|
||||
// RequestInfo represents cache information for a request. The caching
|
||||
// framework uses this to control the behavior of caching and to determine
|
||||
// cacheability.
|
||||
type RequestInfo struct {
|
||||
// Key is a unique cache key for this request. This key should
|
||||
// be globally unique to identify this request, since any conflicting
|
||||
// cache keys could result in invalid data being returned from the cache.
|
||||
// The Key does not need to include ACL or DC information, since the
|
||||
// cache already partitions by these values prior to using this key.
|
||||
Key string
|
||||
|
||||
// Token is the ACL token associated with this request.
|
||||
//
|
||||
// Datacenter is the datacenter that the request is targeting.
|
||||
//
|
||||
// Both of these values are used to partition the cache. The cache framework
|
||||
// today partitions data on these values to simplify behavior: by
|
||||
// partitioning ACL tokens, the cache doesn't need to be smart about
|
||||
// filtering results. By filtering datacenter results, the cache can
|
||||
// service the multi-DC nature of Consul. This comes at the expense of
|
||||
// working set size, but in general the effect is minimal.
|
||||
Token string
|
||||
Datacenter string
|
||||
|
||||
// MinIndex is the minimum index being queried. This is used to
|
||||
// determine if we already have data satisfying the query or if we need
|
||||
// to block until new data is available. If no index is available, the
|
||||
// default value (zero) is acceptable.
|
||||
MinIndex uint64
|
||||
|
||||
// Timeout is the timeout for waiting on a blocking query. When the
|
||||
// timeout is reached, the last known value is returned (or maybe nil
|
||||
// if there was no prior value). This "last known value" behavior matches
|
||||
// normal Consul blocking queries.
|
||||
Timeout time.Duration
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// TestCache returns a Cache instance configuring for testing.
|
||||
func TestCache(t testing.T) *Cache {
|
||||
// Simple but lets us do some fine-tuning later if we want to.
|
||||
return New(nil)
|
||||
}
|
||||
|
||||
// TestCacheGetCh returns a channel that returns the result of the Get call.
|
||||
// This is useful for testing timing and concurrency with Get calls. Any
|
||||
// error will be logged, so the result value should always be asserted.
|
||||
func TestCacheGetCh(t testing.T, c *Cache, typ string, r Request) <-chan interface{} {
|
||||
resultCh := make(chan interface{})
|
||||
go func() {
|
||||
result, _, err := c.Get(typ, r)
|
||||
if err != nil {
|
||||
t.Logf("Error: %s", err)
|
||||
close(resultCh)
|
||||
return
|
||||
}
|
||||
|
||||
resultCh <- result
|
||||
}()
|
||||
|
||||
return resultCh
|
||||
}
|
||||
|
||||
// TestCacheGetChResult tests that the result from TestCacheGetCh matches
|
||||
// within a reasonable period of time (it expects it to be "immediate" but
|
||||
// waits some milliseconds).
|
||||
func TestCacheGetChResult(t testing.T, ch <-chan interface{}, expected interface{}) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case result := <-ch:
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Fatalf("Result doesn't match!\n\n%#v\n\n%#v", result, expected)
|
||||
}
|
||||
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
t.Fatalf("Result not sent on channel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequest returns a Request that returns the given cache key and index.
|
||||
// The Reset method can be called to reset it for custom usage.
|
||||
func TestRequest(t testing.T, info RequestInfo) *MockRequest {
|
||||
req := &MockRequest{}
|
||||
req.On("CacheInfo").Return(info)
|
||||
return req
|
||||
}
|
||||
|
||||
// TestType returns a MockType that can be used to setup expectations
|
||||
// on data fetching.
|
||||
func TestType(t testing.T) *MockType {
|
||||
typ := &MockType{}
|
||||
return typ
|
||||
}
|
||||
|
||||
// A bit weird, but we add methods to the auto-generated structs here so that
|
||||
// they don't get clobbered. The helper methods are conveniences.
|
||||
|
||||
// Static sets a static value to return for a call to Fetch.
|
||||
func (m *MockType) Static(r FetchResult, err error) *mock.Call {
|
||||
return m.Mock.On("Fetch", mock.Anything, mock.Anything).Return(r, err)
|
||||
}
|
||||
|
||||
func (m *MockRequest) Reset() {
|
||||
m.Mock = mock.Mock{}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Type implements the logic to fetch certain types of data.
|
||||
type Type interface {
|
||||
// Fetch fetches a single unique item.
|
||||
//
|
||||
// The FetchOptions contain the index and timeouts for blocking queries.
|
||||
// The MinIndex value on the Request itself should NOT be used
|
||||
// as the blocking index since a request may be reused multiple times
|
||||
// as part of Refresh behavior.
|
||||
//
|
||||
// The return value is a FetchResult which contains information about
|
||||
// the fetch. If an error is given, the FetchResult is ignored. The
|
||||
// cache does not support backends that return partial values.
|
||||
//
|
||||
// On timeout, FetchResult can behave one of two ways. First, it can
|
||||
// return the last known value. This is the default behavior of blocking
|
||||
// RPC calls in Consul so this allows cache types to be implemented with
|
||||
// no extra logic. Second, FetchResult can return an unset value and index.
|
||||
// In this case, the cache will reuse the last value automatically.
|
||||
Fetch(FetchOptions, Request) (FetchResult, error)
|
||||
}
|
||||
|
||||
// FetchOptions are various settable options when a Fetch is called.
|
||||
type FetchOptions struct {
|
||||
// MinIndex is the minimum index to be used for blocking queries.
|
||||
// If blocking queries aren't supported for data being returned,
|
||||
// this value can be ignored.
|
||||
MinIndex uint64
|
||||
|
||||
// Timeout is the maximum time for the query. This must be implemented
|
||||
// in the Fetch itself.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// FetchResult is the result of a Type Fetch operation and contains the
|
||||
// data along with metadata gathered from that operation.
|
||||
type FetchResult struct {
|
||||
// Value is the result of the fetch.
|
||||
Value interface{}
|
||||
|
||||
// Index is the corresponding index value for this data.
|
||||
Index uint64
|
||||
}
|
|
@ -157,12 +157,27 @@ RETRY_ONCE:
|
|||
return out.Services, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) CatalogConnectServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
return s.catalogServiceNodes(resp, req, true)
|
||||
}
|
||||
|
||||
func (s *HTTPServer) CatalogServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_service_nodes"}, 1,
|
||||
return s.catalogServiceNodes(resp, req, false)
|
||||
}
|
||||
|
||||
func (s *HTTPServer) catalogServiceNodes(resp http.ResponseWriter, req *http.Request, connect bool) (interface{}, error) {
|
||||
metricsKey := "catalog_service_nodes"
|
||||
pathPrefix := "/v1/catalog/service/"
|
||||
if connect {
|
||||
metricsKey = "catalog_connect_service_nodes"
|
||||
pathPrefix = "/v1/catalog/connect/"
|
||||
}
|
||||
|
||||
metrics.IncrCounterWithLabels([]string{"client", "api", metricsKey}, 1,
|
||||
[]metrics.Label{{Name: "node", Value: s.nodeName()}})
|
||||
|
||||
// Set default DC
|
||||
args := structs.ServiceSpecificRequest{}
|
||||
args := structs.ServiceSpecificRequest{Connect: connect}
|
||||
s.parseSource(req, &args.Source)
|
||||
args.NodeMetaFilters = s.parseMetaFilter(req)
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
|
@ -177,7 +192,7 @@ func (s *HTTPServer) CatalogServiceNodes(resp http.ResponseWriter, req *http.Req
|
|||
}
|
||||
|
||||
// Pull out the service name
|
||||
args.ServiceName = strings.TrimPrefix(req.URL.Path, "/v1/catalog/service/")
|
||||
args.ServiceName = strings.TrimPrefix(req.URL.Path, pathPrefix)
|
||||
if args.ServiceName == "" {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, "Missing service name")
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/testutil/retry"
|
||||
"github.com/hashicorp/serf/coordinate"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCatalogRegister_Service_InvalidAddress(t *testing.T) {
|
||||
|
@ -750,6 +751,60 @@ func TestCatalogServiceNodes_DistanceSort(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test that connect proxies can be queried via /v1/catalog/service/:service
|
||||
// directly and that their results contain the proxy fields.
|
||||
func TestCatalogServiceNodes_ConnectProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/catalog/service/%s", args.Service.Service), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.CatalogServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
nodes := obj.(structs.ServiceNodes)
|
||||
assert.Len(nodes, 1)
|
||||
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
|
||||
}
|
||||
|
||||
// Test that the Connect-compatible endpoints can be queried for a
|
||||
// service via /v1/catalog/connect/:service.
|
||||
func TestCatalogConnectServiceNodes_good(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Address = "127.0.0.55"
|
||||
var out struct{}
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/catalog/connect/%s", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.CatalogConnectServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
nodes := obj.(structs.ServiceNodes)
|
||||
assert.Len(nodes, 1)
|
||||
assert.Equal(structs.ServiceKindConnectProxy, nodes[0].ServiceKind)
|
||||
assert.Equal(args.Service.Address, nodes[0].ServiceAddress)
|
||||
}
|
||||
|
||||
func TestCatalogNodeServices(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
|
@ -785,6 +840,33 @@ func TestCatalogNodeServices(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test that the services on a node contain all the Connect proxies on
|
||||
// the node as well with their fields properly populated.
|
||||
func TestCatalogNodeServices_ConnectProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/catalog/node/%s", args.Node), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.CatalogNodeServices(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
ns := obj.(*structs.NodeServices)
|
||||
assert.Len(ns.Services, 1)
|
||||
v := ns.Services[args.Service.Service]
|
||||
assert.Equal(structs.ServiceKindConnectProxy, v.Kind)
|
||||
}
|
||||
|
||||
func TestCatalogNodeServices_WanTranslation(t *testing.T) {
|
||||
t.Parallel()
|
||||
a1 := NewTestAgent(t.Name(), `
|
||||
|
|
|
@ -14,9 +14,11 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/consul"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/ipaddr"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"github.com/hashicorp/consul/types"
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
|
@ -340,6 +342,12 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
|
|||
serverPort := b.portVal("ports.server", c.Ports.Server)
|
||||
serfPortLAN := b.portVal("ports.serf_lan", c.Ports.SerfLAN)
|
||||
serfPortWAN := b.portVal("ports.serf_wan", c.Ports.SerfWAN)
|
||||
proxyMinPort := b.portVal("ports.proxy_min_port", c.Ports.ProxyMinPort)
|
||||
proxyMaxPort := b.portVal("ports.proxy_max_port", c.Ports.ProxyMaxPort)
|
||||
if proxyMaxPort < proxyMinPort {
|
||||
return RuntimeConfig{}, fmt.Errorf(
|
||||
"proxy_min_port must be less than proxy_max_port. To disable, set both to zero.")
|
||||
}
|
||||
|
||||
// determine the default bind and advertise address
|
||||
//
|
||||
|
@ -520,6 +528,30 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
|
|||
consulRaftHeartbeatTimeout := b.durationVal("consul.raft.heartbeat_timeout", c.Consul.Raft.HeartbeatTimeout) * time.Duration(performanceRaftMultiplier)
|
||||
consulRaftLeaderLeaseTimeout := b.durationVal("consul.raft.leader_lease_timeout", c.Consul.Raft.LeaderLeaseTimeout) * time.Duration(performanceRaftMultiplier)
|
||||
|
||||
// Connect proxy defaults.
|
||||
connectEnabled := b.boolVal(c.Connect.Enabled)
|
||||
connectCAProvider := b.stringVal(c.Connect.CAProvider)
|
||||
connectCAConfig := c.Connect.CAConfig
|
||||
if connectCAConfig != nil {
|
||||
TranslateKeys(connectCAConfig, map[string]string{
|
||||
// Consul CA config
|
||||
"private_key": "PrivateKey",
|
||||
"root_cert": "RootCert",
|
||||
"rotation_period": "RotationPeriod",
|
||||
|
||||
// Vault CA config
|
||||
"address": "Address",
|
||||
"token": "Token",
|
||||
"root_pki_path": "RootPKIPath",
|
||||
"intermediate_pki_path": "IntermediatePKIPath",
|
||||
})
|
||||
}
|
||||
|
||||
proxyDefaultExecMode := b.stringVal(c.Connect.ProxyDefaults.ExecMode)
|
||||
proxyDefaultDaemonCommand := c.Connect.ProxyDefaults.DaemonCommand
|
||||
proxyDefaultScriptCommand := c.Connect.ProxyDefaults.ScriptCommand
|
||||
proxyDefaultConfig := c.Connect.ProxyDefaults.Config
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
// build runtime config
|
||||
//
|
||||
|
@ -603,29 +635,31 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
|
|||
HTTPResponseHeaders: c.HTTPConfig.ResponseHeaders,
|
||||
|
||||
// Telemetry
|
||||
TelemetryCirconusAPIApp: b.stringVal(c.Telemetry.CirconusAPIApp),
|
||||
TelemetryCirconusAPIToken: b.stringVal(c.Telemetry.CirconusAPIToken),
|
||||
TelemetryCirconusAPIURL: b.stringVal(c.Telemetry.CirconusAPIURL),
|
||||
TelemetryCirconusBrokerID: b.stringVal(c.Telemetry.CirconusBrokerID),
|
||||
TelemetryCirconusBrokerSelectTag: b.stringVal(c.Telemetry.CirconusBrokerSelectTag),
|
||||
TelemetryCirconusCheckDisplayName: b.stringVal(c.Telemetry.CirconusCheckDisplayName),
|
||||
TelemetryCirconusCheckForceMetricActivation: b.stringVal(c.Telemetry.CirconusCheckForceMetricActivation),
|
||||
TelemetryCirconusCheckID: b.stringVal(c.Telemetry.CirconusCheckID),
|
||||
TelemetryCirconusCheckInstanceID: b.stringVal(c.Telemetry.CirconusCheckInstanceID),
|
||||
TelemetryCirconusCheckSearchTag: b.stringVal(c.Telemetry.CirconusCheckSearchTag),
|
||||
TelemetryCirconusCheckTags: b.stringVal(c.Telemetry.CirconusCheckTags),
|
||||
TelemetryCirconusSubmissionInterval: b.stringVal(c.Telemetry.CirconusSubmissionInterval),
|
||||
TelemetryCirconusSubmissionURL: b.stringVal(c.Telemetry.CirconusSubmissionURL),
|
||||
TelemetryDisableHostname: b.boolVal(c.Telemetry.DisableHostname),
|
||||
TelemetryDogstatsdAddr: b.stringVal(c.Telemetry.DogstatsdAddr),
|
||||
TelemetryDogstatsdTags: c.Telemetry.DogstatsdTags,
|
||||
TelemetryPrometheusRetentionTime: b.durationVal("prometheus_retention_time", c.Telemetry.PrometheusRetentionTime),
|
||||
TelemetryFilterDefault: b.boolVal(c.Telemetry.FilterDefault),
|
||||
TelemetryAllowedPrefixes: telemetryAllowedPrefixes,
|
||||
TelemetryBlockedPrefixes: telemetryBlockedPrefixes,
|
||||
TelemetryMetricsPrefix: b.stringVal(c.Telemetry.MetricsPrefix),
|
||||
TelemetryStatsdAddr: b.stringVal(c.Telemetry.StatsdAddr),
|
||||
TelemetryStatsiteAddr: b.stringVal(c.Telemetry.StatsiteAddr),
|
||||
Telemetry: lib.TelemetryConfig{
|
||||
CirconusAPIApp: b.stringVal(c.Telemetry.CirconusAPIApp),
|
||||
CirconusAPIToken: b.stringVal(c.Telemetry.CirconusAPIToken),
|
||||
CirconusAPIURL: b.stringVal(c.Telemetry.CirconusAPIURL),
|
||||
CirconusBrokerID: b.stringVal(c.Telemetry.CirconusBrokerID),
|
||||
CirconusBrokerSelectTag: b.stringVal(c.Telemetry.CirconusBrokerSelectTag),
|
||||
CirconusCheckDisplayName: b.stringVal(c.Telemetry.CirconusCheckDisplayName),
|
||||
CirconusCheckForceMetricActivation: b.stringVal(c.Telemetry.CirconusCheckForceMetricActivation),
|
||||
CirconusCheckID: b.stringVal(c.Telemetry.CirconusCheckID),
|
||||
CirconusCheckInstanceID: b.stringVal(c.Telemetry.CirconusCheckInstanceID),
|
||||
CirconusCheckSearchTag: b.stringVal(c.Telemetry.CirconusCheckSearchTag),
|
||||
CirconusCheckTags: b.stringVal(c.Telemetry.CirconusCheckTags),
|
||||
CirconusSubmissionInterval: b.stringVal(c.Telemetry.CirconusSubmissionInterval),
|
||||
CirconusSubmissionURL: b.stringVal(c.Telemetry.CirconusSubmissionURL),
|
||||
DisableHostname: b.boolVal(c.Telemetry.DisableHostname),
|
||||
DogstatsdAddr: b.stringVal(c.Telemetry.DogstatsdAddr),
|
||||
DogstatsdTags: c.Telemetry.DogstatsdTags,
|
||||
PrometheusRetentionTime: b.durationVal("prometheus_retention_time", c.Telemetry.PrometheusRetentionTime),
|
||||
FilterDefault: b.boolVal(c.Telemetry.FilterDefault),
|
||||
AllowedPrefixes: telemetryAllowedPrefixes,
|
||||
BlockedPrefixes: telemetryBlockedPrefixes,
|
||||
MetricsPrefix: b.stringVal(c.Telemetry.MetricsPrefix),
|
||||
StatsdAddr: b.stringVal(c.Telemetry.StatsdAddr),
|
||||
StatsiteAddr: b.stringVal(c.Telemetry.StatsiteAddr),
|
||||
},
|
||||
|
||||
// Agent
|
||||
AdvertiseAddrLAN: advertiseAddrLAN,
|
||||
|
@ -639,6 +673,17 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
|
|||
CheckUpdateInterval: b.durationVal("check_update_interval", c.CheckUpdateInterval),
|
||||
Checks: checks,
|
||||
ClientAddrs: clientAddrs,
|
||||
ConnectEnabled: connectEnabled,
|
||||
ConnectCAProvider: connectCAProvider,
|
||||
ConnectCAConfig: connectCAConfig,
|
||||
ConnectProxyAllowManagedRoot: b.boolVal(c.Connect.Proxy.AllowManagedRoot),
|
||||
ConnectProxyAllowManagedAPIRegistration: b.boolVal(c.Connect.Proxy.AllowManagedAPIRegistration),
|
||||
ConnectProxyBindMinPort: proxyMinPort,
|
||||
ConnectProxyBindMaxPort: proxyMaxPort,
|
||||
ConnectProxyDefaultExecMode: proxyDefaultExecMode,
|
||||
ConnectProxyDefaultDaemonCommand: proxyDefaultDaemonCommand,
|
||||
ConnectProxyDefaultScriptCommand: proxyDefaultScriptCommand,
|
||||
ConnectProxyDefaultConfig: proxyDefaultConfig,
|
||||
DataDir: b.stringVal(c.DataDir),
|
||||
Datacenter: strings.ToLower(b.stringVal(c.Datacenter)),
|
||||
DevMode: b.boolVal(b.Flags.DevMode),
|
||||
|
@ -882,6 +927,34 @@ func (b *Builder) Validate(rt RuntimeConfig) error {
|
|||
return b.err
|
||||
}
|
||||
|
||||
// Check for errors in the service definitions
|
||||
for _, s := range rt.Services {
|
||||
if err := s.Validate(); err != nil {
|
||||
return fmt.Errorf("service %q: %s", s.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the given Connect CA provider config
|
||||
validCAProviders := map[string]bool{
|
||||
"": true,
|
||||
structs.ConsulCAProvider: true,
|
||||
structs.VaultCAProvider: true,
|
||||
}
|
||||
if _, ok := validCAProviders[rt.ConnectCAProvider]; !ok {
|
||||
return fmt.Errorf("%s is not a valid CA provider", rt.ConnectCAProvider)
|
||||
} else {
|
||||
switch rt.ConnectCAProvider {
|
||||
case structs.ConsulCAProvider:
|
||||
if _, err := ca.ParseConsulCAConfig(rt.ConnectCAConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
case structs.VaultCAProvider:
|
||||
if _, err := ca.ParseVaultCAConfig(rt.ConnectCAConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
// warnings
|
||||
//
|
||||
|
@ -1008,6 +1081,26 @@ func (b *Builder) serviceVal(v *ServiceDefinition) *structs.ServiceDefinition {
|
|||
Token: b.stringVal(v.Token),
|
||||
EnableTagOverride: b.boolVal(v.EnableTagOverride),
|
||||
Checks: checks,
|
||||
Connect: b.serviceConnectVal(v.Connect),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Builder) serviceConnectVal(v *ServiceConnect) *structs.ServiceConnect {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var proxy *structs.ServiceDefinitionConnectProxy
|
||||
if v.Proxy != nil {
|
||||
proxy = &structs.ServiceDefinitionConnectProxy{
|
||||
ExecMode: b.stringVal(v.Proxy.ExecMode),
|
||||
Command: v.Proxy.Command,
|
||||
Config: v.Proxy.Config,
|
||||
}
|
||||
}
|
||||
|
||||
return &structs.ServiceConnect{
|
||||
Proxy: proxy,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -84,6 +84,7 @@ func Parse(data string, format string) (c Config, err error) {
|
|||
"services",
|
||||
"services.checks",
|
||||
"watches",
|
||||
"service.connect.proxy.config.upstreams",
|
||||
})
|
||||
|
||||
// There is a difference of representation of some fields depending on
|
||||
|
@ -159,6 +160,7 @@ type Config struct {
|
|||
CheckUpdateInterval *string `json:"check_update_interval,omitempty" hcl:"check_update_interval" mapstructure:"check_update_interval"`
|
||||
Checks []CheckDefinition `json:"checks,omitempty" hcl:"checks" mapstructure:"checks"`
|
||||
ClientAddr *string `json:"client_addr,omitempty" hcl:"client_addr" mapstructure:"client_addr"`
|
||||
Connect Connect `json:"connect,omitempty" hcl:"connect" mapstructure:"connect"`
|
||||
DNS DNS `json:"dns_config,omitempty" hcl:"dns_config" mapstructure:"dns_config"`
|
||||
DNSDomain *string `json:"domain,omitempty" hcl:"domain" mapstructure:"domain"`
|
||||
DNSRecursors []string `json:"recursors,omitempty" hcl:"recursors" mapstructure:"recursors"`
|
||||
|
@ -324,6 +326,7 @@ type ServiceDefinition struct {
|
|||
Checks []CheckDefinition `json:"checks,omitempty" hcl:"checks" mapstructure:"checks"`
|
||||
Token *string `json:"token,omitempty" hcl:"token" mapstructure:"token"`
|
||||
EnableTagOverride *bool `json:"enable_tag_override,omitempty" hcl:"enable_tag_override" mapstructure:"enable_tag_override"`
|
||||
Connect *ServiceConnect `json:"connect,omitempty" hcl:"connect" mapstructure:"connect"`
|
||||
}
|
||||
|
||||
type CheckDefinition struct {
|
||||
|
@ -349,6 +352,58 @@ type CheckDefinition struct {
|
|||
DeregisterCriticalServiceAfter *string `json:"deregister_critical_service_after,omitempty" hcl:"deregister_critical_service_after" mapstructure:"deregister_critical_service_after"`
|
||||
}
|
||||
|
||||
// ServiceConnect is the connect block within a service registration
|
||||
type ServiceConnect struct {
|
||||
// TODO(banks) add way to specify that the app is connect-native
|
||||
// Proxy configures a connect proxy instance for the service
|
||||
Proxy *ServiceConnectProxy `json:"proxy,omitempty" hcl:"proxy" mapstructure:"proxy"`
|
||||
}
|
||||
|
||||
type ServiceConnectProxy struct {
|
||||
Command []string `json:"command,omitempty" hcl:"command" mapstructure:"command"`
|
||||
ExecMode *string `json:"exec_mode,omitempty" hcl:"exec_mode" mapstructure:"exec_mode"`
|
||||
Config map[string]interface{} `json:"config,omitempty" hcl:"config" mapstructure:"config"`
|
||||
}
|
||||
|
||||
// Connect is the agent-global connect configuration.
|
||||
type Connect struct {
|
||||
// Enabled opts the agent into connect. It should be set on all clients and
|
||||
// servers in a cluster for correct connect operation.
|
||||
Enabled *bool `json:"enabled,omitempty" hcl:"enabled" mapstructure:"enabled"`
|
||||
Proxy ConnectProxy `json:"proxy,omitempty" hcl:"proxy" mapstructure:"proxy"`
|
||||
ProxyDefaults ConnectProxyDefaults `json:"proxy_defaults,omitempty" hcl:"proxy_defaults" mapstructure:"proxy_defaults"`
|
||||
CAProvider *string `json:"ca_provider,omitempty" hcl:"ca_provider" mapstructure:"ca_provider"`
|
||||
CAConfig map[string]interface{} `json:"ca_config,omitempty" hcl:"ca_config" mapstructure:"ca_config"`
|
||||
}
|
||||
|
||||
// ConnectProxy is the agent-global connect proxy configuration.
|
||||
type ConnectProxy struct {
|
||||
// Consul will not execute managed proxies if its EUID is 0 (root).
|
||||
// If this is true, then Consul will execute proxies if Consul is
|
||||
// running as root. This is not recommended.
|
||||
AllowManagedRoot *bool `json:"allow_managed_root" hcl:"allow_managed_root" mapstructure:"allow_managed_root"`
|
||||
|
||||
// AllowManagedAPIRegistration enables managed proxy registration
|
||||
// via the agent HTTP API. If this is false, only file configurations
|
||||
// can be used.
|
||||
AllowManagedAPIRegistration *bool `json:"allow_managed_api_registration" hcl:"allow_managed_api_registration" mapstructure:"allow_managed_api_registration"`
|
||||
}
|
||||
|
||||
// ConnectProxyDefaults is the agent-global defaults for managed Connect proxies.
|
||||
type ConnectProxyDefaults struct {
|
||||
// ExecMode is used where a registration doesn't include an exec_mode.
|
||||
// Defaults to daemon.
|
||||
ExecMode *string `json:"exec_mode,omitempty" hcl:"exec_mode" mapstructure:"exec_mode"`
|
||||
// DaemonCommand is used to start proxy in exec_mode = daemon if not specified
|
||||
// at registration time.
|
||||
DaemonCommand []string `json:"daemon_command,omitempty" hcl:"daemon_command" mapstructure:"daemon_command"`
|
||||
// ScriptCommand is used to start proxy in exec_mode = script if not specified
|
||||
// at registration time.
|
||||
ScriptCommand []string `json:"script_command,omitempty" hcl:"script_command" mapstructure:"script_command"`
|
||||
// Config is merged into an Config specified at registration time.
|
||||
Config map[string]interface{} `json:"config,omitempty" hcl:"config" mapstructure:"config"`
|
||||
}
|
||||
|
||||
type DNS struct {
|
||||
AllowStale *bool `json:"allow_stale,omitempty" hcl:"allow_stale" mapstructure:"allow_stale"`
|
||||
ARecordLimit *int `json:"a_record_limit,omitempty" hcl:"a_record_limit" mapstructure:"a_record_limit"`
|
||||
|
@ -406,6 +461,8 @@ type Ports struct {
|
|||
SerfLAN *int `json:"serf_lan,omitempty" hcl:"serf_lan" mapstructure:"serf_lan"`
|
||||
SerfWAN *int `json:"serf_wan,omitempty" hcl:"serf_wan" mapstructure:"serf_wan"`
|
||||
Server *int `json:"server,omitempty" hcl:"server" mapstructure:"server"`
|
||||
ProxyMinPort *int `json:"proxy_min_port,omitempty" hcl:"proxy_min_port" mapstructure:"proxy_min_port"`
|
||||
ProxyMaxPort *int `json:"proxy_max_port,omitempty" hcl:"proxy_max_port" mapstructure:"proxy_max_port"`
|
||||
}
|
||||
|
||||
type UnixSocket struct {
|
||||
|
|
|
@ -85,6 +85,8 @@ func DefaultSource() Source {
|
|||
serf_lan = ` + strconv.Itoa(consul.DefaultLANSerfPort) + `
|
||||
serf_wan = ` + strconv.Itoa(consul.DefaultWANSerfPort) + `
|
||||
server = ` + strconv.Itoa(consul.DefaultRPCPort) + `
|
||||
proxy_min_port = 20000
|
||||
proxy_max_port = 20255
|
||||
}
|
||||
telemetry = {
|
||||
metrics_prefix = "consul"
|
||||
|
@ -108,6 +110,10 @@ func DevSource() Source {
|
|||
ui = true
|
||||
log_level = "DEBUG"
|
||||
server = true
|
||||
|
||||
connect = {
|
||||
enabled = true
|
||||
}
|
||||
performance = {
|
||||
raft_multiplier = 1
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"golang.org/x/time/rate"
|
||||
|
@ -304,177 +305,8 @@ type RuntimeConfig struct {
|
|||
// hcl: http_config { response_headers = map[string]string }
|
||||
HTTPResponseHeaders map[string]string
|
||||
|
||||
// TelemetryCirconus*: see https://github.com/circonus-labs/circonus-gometrics
|
||||
// for more details on the various configuration options.
|
||||
// Valid configuration combinations:
|
||||
// - CirconusAPIToken
|
||||
// metric management enabled (search for existing check or create a new one)
|
||||
// - CirconusSubmissionUrl
|
||||
// metric management disabled (use check with specified submission_url,
|
||||
// broker must be using a public SSL certificate)
|
||||
// - CirconusAPIToken + CirconusCheckSubmissionURL
|
||||
// metric management enabled (use check with specified submission_url)
|
||||
// - CirconusAPIToken + CirconusCheckID
|
||||
// metric management enabled (use check with specified id)
|
||||
|
||||
// TelemetryCirconusAPIApp is an app name associated with API token.
|
||||
// Default: "consul"
|
||||
//
|
||||
// hcl: telemetry { circonus_api_app = string }
|
||||
TelemetryCirconusAPIApp string
|
||||
|
||||
// TelemetryCirconusAPIToken is a valid API Token used to create/manage check. If provided,
|
||||
// metric management is enabled.
|
||||
// Default: none
|
||||
//
|
||||
// hcl: telemetry { circonus_api_token = string }
|
||||
TelemetryCirconusAPIToken string
|
||||
|
||||
// TelemetryCirconusAPIURL is the base URL to use for contacting the Circonus API.
|
||||
// Default: "https://api.circonus.com/v2"
|
||||
//
|
||||
// hcl: telemetry { circonus_api_url = string }
|
||||
TelemetryCirconusAPIURL string
|
||||
|
||||
// TelemetryCirconusBrokerID is an explicit broker to use when creating a new check. The numeric portion
|
||||
// of broker._cid. If metric management is enabled and neither a Submission URL nor Check ID
|
||||
// is provided, an attempt will be made to search for an existing check using Instance ID and
|
||||
// Search Tag. If one is not found, a new HTTPTRAP check will be created.
|
||||
// Default: use Select Tag if provided, otherwise, a random Enterprise Broker associated
|
||||
// with the specified API token or the default Circonus Broker.
|
||||
// Default: none
|
||||
//
|
||||
// hcl: telemetry { circonus_broker_id = string }
|
||||
TelemetryCirconusBrokerID string
|
||||
|
||||
// TelemetryCirconusBrokerSelectTag is a special tag which will be used to select a broker when
|
||||
// a Broker ID is not provided. The best use of this is to as a hint for which broker
|
||||
// should be used based on *where* this particular instance is running.
|
||||
// (e.g. a specific geo location or datacenter, dc:sfo)
|
||||
// Default: none
|
||||
//
|
||||
// hcl: telemetry { circonus_broker_select_tag = string }
|
||||
TelemetryCirconusBrokerSelectTag string
|
||||
|
||||
// TelemetryCirconusCheckDisplayName is the name for the check which will be displayed in the Circonus UI.
|
||||
// Default: value of CirconusCheckInstanceID
|
||||
//
|
||||
// hcl: telemetry { circonus_check_display_name = string }
|
||||
TelemetryCirconusCheckDisplayName string
|
||||
|
||||
// TelemetryCirconusCheckForceMetricActivation will force enabling metrics, as they are encountered,
|
||||
// if the metric already exists and is NOT active. If check management is enabled, the default
|
||||
// behavior is to add new metrics as they are encountered. If the metric already exists in the
|
||||
// check, it will *NOT* be activated. This setting overrides that behavior.
|
||||
// Default: "false"
|
||||
//
|
||||
// hcl: telemetry { circonus_check_metrics_activation = (true|false)
|
||||
TelemetryCirconusCheckForceMetricActivation string
|
||||
|
||||
// TelemetryCirconusCheckID is the check id (not check bundle id) from a previously created
|
||||
// HTTPTRAP check. The numeric portion of the check._cid field.
|
||||
// Default: none
|
||||
//
|
||||
// hcl: telemetry { circonus_check_id = string }
|
||||
TelemetryCirconusCheckID string
|
||||
|
||||
// TelemetryCirconusCheckInstanceID serves to uniquely identify the metrics coming from this "instance".
|
||||
// It can be used to maintain metric continuity with transient or ephemeral instances as
|
||||
// they move around within an infrastructure.
|
||||
// Default: hostname:app
|
||||
//
|
||||
// hcl: telemetry { circonus_check_instance_id = string }
|
||||
TelemetryCirconusCheckInstanceID string
|
||||
|
||||
// TelemetryCirconusCheckSearchTag is a special tag which, when coupled with the instance id, helps to
|
||||
// narrow down the search results when neither a Submission URL or Check ID is provided.
|
||||
// Default: service:app (e.g. service:consul)
|
||||
//
|
||||
// hcl: telemetry { circonus_check_search_tag = string }
|
||||
TelemetryCirconusCheckSearchTag string
|
||||
|
||||
// TelemetryCirconusCheckSearchTag is a special tag which, when coupled with the instance id, helps to
|
||||
// narrow down the search results when neither a Submission URL or Check ID is provided.
|
||||
// Default: service:app (e.g. service:consul)
|
||||
//
|
||||
// hcl: telemetry { circonus_check_tags = string }
|
||||
TelemetryCirconusCheckTags string
|
||||
|
||||
// TelemetryCirconusSubmissionInterval is the interval at which metrics are submitted to Circonus.
|
||||
// Default: 10s
|
||||
//
|
||||
// hcl: telemetry { circonus_submission_interval = "duration" }
|
||||
TelemetryCirconusSubmissionInterval string
|
||||
|
||||
// TelemetryCirconusCheckSubmissionURL is the check.config.submission_url field from a
|
||||
// previously created HTTPTRAP check.
|
||||
// Default: none
|
||||
//
|
||||
// hcl: telemetry { circonus_submission_url = string }
|
||||
TelemetryCirconusSubmissionURL string
|
||||
|
||||
// DisableHostname will disable hostname prefixing for all metrics.
|
||||
//
|
||||
// hcl: telemetry { disable_hostname = (true|false)
|
||||
TelemetryDisableHostname bool
|
||||
|
||||
// TelemetryDogStatsdAddr is the address of a dogstatsd instance. If provided,
|
||||
// metrics will be sent to that instance
|
||||
//
|
||||
// hcl: telemetry { dogstatsd_addr = string }
|
||||
TelemetryDogstatsdAddr string
|
||||
|
||||
// TelemetryDogStatsdTags are the global tags that should be sent with each packet to dogstatsd
|
||||
// It is a list of strings, where each string looks like "my_tag_name:my_tag_value"
|
||||
//
|
||||
// hcl: telemetry { dogstatsd_tags = []string }
|
||||
TelemetryDogstatsdTags []string
|
||||
|
||||
// PrometheusRetentionTime is the retention time for prometheus metrics if greater than 0.
|
||||
// A value of 0 disable Prometheus support. Regarding Prometheus, it is considered a good
|
||||
// practice to put large values here (such as a few days), and at least the interval between
|
||||
// prometheus requests.
|
||||
//
|
||||
// hcl: telemetry { prometheus_retention_time = "duration" }
|
||||
TelemetryPrometheusRetentionTime time.Duration
|
||||
|
||||
// TelemetryFilterDefault is the default for whether to allow a metric that's not
|
||||
// covered by the filter.
|
||||
//
|
||||
// hcl: telemetry { filter_default = (true|false) }
|
||||
TelemetryFilterDefault bool
|
||||
|
||||
// TelemetryAllowedPrefixes is a list of filter rules to apply for allowing metrics
|
||||
// by prefix. Use the 'prefix_filter' option and prefix rules with '+' to be
|
||||
// included.
|
||||
//
|
||||
// hcl: telemetry { prefix_filter = []string{"+<expr>", "+<expr>", ...} }
|
||||
TelemetryAllowedPrefixes []string
|
||||
|
||||
// TelemetryBlockedPrefixes is a list of filter rules to apply for blocking metrics
|
||||
// by prefix. Use the 'prefix_filter' option and prefix rules with '-' to be
|
||||
// excluded.
|
||||
//
|
||||
// hcl: telemetry { prefix_filter = []string{"-<expr>", "-<expr>", ...} }
|
||||
TelemetryBlockedPrefixes []string
|
||||
|
||||
// TelemetryMetricsPrefix is the prefix used to write stats values to.
|
||||
// Default: "consul."
|
||||
//
|
||||
// hcl: telemetry { metrics_prefix = string }
|
||||
TelemetryMetricsPrefix string
|
||||
|
||||
// TelemetryStatsdAddr is the address of a statsd instance. If provided,
|
||||
// metrics will be sent to that instance.
|
||||
//
|
||||
// hcl: telemetry { statsd_addr = string }
|
||||
TelemetryStatsdAddr string
|
||||
|
||||
// TelemetryStatsiteAddr is the address of a statsite instance. If provided,
|
||||
// metrics will be streamed to that instance.
|
||||
//
|
||||
// hcl: telemetry { statsite_addr = string }
|
||||
TelemetryStatsiteAddr string
|
||||
// Embed Telemetry Config
|
||||
Telemetry lib.TelemetryConfig
|
||||
|
||||
// Datacenter is the datacenter this node is in. Defaults to "dc1".
|
||||
//
|
||||
|
@ -621,6 +453,61 @@ type RuntimeConfig struct {
|
|||
// flag: -client string
|
||||
ClientAddrs []*net.IPAddr
|
||||
|
||||
// ConnectEnabled opts the agent into connect. It should be set on all clients
|
||||
// and servers in a cluster for correct connect operation.
|
||||
ConnectEnabled bool
|
||||
|
||||
// ConnectProxyBindMinPort is the inclusive start of the range of ports
|
||||
// allocated to the agent for starting proxy listeners on where no explicit
|
||||
// port is specified.
|
||||
ConnectProxyBindMinPort int
|
||||
|
||||
// ConnectProxyBindMaxPort is the inclusive end of the range of ports
|
||||
// allocated to the agent for starting proxy listeners on where no explicit
|
||||
// port is specified.
|
||||
ConnectProxyBindMaxPort int
|
||||
|
||||
// ConnectProxyAllowManagedRoot is true if Consul can execute managed
|
||||
// proxies when running as root (EUID == 0).
|
||||
ConnectProxyAllowManagedRoot bool
|
||||
|
||||
// ConnectProxyAllowManagedAPIRegistration enables managed proxy registration
|
||||
// via the agent HTTP API. If this is false, only file configurations
|
||||
// can be used.
|
||||
ConnectProxyAllowManagedAPIRegistration bool
|
||||
|
||||
// ConnectProxyDefaultExecMode is used where a registration doesn't include an
|
||||
// exec_mode. Defaults to daemon.
|
||||
ConnectProxyDefaultExecMode string
|
||||
|
||||
// ConnectProxyDefaultDaemonCommand is used to start proxy in exec_mode =
|
||||
// daemon if not specified at registration time.
|
||||
ConnectProxyDefaultDaemonCommand []string
|
||||
|
||||
// ConnectProxyDefaultScriptCommand is used to start proxy in exec_mode =
|
||||
// script if not specified at registration time.
|
||||
ConnectProxyDefaultScriptCommand []string
|
||||
|
||||
// ConnectProxyDefaultConfig is merged with any config specified at
|
||||
// registration time to allow global control of defaults.
|
||||
ConnectProxyDefaultConfig map[string]interface{}
|
||||
|
||||
// ConnectCAProvider is the type of CA provider to use with Connect.
|
||||
ConnectCAProvider string
|
||||
|
||||
// ConnectCAConfig is the config to use for the CA provider.
|
||||
ConnectCAConfig map[string]interface{}
|
||||
|
||||
// ConnectTestDisableManagedProxies is not exposed to public config but us
|
||||
// used by TestAgent to prevent self-executing the test binary in the
|
||||
// background if a managed proxy is created for a test. The only place we
|
||||
// actually want to test processes really being spun up and managed is in
|
||||
// `agent/proxy` and it does it at a lower level. Note that this still allows
|
||||
// registering managed proxies via API and other methods, and still creates
|
||||
// all the agent state for them, just doesn't actually start external
|
||||
// processes up.
|
||||
ConnectTestDisableManagedProxies bool
|
||||
|
||||
// DNSAddrs contains the list of TCP and UDP addresses the DNS server will
|
||||
// bind to. If the DNS endpoint is disabled (ports.dns <= 0) the list is
|
||||
// empty.
|
||||
|
|
|
@ -19,10 +19,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/testutil"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/sergi/go-diff/diffmatchpatch"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type configTest struct {
|
||||
|
@ -31,6 +32,7 @@ type configTest struct {
|
|||
pre, post func()
|
||||
json, jsontail []string
|
||||
hcl, hcltail []string
|
||||
skipformat bool
|
||||
privatev4 func() ([]*net.IPAddr, error)
|
||||
publicv6 func() ([]*net.IPAddr, error)
|
||||
patch func(rt *RuntimeConfig)
|
||||
|
@ -263,6 +265,7 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
|
|||
rt.AdvertiseAddrLAN = ipAddr("127.0.0.1")
|
||||
rt.AdvertiseAddrWAN = ipAddr("127.0.0.1")
|
||||
rt.BindAddr = ipAddr("127.0.0.1")
|
||||
rt.ConnectEnabled = true
|
||||
rt.DevMode = true
|
||||
rt.DisableAnonymousSignature = true
|
||||
rt.DisableKeyringFile = true
|
||||
|
@ -1850,8 +1853,8 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
|
|||
`},
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.DataDir = dataDir
|
||||
rt.TelemetryAllowedPrefixes = []string{"foo"}
|
||||
rt.TelemetryBlockedPrefixes = []string{"bar"}
|
||||
rt.Telemetry.AllowedPrefixes = []string{"foo"}
|
||||
rt.Telemetry.BlockedPrefixes = []string{"bar"}
|
||||
},
|
||||
warns: []string{`Filter rule must begin with either '+' or '-': "nix"`},
|
||||
},
|
||||
|
@ -2068,6 +2071,127 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
|
|||
rt.DataDir = dataDir
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
desc: "HCL service managed proxy 'upstreams'",
|
||||
args: []string{
|
||||
`-data-dir=` + dataDir,
|
||||
},
|
||||
hcl: []string{
|
||||
`service {
|
||||
name = "web"
|
||||
port = 8080
|
||||
connect {
|
||||
proxy {
|
||||
config {
|
||||
upstreams {
|
||||
local_bind_port = 1234
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
skipformat: true, // skipping JSON cause we get slightly diff types (okay)
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.DataDir = dataDir
|
||||
rt.Services = []*structs.ServiceDefinition{
|
||||
&structs.ServiceDefinition{
|
||||
Name: "web",
|
||||
Port: 8080,
|
||||
Connect: &structs.ServiceConnect{
|
||||
Proxy: &structs.ServiceDefinitionConnectProxy{
|
||||
Config: map[string]interface{}{
|
||||
"upstreams": []map[string]interface{}{
|
||||
map[string]interface{}{
|
||||
"local_bind_port": 1234,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "JSON service managed proxy 'upstreams'",
|
||||
args: []string{
|
||||
`-data-dir=` + dataDir,
|
||||
},
|
||||
json: []string{
|
||||
`{
|
||||
"service": {
|
||||
"name": "web",
|
||||
"port": 8080,
|
||||
"connect": {
|
||||
"proxy": {
|
||||
"config": {
|
||||
"upstreams": [{
|
||||
"local_bind_port": 1234
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
skipformat: true, // skipping HCL cause we get slightly diff types (okay)
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.DataDir = dataDir
|
||||
rt.Services = []*structs.ServiceDefinition{
|
||||
&structs.ServiceDefinition{
|
||||
Name: "web",
|
||||
Port: 8080,
|
||||
Connect: &structs.ServiceConnect{
|
||||
Proxy: &structs.ServiceDefinitionConnectProxy{
|
||||
Config: map[string]interface{}{
|
||||
"upstreams": []interface{}{
|
||||
map[string]interface{}{
|
||||
"local_bind_port": float64(1234),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
desc: "enabling Connect allow_managed_root",
|
||||
args: []string{
|
||||
`-data-dir=` + dataDir,
|
||||
},
|
||||
json: []string{
|
||||
`{ "connect": { "proxy": { "allow_managed_root": true } } }`,
|
||||
},
|
||||
hcl: []string{
|
||||
`connect { proxy { allow_managed_root = true } }`,
|
||||
},
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.DataDir = dataDir
|
||||
rt.ConnectProxyAllowManagedRoot = true
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
desc: "enabling Connect allow_managed_api_registration",
|
||||
args: []string{
|
||||
`-data-dir=` + dataDir,
|
||||
},
|
||||
json: []string{
|
||||
`{ "connect": { "proxy": { "allow_managed_api_registration": true } } }`,
|
||||
},
|
||||
hcl: []string{
|
||||
`connect { proxy { allow_managed_api_registration = true } }`,
|
||||
},
|
||||
patch: func(rt *RuntimeConfig) {
|
||||
rt.DataDir = dataDir
|
||||
rt.ConnectProxyAllowManagedAPIRegistration = true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testConfig(t, tests, dataDir)
|
||||
|
@ -2089,7 +2213,7 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
|
|||
|
||||
// json and hcl sources need to be in sync
|
||||
// to make sure we're generating the same config
|
||||
if len(tt.json) != len(tt.hcl) {
|
||||
if len(tt.json) != len(tt.hcl) && !tt.skipformat {
|
||||
t.Fatal(tt.desc, ": JSON and HCL test case out of sync")
|
||||
}
|
||||
|
||||
|
@ -2099,6 +2223,12 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
|
|||
srcs, tails = tt.hcl, tt.hcltail
|
||||
}
|
||||
|
||||
// If we're skipping a format and the current format is empty,
|
||||
// then skip it!
|
||||
if tt.skipformat && len(srcs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// build the description
|
||||
var desc []string
|
||||
if !flagsOnly {
|
||||
|
@ -2353,6 +2483,23 @@ func TestFullConfig(t *testing.T) {
|
|||
],
|
||||
"check_update_interval": "16507s",
|
||||
"client_addr": "93.83.18.19",
|
||||
"connect": {
|
||||
"ca_provider": "consul",
|
||||
"ca_config": {
|
||||
"RotationPeriod": "90h"
|
||||
},
|
||||
"enabled": true,
|
||||
"proxy_defaults": {
|
||||
"exec_mode": "script",
|
||||
"daemon_command": ["consul", "connect", "proxy"],
|
||||
"script_command": ["proxyctl.sh"],
|
||||
"config": {
|
||||
"foo": "bar",
|
||||
"connect_timeout_ms": 1000,
|
||||
"pedantic_mode": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"data_dir": "` + dataDir + `",
|
||||
"datacenter": "rzo029wg",
|
||||
"disable_anonymous_signature": true,
|
||||
|
@ -2417,7 +2564,9 @@ func TestFullConfig(t *testing.T) {
|
|||
"dns": 7001,
|
||||
"http": 7999,
|
||||
"https": 15127,
|
||||
"server": 3757
|
||||
"server": 3757,
|
||||
"proxy_min_port": 2000,
|
||||
"proxy_max_port": 3000
|
||||
},
|
||||
"protocol": 30793,
|
||||
"raft_protocol": 19016,
|
||||
|
@ -2613,7 +2762,16 @@ func TestFullConfig(t *testing.T) {
|
|||
"ttl": "11222s",
|
||||
"deregister_critical_service_after": "68482s"
|
||||
}
|
||||
]
|
||||
],
|
||||
"connect": {
|
||||
"proxy": {
|
||||
"exec_mode": "daemon",
|
||||
"command": ["awesome-proxy"],
|
||||
"config": {
|
||||
"foo": "qux"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"session_ttl_min": "26627s",
|
||||
|
@ -2786,6 +2944,25 @@ func TestFullConfig(t *testing.T) {
|
|||
]
|
||||
check_update_interval = "16507s"
|
||||
client_addr = "93.83.18.19"
|
||||
connect {
|
||||
ca_provider = "consul"
|
||||
ca_config {
|
||||
"RotationPeriod" = "90h"
|
||||
}
|
||||
enabled = true
|
||||
proxy_defaults {
|
||||
exec_mode = "script"
|
||||
daemon_command = ["consul", "connect", "proxy"]
|
||||
script_command = ["proxyctl.sh"]
|
||||
config = {
|
||||
foo = "bar"
|
||||
# hack float since json parses numbers as float and we have to
|
||||
# assert against the same thing
|
||||
connect_timeout_ms = 1000.0
|
||||
pedantic_mode = true
|
||||
}
|
||||
}
|
||||
}
|
||||
data_dir = "` + dataDir + `"
|
||||
datacenter = "rzo029wg"
|
||||
disable_anonymous_signature = true
|
||||
|
@ -2851,6 +3028,8 @@ func TestFullConfig(t *testing.T) {
|
|||
http = 7999,
|
||||
https = 15127
|
||||
server = 3757
|
||||
proxy_min_port = 2000
|
||||
proxy_max_port = 3000
|
||||
}
|
||||
protocol = 30793
|
||||
raft_protocol = 19016
|
||||
|
@ -3047,6 +3226,15 @@ func TestFullConfig(t *testing.T) {
|
|||
deregister_critical_service_after = "68482s"
|
||||
}
|
||||
]
|
||||
connect {
|
||||
proxy {
|
||||
exec_mode = "daemon"
|
||||
command = ["awesome-proxy"]
|
||||
config = {
|
||||
foo = "qux"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
session_ttl_min = "26627s"
|
||||
|
@ -3357,6 +3545,23 @@ func TestFullConfig(t *testing.T) {
|
|||
},
|
||||
CheckUpdateInterval: 16507 * time.Second,
|
||||
ClientAddrs: []*net.IPAddr{ipAddr("93.83.18.19")},
|
||||
ConnectEnabled: true,
|
||||
ConnectProxyBindMinPort: 2000,
|
||||
ConnectProxyBindMaxPort: 3000,
|
||||
ConnectCAProvider: "consul",
|
||||
ConnectCAConfig: map[string]interface{}{
|
||||
"RotationPeriod": "90h",
|
||||
},
|
||||
ConnectProxyAllowManagedRoot: false,
|
||||
ConnectProxyAllowManagedAPIRegistration: false,
|
||||
ConnectProxyDefaultExecMode: "script",
|
||||
ConnectProxyDefaultDaemonCommand: []string{"consul", "connect", "proxy"},
|
||||
ConnectProxyDefaultScriptCommand: []string{"proxyctl.sh"},
|
||||
ConnectProxyDefaultConfig: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"connect_timeout_ms": float64(1000),
|
||||
"pedantic_mode": true,
|
||||
},
|
||||
DNSAddrs: []net.Addr{tcpAddr("93.95.95.81:7001"), udpAddr("93.95.95.81:7001")},
|
||||
DNSARecordLimit: 29907,
|
||||
DNSAllowStale: true,
|
||||
|
@ -3530,6 +3735,15 @@ func TestFullConfig(t *testing.T) {
|
|||
DeregisterCriticalServiceAfter: 68482 * time.Second,
|
||||
},
|
||||
},
|
||||
Connect: &structs.ServiceConnect{
|
||||
Proxy: &structs.ServiceDefinitionConnectProxy{
|
||||
ExecMode: "daemon",
|
||||
Command: []string{"awesome-proxy"},
|
||||
Config: map[string]interface{}{
|
||||
"foo": "qux",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "dLOXpSCI",
|
||||
|
@ -3616,29 +3830,31 @@ func TestFullConfig(t *testing.T) {
|
|||
StartJoinAddrsLAN: []string{"LR3hGDoG", "MwVpZ4Up"},
|
||||
StartJoinAddrsWAN: []string{"EbFSc3nA", "kwXTh623"},
|
||||
SyslogFacility: "hHv79Uia",
|
||||
TelemetryCirconusAPIApp: "p4QOTe9j",
|
||||
TelemetryCirconusAPIToken: "E3j35V23",
|
||||
TelemetryCirconusAPIURL: "mEMjHpGg",
|
||||
TelemetryCirconusBrokerID: "BHlxUhed",
|
||||
TelemetryCirconusBrokerSelectTag: "13xy1gHm",
|
||||
TelemetryCirconusCheckDisplayName: "DRSlQR6n",
|
||||
TelemetryCirconusCheckForceMetricActivation: "Ua5FGVYf",
|
||||
TelemetryCirconusCheckID: "kGorutad",
|
||||
TelemetryCirconusCheckInstanceID: "rwoOL6R4",
|
||||
TelemetryCirconusCheckSearchTag: "ovT4hT4f",
|
||||
TelemetryCirconusCheckTags: "prvO4uBl",
|
||||
TelemetryCirconusSubmissionInterval: "DolzaflP",
|
||||
TelemetryCirconusSubmissionURL: "gTcbS93G",
|
||||
TelemetryDisableHostname: true,
|
||||
TelemetryDogstatsdAddr: "0wSndumK",
|
||||
TelemetryDogstatsdTags: []string{"3N81zSUB", "Xtj8AnXZ"},
|
||||
TelemetryFilterDefault: true,
|
||||
TelemetryAllowedPrefixes: []string{"oJotS8XJ"},
|
||||
TelemetryBlockedPrefixes: []string{"cazlEhGn"},
|
||||
TelemetryMetricsPrefix: "ftO6DySn",
|
||||
TelemetryPrometheusRetentionTime: 15 * time.Second,
|
||||
TelemetryStatsdAddr: "drce87cy",
|
||||
TelemetryStatsiteAddr: "HpFwKB8R",
|
||||
Telemetry: lib.TelemetryConfig{
|
||||
CirconusAPIApp: "p4QOTe9j",
|
||||
CirconusAPIToken: "E3j35V23",
|
||||
CirconusAPIURL: "mEMjHpGg",
|
||||
CirconusBrokerID: "BHlxUhed",
|
||||
CirconusBrokerSelectTag: "13xy1gHm",
|
||||
CirconusCheckDisplayName: "DRSlQR6n",
|
||||
CirconusCheckForceMetricActivation: "Ua5FGVYf",
|
||||
CirconusCheckID: "kGorutad",
|
||||
CirconusCheckInstanceID: "rwoOL6R4",
|
||||
CirconusCheckSearchTag: "ovT4hT4f",
|
||||
CirconusCheckTags: "prvO4uBl",
|
||||
CirconusSubmissionInterval: "DolzaflP",
|
||||
CirconusSubmissionURL: "gTcbS93G",
|
||||
DisableHostname: true,
|
||||
DogstatsdAddr: "0wSndumK",
|
||||
DogstatsdTags: []string{"3N81zSUB", "Xtj8AnXZ"},
|
||||
FilterDefault: true,
|
||||
AllowedPrefixes: []string{"oJotS8XJ"},
|
||||
BlockedPrefixes: []string{"cazlEhGn"},
|
||||
MetricsPrefix: "ftO6DySn",
|
||||
PrometheusRetentionTime: 15 * time.Second,
|
||||
StatsdAddr: "drce87cy",
|
||||
StatsiteAddr: "HpFwKB8R",
|
||||
},
|
||||
TLSCipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384},
|
||||
TLSMinVersion: "pAOWafkR",
|
||||
TLSPreferServerCipherSuites: true,
|
||||
|
@ -4019,6 +4235,18 @@ func TestSanitize(t *testing.T) {
|
|||
}
|
||||
],
|
||||
"ClientAddrs": [],
|
||||
"ConnectCAConfig": {},
|
||||
"ConnectCAProvider": "",
|
||||
"ConnectEnabled": false,
|
||||
"ConnectProxyAllowManagedAPIRegistration": false,
|
||||
"ConnectProxyAllowManagedRoot": false,
|
||||
"ConnectProxyBindMaxPort": 0,
|
||||
"ConnectProxyBindMinPort": 0,
|
||||
"ConnectProxyDefaultConfig": {},
|
||||
"ConnectProxyDefaultDaemonCommand": [],
|
||||
"ConnectProxyDefaultExecMode": "",
|
||||
"ConnectProxyDefaultScriptCommand": [],
|
||||
"ConnectTestDisableManagedProxies": false,
|
||||
"ConsulCoordinateUpdateBatchSize": 0,
|
||||
"ConsulCoordinateUpdateMaxBatches": 0,
|
||||
"ConsulCoordinateUpdatePeriod": "15s",
|
||||
|
@ -4150,11 +4378,14 @@ func TestSanitize(t *testing.T) {
|
|||
"Timeout": "0s"
|
||||
},
|
||||
"Checks": [],
|
||||
"Connect": null,
|
||||
"EnableTagOverride": false,
|
||||
"ID": "",
|
||||
"Kind": "",
|
||||
"Meta": {},
|
||||
"Name": "foo",
|
||||
"Port": 0,
|
||||
"ProxyDestination": "",
|
||||
"Tags": [],
|
||||
"Token": "hidden"
|
||||
}
|
||||
|
@ -4170,29 +4401,31 @@ func TestSanitize(t *testing.T) {
|
|||
"TLSMinVersion": "",
|
||||
"TLSPreferServerCipherSuites": false,
|
||||
"TaggedAddresses": {},
|
||||
"TelemetryAllowedPrefixes": [],
|
||||
"TelemetryBlockedPrefixes": [],
|
||||
"TelemetryCirconusAPIApp": "",
|
||||
"TelemetryCirconusAPIToken": "hidden",
|
||||
"TelemetryCirconusAPIURL": "",
|
||||
"TelemetryCirconusBrokerID": "",
|
||||
"TelemetryCirconusBrokerSelectTag": "",
|
||||
"TelemetryCirconusCheckDisplayName": "",
|
||||
"TelemetryCirconusCheckForceMetricActivation": "",
|
||||
"TelemetryCirconusCheckID": "",
|
||||
"TelemetryCirconusCheckInstanceID": "",
|
||||
"TelemetryCirconusCheckSearchTag": "",
|
||||
"TelemetryCirconusCheckTags": "",
|
||||
"TelemetryCirconusSubmissionInterval": "",
|
||||
"TelemetryCirconusSubmissionURL": "",
|
||||
"TelemetryDisableHostname": false,
|
||||
"TelemetryDogstatsdAddr": "",
|
||||
"TelemetryDogstatsdTags": [],
|
||||
"TelemetryFilterDefault": false,
|
||||
"TelemetryMetricsPrefix": "",
|
||||
"TelemetryPrometheusRetentionTime": "0s",
|
||||
"TelemetryStatsdAddr": "",
|
||||
"TelemetryStatsiteAddr": "",
|
||||
"Telemetry":{
|
||||
"AllowedPrefixes": [],
|
||||
"BlockedPrefixes": [],
|
||||
"CirconusAPIApp": "",
|
||||
"CirconusAPIToken": "hidden",
|
||||
"CirconusAPIURL": "",
|
||||
"CirconusBrokerID": "",
|
||||
"CirconusBrokerSelectTag": "",
|
||||
"CirconusCheckDisplayName": "",
|
||||
"CirconusCheckForceMetricActivation": "",
|
||||
"CirconusCheckID": "",
|
||||
"CirconusCheckInstanceID": "",
|
||||
"CirconusCheckSearchTag": "",
|
||||
"CirconusCheckTags": "",
|
||||
"CirconusSubmissionInterval": "",
|
||||
"CirconusSubmissionURL": "",
|
||||
"DisableHostname": false,
|
||||
"DogstatsdAddr": "",
|
||||
"DogstatsdTags": [],
|
||||
"FilterDefault": false,
|
||||
"MetricsPrefix": "",
|
||||
"PrometheusRetentionTime": "0s",
|
||||
"StatsdAddr": "",
|
||||
"StatsiteAddr": ""
|
||||
},
|
||||
"TranslateWANAddrs": false,
|
||||
"UIDir": "",
|
||||
"UnixSocketGroup": "",
|
||||
|
@ -4212,11 +4445,7 @@ func TestSanitize(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := string(b), rtJSON; got != want {
|
||||
dmp := diffmatchpatch.New()
|
||||
diffs := dmp.DiffMain(want, got, false)
|
||||
t.Fatal(dmp.DiffPrettyText(diffs))
|
||||
}
|
||||
require.JSONEq(t, rtJSON, string(b))
|
||||
}
|
||||
|
||||
func splitIPPort(hostport string) (net.IP, int) {
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
// Provider is the interface for Consul to interact with
|
||||
// an external CA that provides leaf certificate signing for
|
||||
// given SpiffeIDServices.
|
||||
type Provider interface {
|
||||
// Active root returns the currently active root CA for this
|
||||
// provider. This should be a parent of the certificate returned by
|
||||
// ActiveIntermediate()
|
||||
ActiveRoot() (string, error)
|
||||
|
||||
// ActiveIntermediate returns the current signing cert used by this provider
|
||||
// for generating SPIFFE leaf certs. Note that this must not change except
|
||||
// when Consul requests the change via GenerateIntermediate. Changing the
|
||||
// signing cert will break Consul's assumptions about which validation paths
|
||||
// are active.
|
||||
ActiveIntermediate() (string, error)
|
||||
|
||||
// GenerateIntermediate returns a new intermediate signing cert and sets it to
|
||||
// the active intermediate. If multiple intermediates are needed to complete
|
||||
// the chain from the signing certificate back to the active root, they should
|
||||
// all by bundled here.
|
||||
GenerateIntermediate() (string, error)
|
||||
|
||||
// Sign signs a leaf certificate used by Connect proxies from a CSR. The PEM
|
||||
// returned should include only the leaf certificate as all Intermediates
|
||||
// needed to validate it will be added by Consul based on the active
|
||||
// intemediate and any cross-signed intermediates managed by Consul.
|
||||
Sign(*x509.CertificateRequest) (string, error)
|
||||
|
||||
// CrossSignCA must accept a CA certificate from another CA provider
|
||||
// and cross sign it exactly as it is such that it forms a chain back the the
|
||||
// CAProvider's current root. Specifically, the Distinguished Name, Subject
|
||||
// Alternative Name, SubjectKeyID and other relevant extensions must be kept.
|
||||
// The resulting certificate must have a distinct Serial Number and the
|
||||
// AuthorityKeyID set to the CAProvider's current signing key as well as the
|
||||
// Issuer related fields changed as necessary. The resulting certificate is
|
||||
// returned as a PEM formatted string.
|
||||
CrossSignCA(*x509.Certificate) (string, error)
|
||||
|
||||
// Cleanup performs any necessary cleanup that should happen when the provider
|
||||
// is shut down permanently, such as removing a temporary PKI backend in Vault
|
||||
// created for an intermediate CA.
|
||||
Cleanup() error
|
||||
}
|
|
@ -0,0 +1,379 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
type ConsulProvider struct {
|
||||
config *structs.ConsulCAProviderConfig
|
||||
id string
|
||||
delegate ConsulProviderStateDelegate
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
type ConsulProviderStateDelegate interface {
|
||||
State() *state.Store
|
||||
ApplyCARequest(*structs.CARequest) error
|
||||
}
|
||||
|
||||
// NewConsulProvider returns a new instance of the Consul CA provider,
|
||||
// bootstrapping its state in the state store necessary
|
||||
func NewConsulProvider(rawConfig map[string]interface{}, delegate ConsulProviderStateDelegate) (*ConsulProvider, error) {
|
||||
conf, err := ParseConsulCAConfig(rawConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider := &ConsulProvider{
|
||||
config: conf,
|
||||
delegate: delegate,
|
||||
id: fmt.Sprintf("%s,%s", conf.PrivateKey, conf.RootCert),
|
||||
}
|
||||
|
||||
// Check if this configuration of the provider has already been
|
||||
// initialized in the state store.
|
||||
state := delegate.State()
|
||||
_, providerState, err := state.CAProviderState(provider.id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Exit early if the state store has already been populated for this config.
|
||||
if providerState != nil {
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
newState := structs.CAConsulProviderState{
|
||||
ID: provider.id,
|
||||
}
|
||||
|
||||
// Write the initial provider state to get the index to use for the
|
||||
// CA serial number.
|
||||
{
|
||||
args := &structs.CARequest{
|
||||
Op: structs.CAOpSetProviderState,
|
||||
ProviderState: &newState,
|
||||
}
|
||||
if err := delegate.ApplyCARequest(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
idx, _, err := state.CAProviderState(provider.id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate a private key if needed
|
||||
if conf.PrivateKey == "" {
|
||||
_, pk, err := connect.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newState.PrivateKey = pk
|
||||
} else {
|
||||
newState.PrivateKey = conf.PrivateKey
|
||||
}
|
||||
|
||||
// Generate the root CA if necessary
|
||||
if conf.RootCert == "" {
|
||||
ca, err := provider.generateCA(newState.PrivateKey, idx+1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating CA: %v", err)
|
||||
}
|
||||
newState.RootCert = ca
|
||||
} else {
|
||||
newState.RootCert = conf.RootCert
|
||||
}
|
||||
|
||||
// Write the provider state
|
||||
args := &structs.CARequest{
|
||||
Op: structs.CAOpSetProviderState,
|
||||
ProviderState: &newState,
|
||||
}
|
||||
if err := delegate.ApplyCARequest(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// Return the active root CA and generate a new one if needed
|
||||
func (c *ConsulProvider) ActiveRoot() (string, error) {
|
||||
state := c.delegate.State()
|
||||
_, providerState, err := state.CAProviderState(c.id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return providerState.RootCert, nil
|
||||
}
|
||||
|
||||
// We aren't maintaining separate root/intermediate CAs for the builtin
|
||||
// provider, so just return the root.
|
||||
func (c *ConsulProvider) ActiveIntermediate() (string, error) {
|
||||
return c.ActiveRoot()
|
||||
}
|
||||
|
||||
// We aren't maintaining separate root/intermediate CAs for the builtin
|
||||
// provider, so just return the root.
|
||||
func (c *ConsulProvider) GenerateIntermediate() (string, error) {
|
||||
return c.ActiveIntermediate()
|
||||
}
|
||||
|
||||
// Remove the state store entry for this provider instance.
|
||||
func (c *ConsulProvider) Cleanup() error {
|
||||
args := &structs.CARequest{
|
||||
Op: structs.CAOpDeleteProviderState,
|
||||
ProviderState: &structs.CAConsulProviderState{ID: c.id},
|
||||
}
|
||||
if err := c.delegate.ApplyCARequest(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign returns a new certificate valid for the given SpiffeIDService
|
||||
// using the current CA.
|
||||
func (c *ConsulProvider) Sign(csr *x509.CertificateRequest) (string, error) {
|
||||
// Lock during the signing so we don't use the same index twice
|
||||
// for different cert serial numbers.
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// Get the provider state
|
||||
state := c.delegate.State()
|
||||
idx, providerState, err := state.CAProviderState(c.id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create the keyId for the cert from the signing private key.
|
||||
signer, err := connect.ParseSigner(providerState.PrivateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if signer == nil {
|
||||
return "", fmt.Errorf("error signing cert: Consul CA not initialized yet")
|
||||
}
|
||||
keyId, err := connect.KeyId(signer.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse the SPIFFE ID
|
||||
spiffeId, err := connect.ParseCertURI(csr.URIs[0])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
serviceId, ok := spiffeId.(*connect.SpiffeIDService)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("SPIFFE ID in CSR must be a service ID")
|
||||
}
|
||||
|
||||
// Parse the CA cert
|
||||
caCert, err := connect.ParseCert(providerState.RootCert)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing CA cert: %s", err)
|
||||
}
|
||||
|
||||
// Cert template for generation
|
||||
sn := &big.Int{}
|
||||
sn.SetUint64(idx + 1)
|
||||
// Sign the certificate valid from 1 minute in the past, this helps it be
|
||||
// accepted right away even when nodes are not in close time sync accross the
|
||||
// cluster. A minute is more than enough for typical DC clock drift.
|
||||
effectiveNow := time.Now().Add(-1 * time.Minute)
|
||||
template := x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
Subject: pkix.Name{CommonName: serviceId.Service},
|
||||
URIs: csr.URIs,
|
||||
Signature: csr.Signature,
|
||||
SignatureAlgorithm: csr.SignatureAlgorithm,
|
||||
PublicKeyAlgorithm: csr.PublicKeyAlgorithm,
|
||||
PublicKey: csr.PublicKey,
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageDataEncipherment |
|
||||
x509.KeyUsageKeyAgreement |
|
||||
x509.KeyUsageDigitalSignature |
|
||||
x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
// todo(kyhavlov): add a way to set the cert lifetime here from the CA config
|
||||
NotAfter: effectiveNow.Add(3 * 24 * time.Hour),
|
||||
NotBefore: effectiveNow,
|
||||
AuthorityKeyId: keyId,
|
||||
SubjectKeyId: keyId,
|
||||
}
|
||||
|
||||
// Create the certificate, PEM encode it and return that value.
|
||||
var buf bytes.Buffer
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, caCert, csr.PublicKey, signer)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error generating certificate: %s", err)
|
||||
}
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encoding certificate: %s", err)
|
||||
}
|
||||
|
||||
err = c.incrementProviderIndex(providerState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Set the response
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// CrossSignCA returns the given CA cert signed by the current active root.
|
||||
func (c *ConsulProvider) CrossSignCA(cert *x509.Certificate) (string, error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// Get the provider state
|
||||
state := c.delegate.State()
|
||||
idx, providerState, err := state.CAProviderState(c.id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privKey, err := connect.ParseSigner(providerState.PrivateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing private key %q: %s", providerState.PrivateKey, err)
|
||||
}
|
||||
|
||||
rootCA, err := connect.ParseCert(providerState.RootCert)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyId, err := connect.KeyId(privKey.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create the cross-signing template from the existing root CA
|
||||
serialNum := &big.Int{}
|
||||
serialNum.SetUint64(idx + 1)
|
||||
template := *cert
|
||||
template.SerialNumber = serialNum
|
||||
template.SignatureAlgorithm = rootCA.SignatureAlgorithm
|
||||
template.AuthorityKeyId = keyId
|
||||
|
||||
// Sign the certificate valid from 1 minute in the past, this helps it be
|
||||
// accepted right away even when nodes are not in close time sync accross the
|
||||
// cluster. A minute is more than enough for typical DC clock drift.
|
||||
effectiveNow := time.Now().Add(-1 * time.Minute)
|
||||
template.NotBefore = effectiveNow
|
||||
// This cross-signed cert is only needed during rotation, and only while old
|
||||
// leaf certs are still in use. They expire within 3 days currently so 7 is
|
||||
// safe. TODO(banks): make this be based on leaf expiry time when that is
|
||||
// configurable.
|
||||
template.NotAfter = effectiveNow.Add(7 * 24 * time.Hour)
|
||||
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, rootCA, cert.PublicKey, privKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error generating CA certificate: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
err = c.incrementProviderIndex(providerState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// incrementProviderIndex does a write to increment the provider state store table index
|
||||
// used for serial numbers when generating certificates.
|
||||
func (c *ConsulProvider) incrementProviderIndex(providerState *structs.CAConsulProviderState) error {
|
||||
newState := *providerState
|
||||
args := &structs.CARequest{
|
||||
Op: structs.CAOpSetProviderState,
|
||||
ProviderState: &newState,
|
||||
}
|
||||
if err := c.delegate.ApplyCARequest(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateCA makes a new root CA using the current private key
|
||||
func (c *ConsulProvider) generateCA(privateKey string, sn uint64) (string, error) {
|
||||
state := c.delegate.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privKey, err := connect.ParseSigner(privateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing private key %q: %s", privateKey, err)
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("Consul CA %d", sn)
|
||||
|
||||
// The URI (SPIFFE compatible) for the cert
|
||||
id := connect.SpiffeIDSigningForCluster(config)
|
||||
keyId, err := connect.KeyId(privKey.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create the CA cert
|
||||
serialNum := &big.Int{}
|
||||
serialNum.SetUint64(sn)
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNum,
|
||||
Subject: pkix.Name{CommonName: name},
|
||||
URIs: []*url.URL{id.URI()},
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageCertSign |
|
||||
x509.KeyUsageCRLSign |
|
||||
x509.KeyUsageDigitalSignature,
|
||||
IsCA: true,
|
||||
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
|
||||
NotBefore: time.Now(),
|
||||
AuthorityKeyId: keyId,
|
||||
SubjectKeyId: keyId,
|
||||
}
|
||||
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, &template, privKey.Public(), privKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error generating CA certificate: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
func ParseConsulCAConfig(raw map[string]interface{}) (*structs.ConsulCAProviderConfig, error) {
|
||||
var config structs.ConsulCAProviderConfig
|
||||
decodeConf := &mapstructure.DecoderConfig{
|
||||
DecodeHook: ParseDurationFunc(),
|
||||
ErrorUnused: true,
|
||||
Result: &config,
|
||||
WeaklyTypedInput: true,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(decodeConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := decoder.Decode(raw); err != nil {
|
||||
return nil, fmt.Errorf("error decoding config: %s", err)
|
||||
}
|
||||
|
||||
if config.PrivateKey == "" && config.RootCert != "" {
|
||||
return nil, fmt.Errorf("must provide a private key when providing a root cert")
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// ParseDurationFunc is a mapstructure hook for decoding a string or
|
||||
// []uint8 into a time.Duration value.
|
||||
func ParseDurationFunc() mapstructure.DecodeHookFunc {
|
||||
return func(
|
||||
f reflect.Type,
|
||||
t reflect.Type,
|
||||
data interface{}) (interface{}, error) {
|
||||
var v time.Duration
|
||||
if t != reflect.TypeOf(v) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case f.Kind() == reflect.String:
|
||||
if dur, err := time.ParseDuration(data.(string)); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
v = dur
|
||||
}
|
||||
return v, nil
|
||||
case f == reflect.SliceOf(reflect.TypeOf(uint8(0))):
|
||||
s := Uint8ToString(data.([]uint8))
|
||||
if dur, err := time.ParseDuration(s); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
v = dur
|
||||
}
|
||||
return v, nil
|
||||
default:
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Uint8ToString(bs []uint8) string {
|
||||
b := make([]byte, len(bs))
|
||||
for i, v := range bs {
|
||||
b[i] = byte(v)
|
||||
}
|
||||
return string(b)
|
||||
}
|
|
@ -0,0 +1,266 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type consulCAMockDelegate struct {
|
||||
state *state.Store
|
||||
}
|
||||
|
||||
func (c *consulCAMockDelegate) State() *state.Store {
|
||||
return c.state
|
||||
}
|
||||
|
||||
func (c *consulCAMockDelegate) ApplyCARequest(req *structs.CARequest) error {
|
||||
idx, _, err := c.state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch req.Op {
|
||||
case structs.CAOpSetProviderState:
|
||||
_, err := c.state.CASetProviderState(idx+1, req.ProviderState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
case structs.CAOpDeleteProviderState:
|
||||
if err := c.state.CADeleteProviderState(req.ProviderState.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("Invalid CA operation '%s'", req.Op)
|
||||
}
|
||||
}
|
||||
|
||||
func newMockDelegate(t *testing.T, conf *structs.CAConfiguration) *consulCAMockDelegate {
|
||||
s, err := state.NewStateStore(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if s == nil {
|
||||
t.Fatalf("missing state store")
|
||||
}
|
||||
if err := s.CASetConfig(conf.RaftIndex.CreateIndex, conf); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
return &consulCAMockDelegate{s}
|
||||
}
|
||||
|
||||
func testConsulCAConfig() *structs.CAConfiguration {
|
||||
return &structs.CAConfiguration{
|
||||
ClusterID: "asdf",
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsulCAProvider_Bootstrap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
conf := testConsulCAConfig()
|
||||
delegate := newMockDelegate(t, conf)
|
||||
|
||||
provider, err := NewConsulProvider(conf.Config, delegate)
|
||||
assert.NoError(err)
|
||||
|
||||
root, err := provider.ActiveRoot()
|
||||
assert.NoError(err)
|
||||
|
||||
// Intermediate should be the same cert.
|
||||
inter, err := provider.ActiveIntermediate()
|
||||
assert.NoError(err)
|
||||
assert.Equal(root, inter)
|
||||
|
||||
// Should be a valid cert
|
||||
parsed, err := connect.ParseCert(root)
|
||||
assert.NoError(err)
|
||||
assert.Equal(parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", conf.ClusterID))
|
||||
}
|
||||
|
||||
func TestConsulCAProvider_Bootstrap_WithCert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Make sure setting a custom private key/root cert works.
|
||||
assert := assert.New(t)
|
||||
rootCA := connect.TestCA(t, nil)
|
||||
conf := testConsulCAConfig()
|
||||
conf.Config = map[string]interface{}{
|
||||
"PrivateKey": rootCA.SigningKey,
|
||||
"RootCert": rootCA.RootCert,
|
||||
}
|
||||
delegate := newMockDelegate(t, conf)
|
||||
|
||||
provider, err := NewConsulProvider(conf.Config, delegate)
|
||||
assert.NoError(err)
|
||||
|
||||
root, err := provider.ActiveRoot()
|
||||
assert.NoError(err)
|
||||
assert.Equal(root, rootCA.RootCert)
|
||||
}
|
||||
|
||||
func TestConsulCAProvider_SignLeaf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
conf := testConsulCAConfig()
|
||||
delegate := newMockDelegate(t, conf)
|
||||
|
||||
provider, err := NewConsulProvider(conf.Config, delegate)
|
||||
assert.NoError(err)
|
||||
|
||||
spiffeService := &connect.SpiffeIDService{
|
||||
Host: "node1",
|
||||
Namespace: "default",
|
||||
Datacenter: "dc1",
|
||||
Service: "foo",
|
||||
}
|
||||
|
||||
// Generate a leaf cert for the service.
|
||||
{
|
||||
raw, _ := connect.TestCSR(t, spiffeService)
|
||||
|
||||
csr, err := connect.ParseCSR(raw)
|
||||
assert.NoError(err)
|
||||
|
||||
cert, err := provider.Sign(csr)
|
||||
assert.NoError(err)
|
||||
|
||||
parsed, err := connect.ParseCert(cert)
|
||||
assert.NoError(err)
|
||||
assert.Equal(parsed.URIs[0], spiffeService.URI())
|
||||
assert.Equal(parsed.Subject.CommonName, "foo")
|
||||
assert.Equal(uint64(2), parsed.SerialNumber.Uint64())
|
||||
|
||||
// Ensure the cert is valid now and expires within the correct limit.
|
||||
assert.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
|
||||
assert.True(parsed.NotBefore.Before(time.Now()))
|
||||
}
|
||||
|
||||
// Generate a new cert for another service and make sure
|
||||
// the serial number is incremented.
|
||||
spiffeService.Service = "bar"
|
||||
{
|
||||
raw, _ := connect.TestCSR(t, spiffeService)
|
||||
|
||||
csr, err := connect.ParseCSR(raw)
|
||||
assert.NoError(err)
|
||||
|
||||
cert, err := provider.Sign(csr)
|
||||
assert.NoError(err)
|
||||
|
||||
parsed, err := connect.ParseCert(cert)
|
||||
assert.NoError(err)
|
||||
assert.Equal(parsed.URIs[0], spiffeService.URI())
|
||||
assert.Equal(parsed.Subject.CommonName, "bar")
|
||||
assert.Equal(parsed.SerialNumber.Uint64(), uint64(2))
|
||||
|
||||
// Ensure the cert is valid now and expires within the correct limit.
|
||||
assert.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
|
||||
assert.True(parsed.NotBefore.Before(time.Now()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsulCAProvider_CrossSignCA(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conf1 := testConsulCAConfig()
|
||||
delegate1 := newMockDelegate(t, conf1)
|
||||
provider1, err := NewConsulProvider(conf1.Config, delegate1)
|
||||
require.NoError(t, err)
|
||||
|
||||
conf2 := testConsulCAConfig()
|
||||
conf2.CreateIndex = 10
|
||||
delegate2 := newMockDelegate(t, conf2)
|
||||
provider2, err := NewConsulProvider(conf2.Config, delegate2)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCrossSignProviders(t, provider1, provider2)
|
||||
}
|
||||
|
||||
func testCrossSignProviders(t *testing.T, provider1, provider2 Provider) {
|
||||
require := require.New(t)
|
||||
|
||||
// Get the root from the new provider to be cross-signed.
|
||||
newRootPEM, err := provider2.ActiveRoot()
|
||||
require.NoError(err)
|
||||
newRoot, err := connect.ParseCert(newRootPEM)
|
||||
require.NoError(err)
|
||||
oldSubject := newRoot.Subject.CommonName
|
||||
|
||||
newInterPEM, err := provider2.ActiveIntermediate()
|
||||
require.NoError(err)
|
||||
newIntermediate, err := connect.ParseCert(newInterPEM)
|
||||
require.NoError(err)
|
||||
|
||||
// Have provider1 cross sign our new root cert.
|
||||
xcPEM, err := provider1.CrossSignCA(newRoot)
|
||||
require.NoError(err)
|
||||
xc, err := connect.ParseCert(xcPEM)
|
||||
require.NoError(err)
|
||||
|
||||
oldRootPEM, err := provider1.ActiveRoot()
|
||||
require.NoError(err)
|
||||
oldRoot, err := connect.ParseCert(oldRootPEM)
|
||||
require.NoError(err)
|
||||
|
||||
// AuthorityKeyID should now be the signing root's, SubjectKeyId should be kept.
|
||||
require.Equal(oldRoot.AuthorityKeyId, xc.AuthorityKeyId)
|
||||
require.Equal(newRoot.SubjectKeyId, xc.SubjectKeyId)
|
||||
|
||||
// Subject name should not have changed.
|
||||
require.Equal(oldSubject, xc.Subject.CommonName)
|
||||
|
||||
// Issuer should be the signing root.
|
||||
require.Equal(oldRoot.Issuer.CommonName, xc.Issuer.CommonName)
|
||||
|
||||
// Get a leaf cert so we can verify against the cross-signed cert.
|
||||
spiffeService := &connect.SpiffeIDService{
|
||||
Host: "node1",
|
||||
Namespace: "default",
|
||||
Datacenter: "dc1",
|
||||
Service: "foo",
|
||||
}
|
||||
raw, _ := connect.TestCSR(t, spiffeService)
|
||||
|
||||
leafCsr, err := connect.ParseCSR(raw)
|
||||
require.NoError(err)
|
||||
|
||||
leafPEM, err := provider2.Sign(leafCsr)
|
||||
require.NoError(err)
|
||||
|
||||
cert, err := connect.ParseCert(leafPEM)
|
||||
require.NoError(err)
|
||||
|
||||
// Check that the leaf signed by the new cert can be verified by either root
|
||||
// certificate by using the new intermediate + cross-signed cert.
|
||||
intermediatePool := x509.NewCertPool()
|
||||
intermediatePool.AddCert(newIntermediate)
|
||||
intermediatePool.AddCert(xc)
|
||||
|
||||
for _, root := range []*x509.Certificate{oldRoot, newRoot} {
|
||||
rootPool := x509.NewCertPool()
|
||||
rootPool.AddCert(root)
|
||||
|
||||
_, err = cert.Verify(x509.VerifyOptions{
|
||||
Intermediates: intermediatePool,
|
||||
Roots: rootPool,
|
||||
})
|
||||
require.NoError(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,322 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
vaultapi "github.com/hashicorp/vault/api"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const VaultCALeafCertRole = "leaf-cert"
|
||||
|
||||
var ErrBackendNotMounted = fmt.Errorf("backend not mounted")
|
||||
var ErrBackendNotInitialized = fmt.Errorf("backend not initialized")
|
||||
|
||||
type VaultProvider struct {
|
||||
config *structs.VaultCAProviderConfig
|
||||
client *vaultapi.Client
|
||||
clusterId string
|
||||
}
|
||||
|
||||
// NewVaultProvider returns a vault provider with its root and intermediate PKI
|
||||
// backends mounted and initialized. If the root backend is not set up already,
|
||||
// it will be mounted/generated as needed, but any existing state will not be
|
||||
// overwritten.
|
||||
func NewVaultProvider(rawConfig map[string]interface{}, clusterId string) (*VaultProvider, error) {
|
||||
conf, err := ParseVaultCAConfig(rawConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// todo(kyhavlov): figure out the right way to pass the TLS config
|
||||
clientConf := &vaultapi.Config{
|
||||
Address: conf.Address,
|
||||
}
|
||||
client, err := vaultapi.NewClient(clientConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client.SetToken(conf.Token)
|
||||
|
||||
provider := &VaultProvider{
|
||||
config: conf,
|
||||
client: client,
|
||||
clusterId: clusterId,
|
||||
}
|
||||
|
||||
// Set up the root PKI backend if necessary.
|
||||
_, err = provider.ActiveRoot()
|
||||
switch err {
|
||||
case ErrBackendNotMounted:
|
||||
err := client.Sys().Mount(conf.RootPKIPath, &vaultapi.MountInput{
|
||||
Type: "pki",
|
||||
Description: "root CA backend for Consul Connect",
|
||||
Config: vaultapi.MountConfigInput{
|
||||
MaxLeaseTTL: "8760h",
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fallthrough
|
||||
case ErrBackendNotInitialized:
|
||||
spiffeID := connect.SpiffeIDSigning{ClusterID: clusterId, Domain: "consul"}
|
||||
uuid, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = client.Logical().Write(conf.RootPKIPath+"root/generate/internal", map[string]interface{}{
|
||||
"common_name": fmt.Sprintf("Vault CA Root Authority %s", uuid),
|
||||
"uri_sans": spiffeID.URI().String(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Set up the intermediate backend.
|
||||
if _, err := provider.GenerateIntermediate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (v *VaultProvider) ActiveRoot() (string, error) {
|
||||
return v.getCA(v.config.RootPKIPath)
|
||||
}
|
||||
|
||||
func (v *VaultProvider) ActiveIntermediate() (string, error) {
|
||||
return v.getCA(v.config.IntermediatePKIPath)
|
||||
}
|
||||
|
||||
// getCA returns the raw CA cert for the given endpoint if there is one.
|
||||
// We have to use the raw NewRequest call here instead of Logical().Read
|
||||
// because the endpoint only returns the raw PEM contents of the CA cert
|
||||
// and not the typical format of the secrets endpoints.
|
||||
func (v *VaultProvider) getCA(path string) (string, error) {
|
||||
req := v.client.NewRequest("GET", "/v1/"+path+"/ca/pem")
|
||||
resp, err := v.client.RawRequest(req)
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if resp != nil && resp.StatusCode == http.StatusNotFound {
|
||||
return "", ErrBackendNotMounted
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
bytes, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
root := string(bytes)
|
||||
if root == "" {
|
||||
return "", ErrBackendNotInitialized
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// GenerateIntermediate mounts the configured intermediate PKI backend if
|
||||
// necessary, then generates and signs a new CA CSR using the root PKI backend
|
||||
// and updates the intermediate backend to use that new certificate.
|
||||
func (v *VaultProvider) GenerateIntermediate() (string, error) {
|
||||
mounts, err := v.client.Sys().ListMounts()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Mount the backend if it isn't mounted already.
|
||||
if _, ok := mounts[v.config.IntermediatePKIPath]; !ok {
|
||||
err := v.client.Sys().Mount(v.config.IntermediatePKIPath, &vaultapi.MountInput{
|
||||
Type: "pki",
|
||||
Description: "intermediate CA backend for Consul Connect",
|
||||
Config: vaultapi.MountConfigInput{
|
||||
MaxLeaseTTL: "2160h",
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Create the role for issuing leaf certs if it doesn't exist yet
|
||||
rolePath := v.config.IntermediatePKIPath + "roles/" + VaultCALeafCertRole
|
||||
role, err := v.client.Logical().Read(rolePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
spiffeID := connect.SpiffeIDSigning{ClusterID: v.clusterId, Domain: "consul"}
|
||||
if role == nil {
|
||||
_, err := v.client.Logical().Write(rolePath, map[string]interface{}{
|
||||
"allow_any_name": true,
|
||||
"allowed_uri_sans": "spiffe://*",
|
||||
"key_type": "any",
|
||||
"max_ttl": "72h",
|
||||
"require_cn": false,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a new intermediate CSR for the root to sign.
|
||||
csr, err := v.client.Logical().Write(v.config.IntermediatePKIPath+"intermediate/generate/internal", map[string]interface{}{
|
||||
"common_name": "Vault CA Intermediate Authority",
|
||||
"uri_sans": spiffeID.URI().String(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if csr == nil || csr.Data["csr"] == "" {
|
||||
return "", fmt.Errorf("got empty value when generating intermediate CSR")
|
||||
}
|
||||
|
||||
// Sign the CSR with the root backend.
|
||||
intermediate, err := v.client.Logical().Write(v.config.RootPKIPath+"root/sign-intermediate", map[string]interface{}{
|
||||
"csr": csr.Data["csr"],
|
||||
"format": "pem_bundle",
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if intermediate == nil || intermediate.Data["certificate"] == "" {
|
||||
return "", fmt.Errorf("got empty value when generating intermediate certificate")
|
||||
}
|
||||
|
||||
// Set the intermediate backend to use the new certificate.
|
||||
_, err = v.client.Logical().Write(v.config.IntermediatePKIPath+"intermediate/set-signed", map[string]interface{}{
|
||||
"certificate": intermediate.Data["certificate"],
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return v.ActiveIntermediate()
|
||||
}
|
||||
|
||||
// Sign calls the configured role in the intermediate PKI backend to issue
|
||||
// a new leaf certificate based on the provided CSR, with the issuing
|
||||
// intermediate CA cert attached.
|
||||
func (v *VaultProvider) Sign(csr *x509.CertificateRequest) (string, error) {
|
||||
var pemBuf bytes.Buffer
|
||||
if err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Use the leaf cert role to sign a new cert for this CSR.
|
||||
response, err := v.client.Logical().Write(v.config.IntermediatePKIPath+"sign/"+VaultCALeafCertRole, map[string]interface{}{
|
||||
"csr": pemBuf.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error issuing cert: %v", err)
|
||||
}
|
||||
if response == nil || response.Data["certificate"] == "" || response.Data["issuing_ca"] == "" {
|
||||
return "", fmt.Errorf("certificate info returned from Vault was blank")
|
||||
}
|
||||
|
||||
cert, ok := response.Data["certificate"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("certificate was not a string")
|
||||
}
|
||||
ca, ok := response.Data["issuing_ca"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("issuing_ca was not a string")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s\n%s", cert, ca), nil
|
||||
}
|
||||
|
||||
// CrossSignCA takes a CA certificate and cross-signs it to form a trust chain
|
||||
// back to our active root.
|
||||
func (v *VaultProvider) CrossSignCA(cert *x509.Certificate) (string, error) {
|
||||
var pemBuf bytes.Buffer
|
||||
err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Have the root PKI backend sign this cert.
|
||||
response, err := v.client.Logical().Write(v.config.RootPKIPath+"root/sign-self-issued", map[string]interface{}{
|
||||
"certificate": pemBuf.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error having Vault cross-sign cert: %v", err)
|
||||
}
|
||||
if response == nil || response.Data["certificate"] == "" {
|
||||
return "", fmt.Errorf("certificate info returned from Vault was blank")
|
||||
}
|
||||
|
||||
xcCert, ok := response.Data["certificate"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("certificate was not a string")
|
||||
}
|
||||
|
||||
return xcCert, nil
|
||||
}
|
||||
|
||||
// Cleanup unmounts the configured intermediate PKI backend. It's fine to tear
|
||||
// this down and recreate it on small config changes because the intermediate
|
||||
// certs get bundled with the leaf certs, so there's no cost to the CA changing.
|
||||
func (v *VaultProvider) Cleanup() error {
|
||||
return v.client.Sys().Unmount(v.config.IntermediatePKIPath)
|
||||
}
|
||||
|
||||
func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) {
|
||||
var config structs.VaultCAProviderConfig
|
||||
|
||||
decodeConf := &mapstructure.DecoderConfig{
|
||||
ErrorUnused: true,
|
||||
Result: &config,
|
||||
WeaklyTypedInput: true,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(decodeConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := decoder.Decode(raw); err != nil {
|
||||
return nil, fmt.Errorf("error decoding config: %s", err)
|
||||
}
|
||||
|
||||
if config.Token == "" {
|
||||
return nil, fmt.Errorf("must provide a Vault token")
|
||||
}
|
||||
|
||||
if config.RootPKIPath == "" {
|
||||
return nil, fmt.Errorf("must provide a valid path to a root PKI backend")
|
||||
}
|
||||
if !strings.HasSuffix(config.RootPKIPath, "/") {
|
||||
config.RootPKIPath += "/"
|
||||
}
|
||||
|
||||
if config.IntermediatePKIPath == "" {
|
||||
return nil, fmt.Errorf("must provide a valid path for the intermediate PKI backend")
|
||||
}
|
||||
if !strings.HasSuffix(config.IntermediatePKIPath, "/") {
|
||||
config.IntermediatePKIPath += "/"
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
vaultapi "github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/pki"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testVaultCluster(t *testing.T) (*VaultProvider, *vault.Core, net.Listener) {
|
||||
if err := vault.AddTestLogicalBackend("pki", pki.Factory); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
core, _, token := vault.TestCoreUnsealedRaw(t)
|
||||
|
||||
ln, addr := vaulthttp.TestServer(t, core)
|
||||
|
||||
provider, err := NewVaultProvider(map[string]interface{}{
|
||||
"Address": addr,
|
||||
"Token": token,
|
||||
"RootPKIPath": "pki-root/",
|
||||
"IntermediatePKIPath": "pki-intermediate/",
|
||||
}, "asdf")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return provider, core, ln
|
||||
}
|
||||
|
||||
func TestVaultCAProvider_Bootstrap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
provider, core, listener := testVaultCluster(t)
|
||||
defer core.Shutdown()
|
||||
defer listener.Close()
|
||||
client, err := vaultapi.NewClient(&vaultapi.Config{
|
||||
Address: "http://" + listener.Addr().String(),
|
||||
})
|
||||
require.NoError(err)
|
||||
client.SetToken(provider.config.Token)
|
||||
|
||||
cases := []struct {
|
||||
certFunc func() (string, error)
|
||||
backendPath string
|
||||
}{
|
||||
{
|
||||
certFunc: provider.ActiveRoot,
|
||||
backendPath: "pki-root/",
|
||||
},
|
||||
{
|
||||
certFunc: provider.ActiveIntermediate,
|
||||
backendPath: "pki-intermediate/",
|
||||
},
|
||||
}
|
||||
|
||||
// Verify the root and intermediate certs match the ones in the vault backends
|
||||
for _, tc := range cases {
|
||||
cert, err := tc.certFunc()
|
||||
require.NoError(err)
|
||||
req := client.NewRequest("GET", "/v1/"+tc.backendPath+"ca/pem")
|
||||
resp, err := client.RawRequest(req)
|
||||
require.NoError(err)
|
||||
bytes, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(err)
|
||||
require.Equal(cert, string(bytes))
|
||||
|
||||
// Should be a valid CA cert
|
||||
parsed, err := connect.ParseCert(cert)
|
||||
require.NoError(err)
|
||||
require.True(parsed.IsCA)
|
||||
require.Len(parsed.URIs, 1)
|
||||
require.Equal(parsed.URIs[0].String(), fmt.Sprintf("spiffe://%s.consul", provider.clusterId))
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultCAProvider_SignLeaf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
provider, core, listener := testVaultCluster(t)
|
||||
defer core.Shutdown()
|
||||
defer listener.Close()
|
||||
client, err := vaultapi.NewClient(&vaultapi.Config{
|
||||
Address: "http://" + listener.Addr().String(),
|
||||
})
|
||||
require.NoError(err)
|
||||
client.SetToken(provider.config.Token)
|
||||
|
||||
spiffeService := &connect.SpiffeIDService{
|
||||
Host: "node1",
|
||||
Namespace: "default",
|
||||
Datacenter: "dc1",
|
||||
Service: "foo",
|
||||
}
|
||||
|
||||
// Generate a leaf cert for the service.
|
||||
var firstSerial uint64
|
||||
{
|
||||
raw, _ := connect.TestCSR(t, spiffeService)
|
||||
|
||||
csr, err := connect.ParseCSR(raw)
|
||||
require.NoError(err)
|
||||
|
||||
cert, err := provider.Sign(csr)
|
||||
require.NoError(err)
|
||||
|
||||
parsed, err := connect.ParseCert(cert)
|
||||
require.NoError(err)
|
||||
require.Equal(parsed.URIs[0], spiffeService.URI())
|
||||
firstSerial = parsed.SerialNumber.Uint64()
|
||||
|
||||
// Ensure the cert is valid now and expires within the correct limit.
|
||||
require.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
|
||||
require.True(parsed.NotBefore.Before(time.Now()))
|
||||
}
|
||||
|
||||
// Generate a new cert for another service and make sure
|
||||
// the serial number is unique.
|
||||
spiffeService.Service = "bar"
|
||||
{
|
||||
raw, _ := connect.TestCSR(t, spiffeService)
|
||||
|
||||
csr, err := connect.ParseCSR(raw)
|
||||
require.NoError(err)
|
||||
|
||||
cert, err := provider.Sign(csr)
|
||||
require.NoError(err)
|
||||
|
||||
parsed, err := connect.ParseCert(cert)
|
||||
require.NoError(err)
|
||||
require.Equal(parsed.URIs[0], spiffeService.URI())
|
||||
require.NotEqual(firstSerial, parsed.SerialNumber.Uint64())
|
||||
|
||||
// Ensure the cert is valid now and expires within the correct limit.
|
||||
require.True(parsed.NotAfter.Sub(time.Now()) < 3*24*time.Hour)
|
||||
require.True(parsed.NotBefore.Before(time.Now()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultCAProvider_CrossSignCA(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider1, core1, listener1 := testVaultCluster(t)
|
||||
defer core1.Shutdown()
|
||||
defer listener1.Close()
|
||||
|
||||
provider2, core2, listener2 := testVaultCluster(t)
|
||||
defer core2.Shutdown()
|
||||
defer listener2.Close()
|
||||
|
||||
testCrossSignProviders(t, provider1, provider2)
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// CreateCSR returns a CSR to sign the given service along with the PEM-encoded
|
||||
// private key for this certificate.
|
||||
func CreateCSR(uri CertURI, privateKey crypto.Signer) (string, error) {
|
||||
template := &x509.CertificateRequest{
|
||||
URIs: []*url.URL{uri.URI()},
|
||||
SignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
}
|
||||
|
||||
// Create the CSR itself
|
||||
var csrBuf bytes.Buffer
|
||||
bs, err := x509.CreateCertificateRequest(rand.Reader, template, privateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = pem.Encode(&csrBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: bs})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return csrBuf.String(), nil
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GeneratePrivateKey generates a new Private key
|
||||
func GeneratePrivateKey() (crypto.Signer, string, error) {
|
||||
var pk *ecdsa.PrivateKey
|
||||
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error generating private key: %s", err)
|
||||
}
|
||||
|
||||
bs, err := x509.MarshalECPrivateKey(pk)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error generating private key: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: bs})
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
return pk, buf.String(), nil
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseCert parses the x509 certificate from a PEM-encoded value.
|
||||
func ParseCert(pemValue string) (*x509.Certificate, error) {
|
||||
// The _ result below is not an error but the remaining PEM bytes.
|
||||
block, _ := pem.Decode([]byte(pemValue))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM-encoded data found")
|
||||
}
|
||||
|
||||
if block.Type != "CERTIFICATE" {
|
||||
return nil, fmt.Errorf("first PEM-block should be CERTIFICATE type")
|
||||
}
|
||||
|
||||
return x509.ParseCertificate(block.Bytes)
|
||||
}
|
||||
|
||||
// CalculateCertFingerprint parses the x509 certificate from a PEM-encoded value
|
||||
// and calculates the SHA-1 fingerprint.
|
||||
func CalculateCertFingerprint(pemValue string) (string, error) {
|
||||
// The _ result below is not an error but the remaining PEM bytes.
|
||||
block, _ := pem.Decode([]byte(pemValue))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("no PEM-encoded data found")
|
||||
}
|
||||
|
||||
if block.Type != "CERTIFICATE" {
|
||||
return "", fmt.Errorf("first PEM-block should be CERTIFICATE type")
|
||||
}
|
||||
|
||||
hash := sha1.Sum(block.Bytes)
|
||||
return HexString(hash[:]), nil
|
||||
}
|
||||
|
||||
// ParseSigner parses a crypto.Signer from a PEM-encoded key. The private key
|
||||
// is expected to be the first block in the PEM value.
|
||||
func ParseSigner(pemValue string) (crypto.Signer, error) {
|
||||
// The _ result below is not an error but the remaining PEM bytes.
|
||||
block, _ := pem.Decode([]byte(pemValue))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM-encoded data found")
|
||||
}
|
||||
|
||||
switch block.Type {
|
||||
case "EC PRIVATE KEY":
|
||||
return x509.ParseECPrivateKey(block.Bytes)
|
||||
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
|
||||
case "PRIVATE KEY":
|
||||
signer, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pk, ok := signer.(crypto.Signer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("private key is not a valid format")
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown PEM block type for signing key: %s", block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// ParseCSR parses a CSR from a PEM-encoded value. The certificate request
|
||||
// must be the the first block in the PEM value.
|
||||
func ParseCSR(pemValue string) (*x509.CertificateRequest, error) {
|
||||
// The _ result below is not an error but the remaining PEM bytes.
|
||||
block, _ := pem.Decode([]byte(pemValue))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM-encoded data found")
|
||||
}
|
||||
|
||||
if block.Type != "CERTIFICATE REQUEST" {
|
||||
return nil, fmt.Errorf("first PEM-block should be CERTIFICATE REQUEST type")
|
||||
}
|
||||
|
||||
return x509.ParseCertificateRequest(block.Bytes)
|
||||
}
|
||||
|
||||
// KeyId returns a x509 KeyId from the given signing key. The key must be
|
||||
// an *ecdsa.PublicKey currently, but may support more types in the future.
|
||||
func KeyId(raw interface{}) ([]byte, error) {
|
||||
switch raw.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid key type: %T", raw)
|
||||
}
|
||||
|
||||
// This is not standard; RFC allows any unique identifier as long as they
|
||||
// match in subject/authority chains but suggests specific hashing of DER
|
||||
// bytes of public key including DER tags.
|
||||
bs, err := x509.MarshalPKIXPublicKey(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// String formatted
|
||||
kID := sha256.Sum256(bs)
|
||||
return []byte(strings.Replace(fmt.Sprintf("% x", kID), " ", ":", -1)), nil
|
||||
}
|
||||
|
||||
// HexString returns a standard colon-separated hex value for the input
|
||||
// byte slice. This should be used with cert serial numbers and so on.
|
||||
func HexString(input []byte) string {
|
||||
return strings.Replace(fmt.Sprintf("% x", input), " ", ":", -1)
|
||||
}
|
|
@ -0,0 +1,332 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
)
|
||||
|
||||
// TestClusterID is the Consul cluster ID for testing.
|
||||
const TestClusterID = "11111111-2222-3333-4444-555555555555"
|
||||
|
||||
// testCACounter is just an atomically incremented counter for creating
|
||||
// unique names for the CA certs.
|
||||
var testCACounter uint64
|
||||
|
||||
// TestCA creates a test CA certificate and signing key and returns it
|
||||
// in the CARoot structure format. The returned CA will be set as Active = true.
|
||||
//
|
||||
// If xc is non-nil, then the returned certificate will have a signing cert
|
||||
// that is cross-signed with the previous cert, and this will be set as
|
||||
// SigningCert.
|
||||
func TestCA(t testing.T, xc *structs.CARoot) *structs.CARoot {
|
||||
var result structs.CARoot
|
||||
result.Active = true
|
||||
result.Name = fmt.Sprintf("Test CA %d", atomic.AddUint64(&testCACounter, 1))
|
||||
|
||||
// Create the private key we'll use for this CA cert.
|
||||
signer, keyPEM := testPrivateKey(t)
|
||||
result.SigningKey = keyPEM
|
||||
|
||||
// The serial number for the cert
|
||||
sn, err := testSerialNumber()
|
||||
if err != nil {
|
||||
t.Fatalf("error generating serial number: %s", err)
|
||||
}
|
||||
|
||||
// The URI (SPIFFE compatible) for the cert
|
||||
id := &SpiffeIDSigning{ClusterID: TestClusterID, Domain: "consul"}
|
||||
|
||||
// Create the CA cert
|
||||
template := x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
Subject: pkix.Name{CommonName: result.Name},
|
||||
URIs: []*url.URL{id.URI()},
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageCertSign |
|
||||
x509.KeyUsageCRLSign |
|
||||
x509.KeyUsageDigitalSignature,
|
||||
IsCA: true,
|
||||
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
|
||||
NotBefore: time.Now(),
|
||||
AuthorityKeyId: testKeyID(t, signer.Public()),
|
||||
SubjectKeyId: testKeyID(t, signer.Public()),
|
||||
}
|
||||
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, &template, signer.Public(), signer)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating CA certificate: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
t.Fatalf("error encoding private key: %s", err)
|
||||
}
|
||||
result.RootCert = buf.String()
|
||||
result.ID, err = CalculateCertFingerprint(result.RootCert)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating CA ID fingerprint: %s", err)
|
||||
}
|
||||
|
||||
// If there is a prior CA to cross-sign with, then we need to create that
|
||||
// and set it as the signing cert.
|
||||
if xc != nil {
|
||||
xccert, err := ParseCert(xc.RootCert)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing CA cert: %s", err)
|
||||
}
|
||||
xcsigner, err := ParseSigner(xc.SigningKey)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing signing key: %s", err)
|
||||
}
|
||||
|
||||
// Set the authority key to be the previous one.
|
||||
// NOTE(mitchellh): From Paul Banks: if we have to cross-sign a cert
|
||||
// that came from outside (e.g. vault) we can't rely on them using the
|
||||
// same KeyID hashing algo we do so we'd need to actually copy this
|
||||
// from the xc cert's subjectKeyIdentifier extension.
|
||||
template.AuthorityKeyId = testKeyID(t, xcsigner.Public())
|
||||
|
||||
// Create the new certificate where the parent is the previous
|
||||
// CA, the public key is the new public key, and the signing private
|
||||
// key is the old private key.
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, xccert, signer.Public(), xcsigner)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating CA certificate: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
t.Fatalf("error encoding private key: %s", err)
|
||||
}
|
||||
result.SigningCert = buf.String()
|
||||
}
|
||||
|
||||
return &result
|
||||
}
|
||||
|
||||
// TestLeaf returns a valid leaf certificate and it's private key for the named
|
||||
// service with the given CA Root.
|
||||
func TestLeaf(t testing.T, service string, root *structs.CARoot) (string, string) {
|
||||
// Parse the CA cert and signing key from the root
|
||||
cert := root.SigningCert
|
||||
if cert == "" {
|
||||
cert = root.RootCert
|
||||
}
|
||||
caCert, err := ParseCert(cert)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing CA cert: %s", err)
|
||||
}
|
||||
caSigner, err := ParseSigner(root.SigningKey)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing signing key: %s", err)
|
||||
}
|
||||
|
||||
// Build the SPIFFE ID
|
||||
spiffeId := &SpiffeIDService{
|
||||
Host: fmt.Sprintf("%s.consul", TestClusterID),
|
||||
Namespace: "default",
|
||||
Datacenter: "dc1",
|
||||
Service: service,
|
||||
}
|
||||
|
||||
// The serial number for the cert
|
||||
sn, err := testSerialNumber()
|
||||
if err != nil {
|
||||
t.Fatalf("error generating serial number: %s", err)
|
||||
}
|
||||
|
||||
// Generate fresh private key
|
||||
pkSigner, pkPEM, err := GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key: %s", err)
|
||||
}
|
||||
|
||||
// Cert template for generation
|
||||
template := x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
Subject: pkix.Name{CommonName: service},
|
||||
URIs: []*url.URL{spiffeId.URI()},
|
||||
SignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageDataEncipherment |
|
||||
x509.KeyUsageKeyAgreement |
|
||||
x509.KeyUsageDigitalSignature |
|
||||
x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
|
||||
NotBefore: time.Now(),
|
||||
AuthorityKeyId: testKeyID(t, caSigner.Public()),
|
||||
SubjectKeyId: testKeyID(t, pkSigner.Public()),
|
||||
}
|
||||
|
||||
// Create the certificate, PEM encode it and return that value.
|
||||
var buf bytes.Buffer
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, caCert, pkSigner.Public(), caSigner)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating certificate: %s", err)
|
||||
}
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs})
|
||||
if err != nil {
|
||||
t.Fatalf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
return buf.String(), pkPEM
|
||||
}
|
||||
|
||||
// TestCSR returns a CSR to sign the given service along with the PEM-encoded
|
||||
// private key for this certificate.
|
||||
func TestCSR(t testing.T, uri CertURI) (string, string) {
|
||||
template := &x509.CertificateRequest{
|
||||
URIs: []*url.URL{uri.URI()},
|
||||
SignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
}
|
||||
|
||||
// Create the private key we'll use
|
||||
signer, pkPEM := testPrivateKey(t)
|
||||
|
||||
// Create the CSR itself
|
||||
var csrBuf bytes.Buffer
|
||||
bs, err := x509.CreateCertificateRequest(rand.Reader, template, signer)
|
||||
if err != nil {
|
||||
t.Fatalf("error creating CSR: %s", err)
|
||||
}
|
||||
|
||||
err = pem.Encode(&csrBuf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: bs})
|
||||
if err != nil {
|
||||
t.Fatalf("error encoding CSR: %s", err)
|
||||
}
|
||||
|
||||
return csrBuf.String(), pkPEM
|
||||
}
|
||||
|
||||
// testKeyID returns a KeyID from the given public key. This just calls
|
||||
// KeyId but handles errors for tests.
|
||||
func testKeyID(t testing.T, raw interface{}) []byte {
|
||||
result, err := KeyId(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("KeyId error: %s", err)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// testPrivateKey creates an ECDSA based private key. Both a crypto.Signer and
|
||||
// the key in PEM form are returned.
|
||||
//
|
||||
// NOTE(banks): this was memoized to save entropy during tests but it turns out
|
||||
// crypto/rand will never block and always reads from /dev/urandom on unix OSes
|
||||
// which does not consume entropy.
|
||||
//
|
||||
// If we find by profiling it's taking a lot of cycles we could optimise/cache
|
||||
// again but we at least need to use different keys for each distinct CA (when
|
||||
// multiple CAs are generated at once e.g. to test cross-signing) and a
|
||||
// different one again for the leafs otherwise we risk tests that have false
|
||||
// positives since signatures from different logical cert's keys are
|
||||
// indistinguishable, but worse we build validation chains using AuthorityKeyID
|
||||
// which will be the same for multiple CAs/Leafs. Also note that our UUID
|
||||
// generator also reads from crypto rand and is called far more often during
|
||||
// tests than this will be.
|
||||
func testPrivateKey(t testing.T) (crypto.Signer, string) {
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating private key: %s", err)
|
||||
}
|
||||
|
||||
bs, err := x509.MarshalECPrivateKey(pk)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating private key: %s", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: bs})
|
||||
if err != nil {
|
||||
t.Fatalf("error encoding private key: %s", err)
|
||||
}
|
||||
|
||||
return pk, buf.String()
|
||||
}
|
||||
|
||||
// testSerialNumber generates a serial number suitable for a certificate.
|
||||
// For testing, this just sets it to a random number.
|
||||
//
|
||||
// This function is taken directly from the Vault implementation.
|
||||
func testSerialNumber() (*big.Int, error) {
|
||||
return rand.Int(rand.Reader, (&big.Int{}).Exp(big.NewInt(2), big.NewInt(159), nil))
|
||||
}
|
||||
|
||||
// testUUID generates a UUID for testing.
|
||||
func testUUID(t testing.T) string {
|
||||
ret, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate a UUID, %s", err)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// TestAgentRPC is an interface that an RPC client must implement. This is a
|
||||
// helper interface that is implemented by the agent delegate so that test
|
||||
// helpers can make RPCs without introducing an import cycle on `agent`.
|
||||
type TestAgentRPC interface {
|
||||
RPC(method string, args interface{}, reply interface{}) error
|
||||
}
|
||||
|
||||
// TestCAConfigSet sets a CARoot returned by TestCA into the TestAgent state. It
|
||||
// requires that TestAgent had connect enabled in it's config. If ca is nil, a
|
||||
// new CA is created.
|
||||
//
|
||||
// It returns the CARoot passed or created.
|
||||
//
|
||||
// Note that we have to use an interface for the TestAgent.RPC method since we
|
||||
// can't introduce an import cycle by importing `agent.TestAgent` here directly.
|
||||
// It also means this will work in a few other places we mock that method.
|
||||
func TestCAConfigSet(t testing.T, a TestAgentRPC,
|
||||
ca *structs.CARoot) *structs.CARoot {
|
||||
t.Helper()
|
||||
|
||||
if ca == nil {
|
||||
ca = TestCA(t, nil)
|
||||
}
|
||||
newConfig := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": ca.SigningKey,
|
||||
"RootCert": ca.RootCert,
|
||||
"RotationPeriod": 180 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
args := &structs.CARequest{
|
||||
Datacenter: "dc1",
|
||||
Config: newConfig,
|
||||
}
|
||||
var reply interface{}
|
||||
|
||||
err := a.RPC("ConnectCA.ConfigurationSet", args, &reply)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set test CA config: %s", err)
|
||||
}
|
||||
return ca
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// hasOpenSSL is used to determine if the openssl CLI exists for unit tests.
|
||||
var hasOpenSSL bool
|
||||
|
||||
func init() {
|
||||
_, err := exec.LookPath("openssl")
|
||||
hasOpenSSL = err == nil
|
||||
}
|
||||
|
||||
// Test that the TestCA and TestLeaf functions generate valid certificates.
|
||||
func TestTestCAAndLeaf(t *testing.T) {
|
||||
if !hasOpenSSL {
|
||||
t.Skip("openssl not found")
|
||||
return
|
||||
}
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
// Create the certs
|
||||
ca := TestCA(t, nil)
|
||||
leaf, _ := TestLeaf(t, "web", ca)
|
||||
|
||||
// Create a temporary directory for storing the certs
|
||||
td, err := ioutil.TempDir("", "consul")
|
||||
assert.Nil(err)
|
||||
defer os.RemoveAll(td)
|
||||
|
||||
// Write the cert
|
||||
assert.Nil(ioutil.WriteFile(filepath.Join(td, "ca.pem"), []byte(ca.RootCert), 0644))
|
||||
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf.pem"), []byte(leaf), 0644))
|
||||
|
||||
// Use OpenSSL to verify so we have an external, known-working process
|
||||
// that can verify this outside of our own implementations.
|
||||
cmd := exec.Command(
|
||||
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf.pem")
|
||||
cmd.Dir = td
|
||||
output, err := cmd.Output()
|
||||
t.Log(string(output))
|
||||
assert.Nil(err)
|
||||
}
|
||||
|
||||
// Test cross-signing.
|
||||
func TestTestCAAndLeaf_xc(t *testing.T) {
|
||||
if !hasOpenSSL {
|
||||
t.Skip("openssl not found")
|
||||
return
|
||||
}
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
// Create the certs
|
||||
ca1 := TestCA(t, nil)
|
||||
ca2 := TestCA(t, ca1)
|
||||
leaf1, _ := TestLeaf(t, "web", ca1)
|
||||
leaf2, _ := TestLeaf(t, "web", ca2)
|
||||
|
||||
// Create a temporary directory for storing the certs
|
||||
td, err := ioutil.TempDir("", "consul")
|
||||
assert.Nil(err)
|
||||
defer os.RemoveAll(td)
|
||||
|
||||
// Write the cert
|
||||
xcbundle := []byte(ca1.RootCert)
|
||||
xcbundle = append(xcbundle, '\n')
|
||||
xcbundle = append(xcbundle, []byte(ca2.SigningCert)...)
|
||||
assert.Nil(ioutil.WriteFile(filepath.Join(td, "ca.pem"), xcbundle, 0644))
|
||||
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf1.pem"), []byte(leaf1), 0644))
|
||||
assert.Nil(ioutil.WriteFile(filepath.Join(td, "leaf2.pem"), []byte(leaf2), 0644))
|
||||
|
||||
// OpenSSL verify the cross-signed leaf (leaf2)
|
||||
{
|
||||
cmd := exec.Command(
|
||||
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf2.pem")
|
||||
cmd.Dir = td
|
||||
output, err := cmd.Output()
|
||||
t.Log(string(output))
|
||||
assert.Nil(err)
|
||||
}
|
||||
|
||||
// OpenSSL verify the old leaf (leaf1)
|
||||
{
|
||||
cmd := exec.Command(
|
||||
"openssl", "verify", "-verbose", "-CAfile", "ca.pem", "leaf1.pem")
|
||||
cmd.Dir = td
|
||||
output, err := cmd.Output()
|
||||
t.Log(string(output))
|
||||
assert.Nil(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
)
|
||||
|
||||
// TestSpiffeIDService returns a SPIFFE ID representing a service.
|
||||
func TestSpiffeIDService(t testing.T, service string) *SpiffeIDService {
|
||||
return TestSpiffeIDServiceWithHost(t, service, TestClusterID+".consul")
|
||||
}
|
||||
|
||||
// TestSpiffeIDServiceWithHost returns a SPIFFE ID representing a service with
|
||||
// the specified trust domain.
|
||||
func TestSpiffeIDServiceWithHost(t testing.T, service, host string) *SpiffeIDService {
|
||||
return &SpiffeIDService{
|
||||
Host: host,
|
||||
Namespace: "default",
|
||||
Datacenter: "dc1",
|
||||
Service: service,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// CertURI represents a Connect-valid URI value for a TLS certificate.
|
||||
// The user should type switch on the various implementations in this
|
||||
// package to determine the type of URI and the data encoded within it.
|
||||
//
|
||||
// Note that the current implementations of this are all also SPIFFE IDs.
|
||||
// However, we anticipate that we may accept URIs that are also not SPIFFE
|
||||
// compliant and therefore the interface is named as such.
|
||||
type CertURI interface {
|
||||
// Authorize tests the authorization for this URI as a client
|
||||
// for the given intention. The return value `auth` is only valid if
|
||||
// the second value `match` is true. If the second value `match` is
|
||||
// false, then the intention doesn't match this client and any
|
||||
// result should be ignored.
|
||||
Authorize(*structs.Intention) (auth bool, match bool)
|
||||
|
||||
// URI is the valid URI value used in the cert.
|
||||
URI() *url.URL
|
||||
}
|
||||
|
||||
var (
|
||||
spiffeIDServiceRegexp = regexp.MustCompile(
|
||||
`^/ns/([^/]+)/dc/([^/]+)/svc/([^/]+)$`)
|
||||
)
|
||||
|
||||
// ParseCertURI parses a the URI value from a TLS certificate.
|
||||
func ParseCertURI(input *url.URL) (CertURI, error) {
|
||||
if input.Scheme != "spiffe" {
|
||||
return nil, fmt.Errorf("SPIFFE ID must have 'spiffe' scheme")
|
||||
}
|
||||
|
||||
// Path is the raw value of the path without url decoding values.
|
||||
// RawPath is empty if there were no encoded values so we must
|
||||
// check both.
|
||||
path := input.Path
|
||||
if input.RawPath != "" {
|
||||
path = input.RawPath
|
||||
}
|
||||
|
||||
// Test for service IDs
|
||||
if v := spiffeIDServiceRegexp.FindStringSubmatch(path); v != nil {
|
||||
// Determine the values. We assume they're sane to save cycles,
|
||||
// but if the raw path is not empty that means that something is
|
||||
// URL encoded so we go to the slow path.
|
||||
ns := v[1]
|
||||
dc := v[2]
|
||||
service := v[3]
|
||||
if input.RawPath != "" {
|
||||
var err error
|
||||
if ns, err = url.PathUnescape(v[1]); err != nil {
|
||||
return nil, fmt.Errorf("Invalid namespace: %s", err)
|
||||
}
|
||||
if dc, err = url.PathUnescape(v[2]); err != nil {
|
||||
return nil, fmt.Errorf("Invalid datacenter: %s", err)
|
||||
}
|
||||
if service, err = url.PathUnescape(v[3]); err != nil {
|
||||
return nil, fmt.Errorf("Invalid service: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &SpiffeIDService{
|
||||
Host: input.Host,
|
||||
Namespace: ns,
|
||||
Datacenter: dc,
|
||||
Service: service,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Test for signing ID
|
||||
if input.Path == "" {
|
||||
idx := strings.Index(input.Host, ".")
|
||||
if idx > 0 {
|
||||
return &SpiffeIDSigning{
|
||||
ClusterID: input.Host[:idx],
|
||||
Domain: input.Host[idx+1:],
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("SPIFFE ID is not in the expected format")
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// SpiffeIDService is the structure to represent the SPIFFE ID for a service.
|
||||
type SpiffeIDService struct {
|
||||
Host string
|
||||
Namespace string
|
||||
Datacenter string
|
||||
Service string
|
||||
}
|
||||
|
||||
// URI returns the *url.URL for this SPIFFE ID.
|
||||
func (id *SpiffeIDService) URI() *url.URL {
|
||||
var result url.URL
|
||||
result.Scheme = "spiffe"
|
||||
result.Host = id.Host
|
||||
result.Path = fmt.Sprintf("/ns/%s/dc/%s/svc/%s",
|
||||
id.Namespace, id.Datacenter, id.Service)
|
||||
return &result
|
||||
}
|
||||
|
||||
// CertURI impl.
|
||||
func (id *SpiffeIDService) Authorize(ixn *structs.Intention) (bool, bool) {
|
||||
if ixn.SourceNS != structs.IntentionWildcard && ixn.SourceNS != id.Namespace {
|
||||
// Non-matching namespace
|
||||
return false, false
|
||||
}
|
||||
|
||||
if ixn.SourceName != structs.IntentionWildcard && ixn.SourceName != id.Service {
|
||||
// Non-matching name
|
||||
return false, false
|
||||
}
|
||||
|
||||
// Match, return allow value
|
||||
return ixn.Action == structs.IntentionActionAllow, true
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSpiffeIDServiceAuthorize(t *testing.T) {
|
||||
ns := structs.IntentionDefaultNamespace
|
||||
serviceWeb := &SpiffeIDService{
|
||||
Host: "1234.consul",
|
||||
Namespace: structs.IntentionDefaultNamespace,
|
||||
Datacenter: "dc01",
|
||||
Service: "web",
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
Name string
|
||||
URI *SpiffeIDService
|
||||
Ixn *structs.Intention
|
||||
Auth bool
|
||||
Match bool
|
||||
}{
|
||||
{
|
||||
"exact source, not matching namespace",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: "different",
|
||||
SourceName: "db",
|
||||
},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
"exact source, not matching name",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: ns,
|
||||
SourceName: "db",
|
||||
},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
"exact source, allow",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: serviceWeb.Namespace,
|
||||
SourceName: serviceWeb.Service,
|
||||
Action: structs.IntentionActionAllow,
|
||||
},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
|
||||
{
|
||||
"exact source, deny",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: serviceWeb.Namespace,
|
||||
SourceName: serviceWeb.Service,
|
||||
Action: structs.IntentionActionDeny,
|
||||
},
|
||||
false,
|
||||
true,
|
||||
},
|
||||
|
||||
{
|
||||
"exact namespace, wildcard service, deny",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: serviceWeb.Namespace,
|
||||
SourceName: structs.IntentionWildcard,
|
||||
Action: structs.IntentionActionDeny,
|
||||
},
|
||||
false,
|
||||
true,
|
||||
},
|
||||
|
||||
{
|
||||
"exact namespace, wildcard service, allow",
|
||||
serviceWeb,
|
||||
&structs.Intention{
|
||||
SourceNS: serviceWeb.Namespace,
|
||||
SourceName: structs.IntentionWildcard,
|
||||
Action: structs.IntentionActionAllow,
|
||||
},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
auth, match := tc.URI.Authorize(tc.Ixn)
|
||||
assert.Equal(t, tc.Auth, auth)
|
||||
assert.Equal(t, tc.Match, match)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// SpiffeIDSigning is the structure to represent the SPIFFE ID for a
|
||||
// signing certificate (not a leaf service).
|
||||
type SpiffeIDSigning struct {
|
||||
ClusterID string // Unique cluster ID
|
||||
Domain string // The domain, usually "consul"
|
||||
}
|
||||
|
||||
// URI returns the *url.URL for this SPIFFE ID.
|
||||
func (id *SpiffeIDSigning) URI() *url.URL {
|
||||
var result url.URL
|
||||
result.Scheme = "spiffe"
|
||||
result.Host = id.Host()
|
||||
return &result
|
||||
}
|
||||
|
||||
// Host is the canonical representation as a DNS-compatible hostname.
|
||||
func (id *SpiffeIDSigning) Host() string {
|
||||
return strings.ToLower(fmt.Sprintf("%s.%s", id.ClusterID, id.Domain))
|
||||
}
|
||||
|
||||
// CertURI impl.
|
||||
func (id *SpiffeIDSigning) Authorize(ixn *structs.Intention) (bool, bool) {
|
||||
// Never authorize as a client.
|
||||
return false, true
|
||||
}
|
||||
|
||||
// CanSign takes any CertURI and returns whether or not this signing entity is
|
||||
// allowed to sign CSRs for that entity (i.e. represents the trust domain for
|
||||
// that entity).
|
||||
//
|
||||
// I choose to make this a fixed centralised method here for now rather than a
|
||||
// method on CertURI interface since we don't intend this to be extensible
|
||||
// outside and it's easier to reason about the security properties when they are
|
||||
// all in one place with "whitelist" semantics.
|
||||
func (id *SpiffeIDSigning) CanSign(cu CertURI) bool {
|
||||
switch other := cu.(type) {
|
||||
case *SpiffeIDSigning:
|
||||
// We can only sign other CA certificates for the same trust domain. Note
|
||||
// that we could open this up later for example to support external
|
||||
// federation of roots and cross-signing external roots that have different
|
||||
// URI structure but it's simpler to start off restrictive.
|
||||
return id == other
|
||||
case *SpiffeIDService:
|
||||
// The host component of the service must be an exact match for now under
|
||||
// ascii case folding (since hostnames are case-insensitive). Later we might
|
||||
// worry about Unicode domains if we start allowing customisation beyond the
|
||||
// built-in cluster ids.
|
||||
return strings.ToLower(other.Host) == id.Host()
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SpiffeIDSigningForCluster returns the SPIFFE signing identifier (trust
|
||||
// domain) representation of the given CA config. If config is nil this function
|
||||
// will panic.
|
||||
//
|
||||
// NOTE(banks): we intentionally fix the tld `.consul` for now rather than tie
|
||||
// this to the `domain` config used for DNS because changing DNS domain can't
|
||||
// break all certificate validation. That does mean that DNS prefix might not
|
||||
// match the identity URIs and so the trust domain might not actually resolve
|
||||
// which we would like but don't actually need.
|
||||
func SpiffeIDSigningForCluster(config *structs.CAConfiguration) *SpiffeIDSigning {
|
||||
return &SpiffeIDSigning{ClusterID: config.ClusterID, Domain: "consul"}
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Signing ID should never authorize
|
||||
func TestSpiffeIDSigningAuthorize(t *testing.T) {
|
||||
var id SpiffeIDSigning
|
||||
auth, ok := id.Authorize(nil)
|
||||
assert.False(t, auth)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestSpiffeIDSigningForCluster(t *testing.T) {
|
||||
// For now it should just append .consul to the ID.
|
||||
config := &structs.CAConfiguration{
|
||||
ClusterID: TestClusterID,
|
||||
}
|
||||
id := SpiffeIDSigningForCluster(config)
|
||||
assert.Equal(t, id.URI().String(), "spiffe://"+TestClusterID+".consul")
|
||||
}
|
||||
|
||||
// fakeCertURI is a CertURI implementation that our implementation doesn't know
|
||||
// about
|
||||
type fakeCertURI string
|
||||
|
||||
func (f fakeCertURI) Authorize(*structs.Intention) (auth bool, match bool) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (f fakeCertURI) URI() *url.URL {
|
||||
u, _ := url.Parse(string(f))
|
||||
return u
|
||||
}
|
||||
func TestSpiffeIDSigning_CanSign(t *testing.T) {
|
||||
|
||||
testSigning := &SpiffeIDSigning{
|
||||
ClusterID: TestClusterID,
|
||||
Domain: "consul",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id *SpiffeIDSigning
|
||||
input CertURI
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "same signing ID",
|
||||
id: testSigning,
|
||||
input: testSigning,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other signing ID",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDSigning{
|
||||
ClusterID: "fakedomain",
|
||||
Domain: "consul",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "different TLD signing ID",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDSigning{
|
||||
ClusterID: TestClusterID,
|
||||
Domain: "evil",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
id: testSigning,
|
||||
input: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unrecognised CertURI implementation",
|
||||
id: testSigning,
|
||||
input: fakeCertURI("spiffe://foo.bar/baz"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "service - good",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDService{TestClusterID + ".consul", "default", "dc1", "web"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "service - good midex case",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDService{strings.ToUpper(TestClusterID) + ".CONsuL", "defAUlt", "dc1", "WEB"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "service - different cluster",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDService{"55555555-4444-3333-2222-111111111111.consul", "default", "dc1", "web"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "service - different TLD",
|
||||
id: testSigning,
|
||||
input: &SpiffeIDService{TestClusterID + ".fake", "default", "dc1", "web"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.id.CanSign(tt.input)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package connect
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// testCertURICases contains the test cases for parsing and encoding
|
||||
// the SPIFFE IDs. This is a global since it is used in multiple test functions.
|
||||
var testCertURICases = []struct {
|
||||
Name string
|
||||
URI string
|
||||
Struct interface{}
|
||||
ParseError string
|
||||
}{
|
||||
{
|
||||
"invalid scheme",
|
||||
"http://google.com/",
|
||||
nil,
|
||||
"scheme",
|
||||
},
|
||||
|
||||
{
|
||||
"basic service ID",
|
||||
"spiffe://1234.consul/ns/default/dc/dc01/svc/web",
|
||||
&SpiffeIDService{
|
||||
Host: "1234.consul",
|
||||
Namespace: "default",
|
||||
Datacenter: "dc01",
|
||||
Service: "web",
|
||||
},
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"service with URL-encoded values",
|
||||
"spiffe://1234.consul/ns/foo%2Fbar/dc/bar%2Fbaz/svc/baz%2Fqux",
|
||||
&SpiffeIDService{
|
||||
Host: "1234.consul",
|
||||
Namespace: "foo/bar",
|
||||
Datacenter: "bar/baz",
|
||||
Service: "baz/qux",
|
||||
},
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"signing ID",
|
||||
"spiffe://1234.consul",
|
||||
&SpiffeIDSigning{
|
||||
ClusterID: "1234",
|
||||
Domain: "consul",
|
||||
},
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
func TestParseCertURI(t *testing.T) {
|
||||
for _, tc := range testCertURICases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Parse the URI, should always be valid
|
||||
uri, err := url.Parse(tc.URI)
|
||||
assert.Nil(err)
|
||||
|
||||
// Parse the ID and check the error/return value
|
||||
actual, err := ParseCertURI(uri)
|
||||
if err != nil {
|
||||
t.Logf("parse error: %s", err.Error())
|
||||
}
|
||||
assert.Equal(tc.ParseError != "", err != nil, "error value")
|
||||
if err != nil {
|
||||
assert.Contains(err.Error(), tc.ParseError)
|
||||
return
|
||||
}
|
||||
assert.Equal(tc.Struct, actual)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// GET /v1/connect/ca/roots
|
||||
func (s *HTTPServer) ConnectCARoots(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
var args structs.DCSpecificRequest
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var reply structs.IndexedCARoots
|
||||
defer setMeta(resp, &reply.QueryMeta)
|
||||
if err := s.agent.RPC("ConnectCA.Roots", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// /v1/connect/ca/configuration
|
||||
func (s *HTTPServer) ConnectCAConfiguration(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
switch req.Method {
|
||||
case "GET":
|
||||
return s.ConnectCAConfigurationGet(resp, req)
|
||||
|
||||
case "PUT":
|
||||
return s.ConnectCAConfigurationSet(resp, req)
|
||||
|
||||
default:
|
||||
return nil, MethodNotAllowedError{req.Method, []string{"GET", "POST"}}
|
||||
}
|
||||
}
|
||||
|
||||
// GEt /v1/connect/ca/configuration
|
||||
func (s *HTTPServer) ConnectCAConfigurationGet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in ConnectCAConfiguration
|
||||
var args structs.DCSpecificRequest
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var reply structs.CAConfiguration
|
||||
err := s.agent.RPC("ConnectCA.ConfigurationGet", &args, &reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fixupConfig(&reply)
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// PUT /v1/connect/ca/configuration
|
||||
func (s *HTTPServer) ConnectCAConfigurationSet(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in ConnectCAConfiguration
|
||||
|
||||
var args structs.CARequest
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
s.parseToken(req, &args.Token)
|
||||
if err := decodeBody(req, &args.Config, nil); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Request decode failed: %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var reply interface{}
|
||||
err := s.agent.RPC("ConnectCA.ConfigurationSet", &args, &reply)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// A hack to fix up the config types inside of the map[string]interface{}
|
||||
// so that they get formatted correctly during json.Marshal. Without this,
|
||||
// string values that get converted to []uint8 end up getting output back
|
||||
// to the user in base64-encoded form.
|
||||
func fixupConfig(conf *structs.CAConfiguration) {
|
||||
for k, v := range conf.Config {
|
||||
if raw, ok := v.([]uint8); ok {
|
||||
strVal := ca.Uint8ToString(raw)
|
||||
conf.Config[k] = strVal
|
||||
switch conf.Provider {
|
||||
case structs.ConsulCAProvider:
|
||||
if k == "PrivateKey" && strVal != "" {
|
||||
conf.Config["PrivateKey"] = "hidden"
|
||||
}
|
||||
case structs.VaultCAProvider:
|
||||
if k == "Token" && strVal != "" {
|
||||
conf.Config["Token"] = "hidden"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
ca "github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConnectCARoots_empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "connect { enabled = false }")
|
||||
defer a.Shutdown()
|
||||
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/ca/roots", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ConnectCARoots(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(structs.IndexedCARoots)
|
||||
assert.Equal(value.ActiveRootID, "")
|
||||
assert.Len(value.Roots, 0)
|
||||
}
|
||||
|
||||
func TestConnectCARoots_list(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Set some CAs. Note that NewTestAgent already bootstraps one CA so this just
|
||||
// adds a second and makes it active.
|
||||
ca2 := connect.TestCAConfigSet(t, a, nil)
|
||||
|
||||
// List
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/ca/roots", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ConnectCARoots(resp, req)
|
||||
assert.NoError(err)
|
||||
|
||||
value := obj.(structs.IndexedCARoots)
|
||||
assert.Equal(value.ActiveRootID, ca2.ID)
|
||||
assert.Len(value.Roots, 2)
|
||||
|
||||
// We should never have the secret information
|
||||
for _, r := range value.Roots {
|
||||
assert.Equal("", r.SigningCert)
|
||||
assert.Equal("", r.SigningKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectCAConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
expected := &structs.ConsulCAProviderConfig{
|
||||
RotationPeriod: 90 * 24 * time.Hour,
|
||||
}
|
||||
|
||||
// Get the initial config.
|
||||
{
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/ca/configuration", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ConnectCAConfiguration(resp, req)
|
||||
assert.NoError(err)
|
||||
|
||||
value := obj.(structs.CAConfiguration)
|
||||
parsed, err := ca.ParseConsulCAConfig(value.Config)
|
||||
assert.NoError(err)
|
||||
assert.Equal("consul", value.Provider)
|
||||
assert.Equal(expected, parsed)
|
||||
}
|
||||
|
||||
// Set the config.
|
||||
{
|
||||
body := bytes.NewBuffer([]byte(`
|
||||
{
|
||||
"Provider": "consul",
|
||||
"Config": {
|
||||
"RotationPeriod": 3600000000000
|
||||
}
|
||||
}`))
|
||||
req, _ := http.NewRequest("PUT", "/v1/connect/ca/configuration", body)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ConnectCAConfiguration(resp, req)
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
// The config should be updated now.
|
||||
{
|
||||
expected.RotationPeriod = time.Hour
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/ca/configuration", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ConnectCAConfiguration(resp, req)
|
||||
assert.NoError(err)
|
||||
|
||||
value := obj.(structs.CAConfiguration)
|
||||
parsed, err := ca.ParseConsulCAConfig(value.Config)
|
||||
assert.NoError(err)
|
||||
assert.Equal("consul", value.Provider)
|
||||
assert.Equal(expected, parsed)
|
||||
}
|
||||
}
|
|
@ -454,6 +454,33 @@ func (f *aclFilter) filterCoordinates(coords *structs.Coordinates) {
|
|||
*coords = c
|
||||
}
|
||||
|
||||
// filterIntentions is used to filter intentions based on ACL rules.
|
||||
// We prune entries the user doesn't have access to, and we redact any tokens
|
||||
// if the user doesn't have a management token.
|
||||
func (f *aclFilter) filterIntentions(ixns *structs.Intentions) {
|
||||
// Management tokens can see everything with no filtering.
|
||||
if f.acl.ACLList() {
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, we need to see what the token has access to.
|
||||
ret := make(structs.Intentions, 0, len(*ixns))
|
||||
for _, ixn := range *ixns {
|
||||
// If no prefix ACL applies to this then filter it, since
|
||||
// we know at this point the user doesn't have a management
|
||||
// token, otherwise see what the policy says.
|
||||
prefix, ok := ixn.GetACLPrefix()
|
||||
if !ok || !f.acl.IntentionRead(prefix) {
|
||||
f.logger.Printf("[DEBUG] consul: dropping intention %q from result due to ACLs", ixn.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
ret = append(ret, ixn)
|
||||
}
|
||||
|
||||
*ixns = ret
|
||||
}
|
||||
|
||||
// filterNodeDump is used to filter through all parts of a node dump and
|
||||
// remove elements the provided ACL token cannot access.
|
||||
func (f *aclFilter) filterNodeDump(dump *structs.NodeDump) {
|
||||
|
@ -598,6 +625,9 @@ func (s *Server) filterACL(token string, subj interface{}) error {
|
|||
case *structs.IndexedHealthChecks:
|
||||
filt.filterHealthChecks(&v.HealthChecks)
|
||||
|
||||
case *structs.IndexedIntentions:
|
||||
filt.filterIntentions(&v.Intentions)
|
||||
|
||||
case *structs.IndexedNodeDump:
|
||||
filt.filterNodeDump(&v.Dump)
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/consul/testutil/retry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var testACLPolicy = `
|
||||
|
@ -847,6 +848,58 @@ node "node1" {
|
|||
}
|
||||
}
|
||||
|
||||
func TestACL_filterIntentions(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert := assert.New(t)
|
||||
|
||||
fill := func() structs.Intentions {
|
||||
return structs.Intentions{
|
||||
&structs.Intention{
|
||||
ID: "f004177f-2c28-83b7-4229-eacc25fe55d1",
|
||||
DestinationName: "bar",
|
||||
},
|
||||
&structs.Intention{
|
||||
ID: "f004177f-2c28-83b7-4229-eacc25fe55d2",
|
||||
DestinationName: "foo",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Try permissive filtering.
|
||||
{
|
||||
ixns := fill()
|
||||
filt := newACLFilter(acl.AllowAll(), nil, false)
|
||||
filt.filterIntentions(&ixns)
|
||||
assert.Len(ixns, 2)
|
||||
}
|
||||
|
||||
// Try restrictive filtering.
|
||||
{
|
||||
ixns := fill()
|
||||
filt := newACLFilter(acl.DenyAll(), nil, false)
|
||||
filt.filterIntentions(&ixns)
|
||||
assert.Len(ixns, 0)
|
||||
}
|
||||
|
||||
// Policy to see one
|
||||
policy, err := acl.Parse(`
|
||||
service "foo" {
|
||||
policy = "read"
|
||||
}
|
||||
`, nil)
|
||||
assert.Nil(err)
|
||||
perms, err := acl.New(acl.DenyAll(), policy, nil)
|
||||
assert.Nil(err)
|
||||
|
||||
// Filter
|
||||
{
|
||||
ixns := fill()
|
||||
filt := newACLFilter(perms, nil, false)
|
||||
filt.filterIntentions(&ixns)
|
||||
assert.Len(ixns, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestACL_filterServices(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create some services
|
||||
|
|
|
@ -47,6 +47,13 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error
|
|||
|
||||
// Handle a service registration.
|
||||
if args.Service != nil {
|
||||
// Validate the service. This is in addition to the below since
|
||||
// the above just hasn't been moved over yet. We should move it over
|
||||
// in time.
|
||||
if err := args.Service.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If no service id, but service name, use default
|
||||
if args.Service.ID == "" && args.Service.Service != "" {
|
||||
args.Service.ID = args.Service.Service
|
||||
|
@ -73,6 +80,13 @@ func (c *Catalog) Register(args *structs.RegisterRequest, reply *struct{}) error
|
|||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
// Proxies must have write permission on their destination
|
||||
if args.Service.Kind == structs.ServiceKindConnectProxy {
|
||||
if rule != nil && !rule.ServiceWrite(args.Service.ProxyDestination, nil) {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Move the old format single check into the slice, and fixup IDs.
|
||||
|
@ -244,24 +258,52 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru
|
|||
return fmt.Errorf("Must provide service name")
|
||||
}
|
||||
|
||||
// Determine the function we'll call
|
||||
var f func(memdb.WatchSet, *state.Store) (uint64, structs.ServiceNodes, error)
|
||||
switch {
|
||||
case args.Connect:
|
||||
f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) {
|
||||
return s.ConnectServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
|
||||
default:
|
||||
f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.ServiceNodes, error) {
|
||||
if args.ServiceAddress != "" {
|
||||
return s.ServiceAddressNodes(ws, args.ServiceAddress)
|
||||
}
|
||||
|
||||
if args.TagFilter {
|
||||
return s.ServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
|
||||
}
|
||||
|
||||
return s.ServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
}
|
||||
|
||||
// If we're doing a connect query, we need read access to the service
|
||||
// we're trying to find proxies for, so check that.
|
||||
if args.Connect {
|
||||
// Fetch the ACL token, if any.
|
||||
rule, err := c.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule != nil && !rule.ServiceRead(args.ServiceName) {
|
||||
// Just return nil, which will return an empty response (tested)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
err := c.srv.blockingQuery(
|
||||
&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
var index uint64
|
||||
var services structs.ServiceNodes
|
||||
var err error
|
||||
if args.TagFilter {
|
||||
index, services, err = state.ServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
|
||||
} else {
|
||||
index, services, err = state.ServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
if args.ServiceAddress != "" {
|
||||
index, services, err = state.ServiceAddressNodes(ws, args.ServiceAddress)
|
||||
}
|
||||
index, services, err := f(ws, state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index, reply.ServiceNodes = index, services
|
||||
if len(args.NodeMetaFilters) > 0 {
|
||||
var filtered structs.ServiceNodes
|
||||
|
@ -280,17 +322,24 @@ func (c *Catalog) ServiceNodes(args *structs.ServiceSpecificRequest, reply *stru
|
|||
|
||||
// Provide some metrics
|
||||
if err == nil {
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", "service", "query"}, 1,
|
||||
// For metrics, we separate Connect-based lookups from non-Connect
|
||||
key := "service"
|
||||
if args.Connect {
|
||||
key = "connect"
|
||||
}
|
||||
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", key, "query"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
|
||||
if args.ServiceTag != "" {
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", "service", "query-tag"}, 1,
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", key, "query-tag"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}, {Name: "tag", Value: args.ServiceTag}})
|
||||
}
|
||||
if len(reply.ServiceNodes) == 0 {
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", "service", "not-found"}, 1,
|
||||
metrics.IncrCounterWithLabels([]string{"catalog", key, "not-found"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,8 @@ import (
|
|||
"github.com/hashicorp/consul/testutil/retry"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCatalog_Register(t *testing.T) {
|
||||
|
@ -332,6 +334,147 @@ func TestCatalog_Register_ForwardDC(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCatalog_Register_ConnectProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
|
||||
// Register
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
|
||||
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
|
||||
}
|
||||
|
||||
// Test an invalid ConnectProxy. We don't need to exhaustively test because
|
||||
// this is all tested in structs on the Validate method.
|
||||
func TestCatalog_Register_ConnectProxy_invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Service.ProxyDestination = ""
|
||||
|
||||
// Register
|
||||
var out struct{}
|
||||
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "ProxyDestination")
|
||||
}
|
||||
|
||||
// Test that write is required for the proxy destination to register a proxy.
|
||||
func TestCatalog_Register_ConnectProxy_ACLProxyDestination(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServerWithConfig(t, func(c *Config) {
|
||||
c.ACLDatacenter = "dc1"
|
||||
c.ACLMasterToken = "root"
|
||||
c.ACLDefaultPolicy = "deny"
|
||||
})
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Create the ACL.
|
||||
arg := structs.ACLRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.ACLSet,
|
||||
ACL: structs.ACL{
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
Rules: `
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
}
|
||||
`,
|
||||
},
|
||||
WriteRequest: structs.WriteRequest{Token: "root"},
|
||||
}
|
||||
var token string
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &token))
|
||||
|
||||
// Register should fail because we don't have permission on the destination
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "foo"
|
||||
args.Service.ProxyDestination = "bar"
|
||||
args.WriteRequest.Token = token
|
||||
var out struct{}
|
||||
err := msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
|
||||
assert.True(acl.IsErrPermissionDenied(err))
|
||||
|
||||
// Register should fail with the right destination but wrong name
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "bar"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.WriteRequest.Token = token
|
||||
err = msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out)
|
||||
assert.True(acl.IsErrPermissionDenied(err))
|
||||
|
||||
// Register should work with the right destination
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "foo"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.WriteRequest.Token = token
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
}
|
||||
|
||||
func TestCatalog_Register_ConnectNative(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
args := structs.TestRegisterRequest(t)
|
||||
args.Service.Connect.Native = true
|
||||
|
||||
// Register
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal(structs.ServiceKindTypical, v.ServiceKind)
|
||||
assert.True(v.ServiceConnect.Native)
|
||||
}
|
||||
|
||||
func TestCatalog_Deregister(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir1, s1 := testServer(t)
|
||||
|
@ -1599,6 +1742,246 @@ func TestCatalog_ListServiceNodes_DistanceSort(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCatalog_ListServiceNodes_ConnectProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the service
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
TagFilter: false,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
|
||||
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
|
||||
}
|
||||
|
||||
func TestCatalog_ListServiceNodes_ConnectDestination(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the proxy service
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// Register the service
|
||||
{
|
||||
dst := args.Service.ProxyDestination
|
||||
args := structs.TestRegisterRequest(t)
|
||||
args.Service.Service = dst
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
}
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.ProxyDestination,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal(structs.ServiceKindConnectProxy, v.ServiceKind)
|
||||
assert.Equal(args.Service.ProxyDestination, v.ServiceProxyDestination)
|
||||
|
||||
// List by non-Connect
|
||||
req = structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.ProxyDestination,
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v = resp.ServiceNodes[0]
|
||||
assert.Equal(args.Service.ProxyDestination, v.ServiceName)
|
||||
assert.Equal("", v.ServiceProxyDestination)
|
||||
}
|
||||
|
||||
// Test that calling ServiceNodes with Connect: true will return
|
||||
// Connect native services.
|
||||
func TestCatalog_ListServiceNodes_ConnectDestinationNative(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the native service
|
||||
args := structs.TestRegisterRequest(t)
|
||||
args.Service.Connect.Native = true
|
||||
var out struct{}
|
||||
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
require.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
require.Equal(args.Service.Service, v.ServiceName)
|
||||
|
||||
// List by non-Connect
|
||||
req = structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
}
|
||||
require.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
require.Len(resp.ServiceNodes, 1)
|
||||
v = resp.ServiceNodes[0]
|
||||
require.Equal(args.Service.Service, v.ServiceName)
|
||||
}
|
||||
|
||||
func TestCatalog_ListServiceNodes_ConnectProxy_ACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServerWithConfig(t, func(c *Config) {
|
||||
c.ACLDatacenter = "dc1"
|
||||
c.ACLMasterToken = "root"
|
||||
c.ACLDefaultPolicy = "deny"
|
||||
})
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Create the ACL.
|
||||
arg := structs.ACLRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.ACLSet,
|
||||
ACL: structs.ACL{
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
Rules: `
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
}
|
||||
`,
|
||||
},
|
||||
WriteRequest: structs.WriteRequest{Token: "root"},
|
||||
}
|
||||
var token string
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &token))
|
||||
|
||||
{
|
||||
// Register a proxy
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "foo-proxy"
|
||||
args.Service.ProxyDestination = "bar"
|
||||
args.WriteRequest.Token = "root"
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// Register a proxy
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "foo-proxy"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.WriteRequest.Token = "root"
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// Register a proxy
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.Service.Service = "another-proxy"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.WriteRequest.Token = "root"
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
}
|
||||
|
||||
// List w/ token. This should disallow because we don't have permission
|
||||
// to read "bar"
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: "bar",
|
||||
QueryOptions: structs.QueryOptions{Token: token},
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 0)
|
||||
|
||||
// List w/ token. This should work since we're requesting "foo", but should
|
||||
// also only contain the proxies with names that adhere to our ACL.
|
||||
req = structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: "foo",
|
||||
QueryOptions: structs.QueryOptions{Token: token},
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal("foo-proxy", v.ServiceName)
|
||||
}
|
||||
|
||||
func TestCatalog_ListServiceNodes_ConnectNative(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the service
|
||||
args := structs.TestRegisterRequest(t)
|
||||
args.Service.Connect.Native = true
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// List
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
ServiceName: args.Service.Service,
|
||||
TagFilter: false,
|
||||
}
|
||||
var resp structs.IndexedServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.ServiceNodes, 1)
|
||||
v := resp.ServiceNodes[0]
|
||||
assert.Equal(args.Service.Connect.Native, v.ServiceConnect.Native)
|
||||
}
|
||||
|
||||
func TestCatalog_NodeServices(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir1, s1 := testServer(t)
|
||||
|
@ -1649,6 +2032,67 @@ func TestCatalog_NodeServices(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the service
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// List
|
||||
req := structs.NodeSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: args.Node,
|
||||
}
|
||||
var resp structs.IndexedNodeServices
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
|
||||
|
||||
assert.Len(resp.NodeServices.Services, 1)
|
||||
v := resp.NodeServices.Services[args.Service.Service]
|
||||
assert.Equal(structs.ServiceKindConnectProxy, v.Kind)
|
||||
assert.Equal(args.Service.ProxyDestination, v.ProxyDestination)
|
||||
}
|
||||
|
||||
func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Register the service
|
||||
args := structs.TestRegisterRequest(t)
|
||||
var out struct{}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", args, &out))
|
||||
|
||||
// List
|
||||
req := structs.NodeSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: args.Node,
|
||||
}
|
||||
var resp structs.IndexedNodeServices
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.NodeServices", &req, &resp))
|
||||
|
||||
assert.Len(resp.NodeServices.Services, 1)
|
||||
v := resp.NodeServices.Services[args.Service.Service]
|
||||
assert.Equal(args.Service.Connect.Native, v.Connect.Native)
|
||||
}
|
||||
|
||||
// Used to check for a regression against a known bug
|
||||
func TestCatalog_Register_FailedCase1(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/tlsutil"
|
||||
"github.com/hashicorp/consul/types"
|
||||
|
@ -346,6 +347,13 @@ type Config struct {
|
|||
// autopilot tasks, such as promoting eligible non-voters and removing
|
||||
// dead servers.
|
||||
AutopilotInterval time.Duration
|
||||
|
||||
// ConnectEnabled is whether to enable Connect features such as the CA.
|
||||
ConnectEnabled bool
|
||||
|
||||
// CAConfig is used to apply the initial Connect CA configuration when
|
||||
// bootstrapping.
|
||||
CAConfig *structs.CAConfiguration
|
||||
}
|
||||
|
||||
// CheckProtocolVersion validates the protocol version.
|
||||
|
@ -425,6 +433,13 @@ func DefaultConfig() *Config {
|
|||
ServerStabilizationTime: 10 * time.Second,
|
||||
},
|
||||
|
||||
CAConfig: &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"RotationPeriod": "2160h",
|
||||
},
|
||||
},
|
||||
|
||||
ServerHealthInterval: 2 * time.Second,
|
||||
AutopilotInterval: 10 * time.Second,
|
||||
}
|
||||
|
|
|
@ -0,0 +1,393 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
var ErrConnectNotEnabled = errors.New("Connect must be enabled in order to use this endpoint")
|
||||
|
||||
// ConnectCA manages the Connect CA.
|
||||
type ConnectCA struct {
|
||||
// srv is a pointer back to the server.
|
||||
srv *Server
|
||||
}
|
||||
|
||||
// ConfigurationGet returns the configuration for the CA.
|
||||
func (s *ConnectCA) ConfigurationGet(
|
||||
args *structs.DCSpecificRequest,
|
||||
reply *structs.CAConfiguration) error {
|
||||
// Exit early if Connect hasn't been enabled.
|
||||
if !s.srv.config.ConnectEnabled {
|
||||
return ErrConnectNotEnabled
|
||||
}
|
||||
|
||||
if done, err := s.srv.forward("ConnectCA.ConfigurationGet", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// This action requires operator read access.
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rule != nil && !rule.OperatorRead() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
state := s.srv.fsm.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*reply = *config
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigurationSet updates the configuration for the CA.
|
||||
func (s *ConnectCA) ConfigurationSet(
|
||||
args *structs.CARequest,
|
||||
reply *interface{}) error {
|
||||
// Exit early if Connect hasn't been enabled.
|
||||
if !s.srv.config.ConnectEnabled {
|
||||
return ErrConnectNotEnabled
|
||||
}
|
||||
|
||||
if done, err := s.srv.forward("ConnectCA.ConfigurationSet", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// This action requires operator write access.
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rule != nil && !rule.OperatorWrite() {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// Exit early if it's a no-op change
|
||||
state := s.srv.fsm.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
args.Config.ClusterID = config.ClusterID
|
||||
if args.Config.Provider == config.Provider && reflect.DeepEqual(args.Config.Config, config.Config) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create a new instance of the provider described by the config
|
||||
// and get the current active root CA. This acts as a good validation
|
||||
// of the config and makes sure the provider is functioning correctly
|
||||
// before we commit any changes to Raft.
|
||||
newProvider, err := s.srv.createCAProvider(args.Config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not initialize provider: %v", err)
|
||||
}
|
||||
|
||||
newRootPEM, err := newProvider.ActiveRoot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newActiveRoot, err := parseCARoot(newRootPEM, args.Config.Provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Compare the new provider's root CA ID to the current one. If they
|
||||
// match, just update the existing provider with the new config.
|
||||
// If they don't match, begin the root rotation process.
|
||||
_, root, err := state.CARootActive(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if root != nil && root.ID == newActiveRoot.ID {
|
||||
args.Op = structs.CAOpSetConfig
|
||||
resp, err := s.srv.raftApply(structs.ConnectCARequestType, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
// If the config has been committed, update the local provider instance
|
||||
s.srv.setCAProvider(newProvider, newActiveRoot)
|
||||
|
||||
s.srv.logger.Printf("[INFO] connect: CA provider config updated")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// At this point, we know the config change has trigged a root rotation,
|
||||
// either by swapping the provider type or changing the provider's config
|
||||
// to use a different root certificate.
|
||||
|
||||
// If it's a config change that would trigger a rotation (different provider/root):
|
||||
// 1. Get the root from the new provider.
|
||||
// 2. Call CrossSignCA on the old provider to sign the new root with the old one to
|
||||
// get a cross-signed certificate.
|
||||
// 3. Take the active root for the new provider and append the intermediate from step 2
|
||||
// to its list of intermediates.
|
||||
newRoot, err := connect.ParseCert(newRootPEM)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Have the old provider cross-sign the new intermediate
|
||||
oldProvider, _ := s.srv.getCAProvider()
|
||||
if oldProvider == nil {
|
||||
return fmt.Errorf("internal error: CA provider is nil")
|
||||
}
|
||||
xcCert, err := oldProvider.CrossSignCA(newRoot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the cross signed cert to the new root's intermediates.
|
||||
newActiveRoot.IntermediateCerts = []string{xcCert}
|
||||
intermediate, err := newProvider.GenerateIntermediate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if intermediate != newRootPEM {
|
||||
newActiveRoot.IntermediateCerts = append(newActiveRoot.IntermediateCerts, intermediate)
|
||||
}
|
||||
|
||||
// Update the roots and CA config in the state store at the same time
|
||||
idx, roots, err := state.CARoots(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var newRoots structs.CARoots
|
||||
for _, r := range roots {
|
||||
newRoot := *r
|
||||
if newRoot.Active {
|
||||
newRoot.Active = false
|
||||
}
|
||||
newRoots = append(newRoots, &newRoot)
|
||||
}
|
||||
newRoots = append(newRoots, newActiveRoot)
|
||||
|
||||
args.Op = structs.CAOpSetRootsAndConfig
|
||||
args.Index = idx
|
||||
args.Roots = newRoots
|
||||
resp, err := s.srv.raftApply(structs.ConnectCARequestType, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
// If the config has been committed, update the local provider instance
|
||||
// and call teardown on the old provider
|
||||
s.srv.setCAProvider(newProvider, newActiveRoot)
|
||||
|
||||
if err := oldProvider.Cleanup(); err != nil {
|
||||
s.srv.logger.Printf("[WARN] connect: failed to clean up old provider %q", config.Provider)
|
||||
}
|
||||
|
||||
s.srv.logger.Printf("[INFO] connect: CA rotated to new root under provider %q", args.Config.Provider)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Roots returns the currently trusted root certificates.
|
||||
func (s *ConnectCA) Roots(
|
||||
args *structs.DCSpecificRequest,
|
||||
reply *structs.IndexedCARoots) error {
|
||||
// Forward if necessary
|
||||
if done, err := s.srv.forward("ConnectCA.Roots", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load the ClusterID to generate TrustDomain. We do this outside the loop
|
||||
// since by definition this value should be immutable once set for lifetime of
|
||||
// the cluster so we don't need to look it up more than once. We also don't
|
||||
// have to worry about non-atomicity between the config fetch transaction and
|
||||
// the CARoots transaction below since this field must remain immutable. Do
|
||||
// not re-use this state/config for other logic that might care about changes
|
||||
// of config during the blocking query below.
|
||||
{
|
||||
state := s.srv.fsm.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check CA is actually bootstrapped...
|
||||
if config != nil {
|
||||
// Build TrustDomain based on the ClusterID stored.
|
||||
signingID := connect.SpiffeIDSigningForCluster(config)
|
||||
if signingID == nil {
|
||||
// If CA is bootstrapped at all then this should never happen but be
|
||||
// defensive.
|
||||
return errors.New("no cluster trust domain setup")
|
||||
}
|
||||
reply.TrustDomain = signingID.Host()
|
||||
}
|
||||
}
|
||||
|
||||
return s.srv.blockingQuery(
|
||||
&args.QueryOptions, &reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, roots, err := state.CARoots(ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index, reply.Roots = index, roots
|
||||
if reply.Roots == nil {
|
||||
reply.Roots = make(structs.CARoots, 0)
|
||||
}
|
||||
|
||||
// The API response must NEVER contain the secret information
|
||||
// such as keys and so on. We use a whitelist below to copy the
|
||||
// specific fields we want to expose.
|
||||
for i, r := range reply.Roots {
|
||||
// IMPORTANT: r must NEVER be modified, since it is a pointer
|
||||
// directly to the structure in the memdb store.
|
||||
|
||||
reply.Roots[i] = &structs.CARoot{
|
||||
ID: r.ID,
|
||||
Name: r.Name,
|
||||
SerialNumber: r.SerialNumber,
|
||||
SigningKeyID: r.SigningKeyID,
|
||||
NotBefore: r.NotBefore,
|
||||
NotAfter: r.NotAfter,
|
||||
RootCert: r.RootCert,
|
||||
IntermediateCerts: r.IntermediateCerts,
|
||||
RaftIndex: r.RaftIndex,
|
||||
Active: r.Active,
|
||||
}
|
||||
|
||||
if r.Active {
|
||||
reply.ActiveRootID = r.ID
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Sign signs a certificate for a service.
|
||||
func (s *ConnectCA) Sign(
|
||||
args *structs.CASignRequest,
|
||||
reply *structs.IssuedCert) error {
|
||||
// Exit early if Connect hasn't been enabled.
|
||||
if !s.srv.config.ConnectEnabled {
|
||||
return ErrConnectNotEnabled
|
||||
}
|
||||
|
||||
if done, err := s.srv.forward("ConnectCA.Sign", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the CSR
|
||||
csr, err := connect.ParseCSR(args.CSR)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the SPIFFE ID
|
||||
spiffeID, err := connect.ParseCertURI(csr.URIs[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serviceID, ok := spiffeID.(*connect.SpiffeIDService)
|
||||
if !ok {
|
||||
return fmt.Errorf("SPIFFE ID in CSR must be a service ID")
|
||||
}
|
||||
|
||||
provider, caRoot := s.srv.getCAProvider()
|
||||
if provider == nil {
|
||||
return fmt.Errorf("internal error: CA provider is nil")
|
||||
}
|
||||
|
||||
// Verify that the CSR entity is in the cluster's trust domain
|
||||
state := s.srv.fsm.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
signingID := connect.SpiffeIDSigningForCluster(config)
|
||||
if !signingID.CanSign(serviceID) {
|
||||
return fmt.Errorf("SPIFFE ID in CSR from a different trust domain: %s, "+
|
||||
"we are %s", serviceID.Host, signingID.Host())
|
||||
}
|
||||
|
||||
// Verify that the ACL token provided has permission to act as this service
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rule != nil && !rule.ServiceWrite(serviceID.Service, nil) {
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// Verify that the DC in the service URI matches us. We might relax this
|
||||
// requirement later but being restrictive for now is safer.
|
||||
if serviceID.Datacenter != s.srv.config.Datacenter {
|
||||
return fmt.Errorf("SPIFFE ID in CSR from a different datacenter: %s, "+
|
||||
"we are %s", serviceID.Datacenter, s.srv.config.Datacenter)
|
||||
}
|
||||
|
||||
// All seems to be in order, actually sign it.
|
||||
pem, err := provider.Sign(csr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Append any intermediates needed by this root.
|
||||
for _, p := range caRoot.IntermediateCerts {
|
||||
pem = strings.TrimSpace(pem) + "\n" + p
|
||||
}
|
||||
|
||||
// TODO(banks): when we implement IssuedCerts table we can use the insert to
|
||||
// that as the raft index to return in response. Right now we can rely on only
|
||||
// the built-in provider being supported and the implementation detail that we
|
||||
// have to write a SerialIndex update to the provider config table for every
|
||||
// cert issued so in all cases this index will be higher than any previous
|
||||
// sign response. This has to be reloaded after the provider.Sign call to
|
||||
// observe the index update.
|
||||
state = s.srv.fsm.State()
|
||||
modIdx, _, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cert, err := connect.ParseCert(pem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the response
|
||||
*reply = structs.IssuedCert{
|
||||
SerialNumber: connect.HexString(cert.SerialNumber.Bytes()),
|
||||
CertPEM: pem,
|
||||
Service: serviceID.Service,
|
||||
ServiceURI: cert.URIs[0].String(),
|
||||
ValidAfter: cert.NotBefore,
|
||||
ValidBefore: cert.NotAfter,
|
||||
RaftIndex: structs.RaftIndex{
|
||||
ModifyIndex: modIdx,
|
||||
CreateIndex: modIdx,
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,434 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
ca "github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testParseCert(t *testing.T, pemValue string) *x509.Certificate {
|
||||
cert, err := connect.ParseCert(pemValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
// Test listing root CAs.
|
||||
func TestConnectCARoots(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Insert some CAs
|
||||
state := s1.fsm.State()
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
ca2.Active = false
|
||||
idx, _, err := state.CARoots(nil)
|
||||
require.NoError(err)
|
||||
ok, err := state.CARootSetCAS(idx, idx, []*structs.CARoot{ca1, ca2})
|
||||
assert.True(ok)
|
||||
require.NoError(err)
|
||||
_, caCfg, err := state.CAConfig()
|
||||
require.NoError(err)
|
||||
|
||||
// Request
|
||||
args := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var reply structs.IndexedCARoots
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
|
||||
|
||||
// Verify
|
||||
assert.Equal(ca1.ID, reply.ActiveRootID)
|
||||
assert.Len(reply.Roots, 2)
|
||||
for _, r := range reply.Roots {
|
||||
// These must never be set, for security
|
||||
assert.Equal("", r.SigningCert)
|
||||
assert.Equal("", r.SigningKey)
|
||||
}
|
||||
assert.Equal(fmt.Sprintf("%s.consul", caCfg.ClusterID), reply.TrustDomain)
|
||||
}
|
||||
|
||||
func TestConnectCAConfig_GetSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Get the starting config
|
||||
{
|
||||
args := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var reply structs.CAConfiguration
|
||||
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
|
||||
|
||||
actual, err := ca.ParseConsulCAConfig(reply.Config)
|
||||
assert.NoError(err)
|
||||
expected, err := ca.ParseConsulCAConfig(s1.config.CAConfig.Config)
|
||||
assert.NoError(err)
|
||||
assert.Equal(reply.Provider, s1.config.CAConfig.Provider)
|
||||
assert.Equal(actual, expected)
|
||||
}
|
||||
|
||||
// Update a config value
|
||||
newConfig := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": "",
|
||||
"RootCert": "",
|
||||
"RotationPeriod": 180 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
{
|
||||
args := &structs.CARequest{
|
||||
Datacenter: "dc1",
|
||||
Config: newConfig,
|
||||
}
|
||||
var reply interface{}
|
||||
|
||||
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
|
||||
}
|
||||
|
||||
// Verify the new config was set
|
||||
{
|
||||
args := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var reply structs.CAConfiguration
|
||||
assert.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
|
||||
|
||||
actual, err := ca.ParseConsulCAConfig(reply.Config)
|
||||
assert.NoError(err)
|
||||
expected, err := ca.ParseConsulCAConfig(newConfig.Config)
|
||||
assert.NoError(err)
|
||||
assert.Equal(reply.Provider, newConfig.Provider)
|
||||
assert.Equal(actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectCAConfig_TriggerRotation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Store the current root
|
||||
rootReq := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var rootList structs.IndexedCARoots
|
||||
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", rootReq, &rootList))
|
||||
assert.Len(rootList.Roots, 1)
|
||||
oldRoot := rootList.Roots[0]
|
||||
|
||||
// Update the provider config to use a new private key, which should
|
||||
// cause a rotation.
|
||||
_, newKey, err := connect.GeneratePrivateKey()
|
||||
assert.NoError(err)
|
||||
newConfig := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": newKey,
|
||||
"RootCert": "",
|
||||
"RotationPeriod": 90 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
{
|
||||
args := &structs.CARequest{
|
||||
Datacenter: "dc1",
|
||||
Config: newConfig,
|
||||
}
|
||||
var reply interface{}
|
||||
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationSet", args, &reply))
|
||||
}
|
||||
|
||||
// Make sure the new root has been added along with an intermediate
|
||||
// cross-signed by the old root.
|
||||
var newRootPEM string
|
||||
{
|
||||
args := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var reply structs.IndexedCARoots
|
||||
require.Nil(msgpackrpc.CallWithCodec(codec, "ConnectCA.Roots", args, &reply))
|
||||
assert.Len(reply.Roots, 2)
|
||||
|
||||
for _, r := range reply.Roots {
|
||||
if r.ID == oldRoot.ID {
|
||||
// The old root should no longer be marked as the active root,
|
||||
// and none of its other fields should have changed.
|
||||
assert.False(r.Active)
|
||||
assert.Equal(r.Name, oldRoot.Name)
|
||||
assert.Equal(r.RootCert, oldRoot.RootCert)
|
||||
assert.Equal(r.SigningCert, oldRoot.SigningCert)
|
||||
assert.Equal(r.IntermediateCerts, oldRoot.IntermediateCerts)
|
||||
} else {
|
||||
newRootPEM = r.RootCert
|
||||
// The new root should have a valid cross-signed cert from the old
|
||||
// root as an intermediate.
|
||||
assert.True(r.Active)
|
||||
assert.Len(r.IntermediateCerts, 1)
|
||||
|
||||
xc := testParseCert(t, r.IntermediateCerts[0])
|
||||
oldRootCert := testParseCert(t, oldRoot.RootCert)
|
||||
newRootCert := testParseCert(t, r.RootCert)
|
||||
|
||||
// Should have the authority key ID and signature algo of the
|
||||
// (old) signing CA.
|
||||
assert.Equal(xc.AuthorityKeyId, oldRootCert.AuthorityKeyId)
|
||||
assert.NotEqual(xc.SubjectKeyId, oldRootCert.SubjectKeyId)
|
||||
assert.Equal(xc.SignatureAlgorithm, oldRootCert.SignatureAlgorithm)
|
||||
|
||||
// The common name and SAN should not have changed.
|
||||
assert.Equal(xc.Subject.CommonName, newRootCert.Subject.CommonName)
|
||||
assert.Equal(xc.URIs, newRootCert.URIs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the new config was set.
|
||||
{
|
||||
args := &structs.DCSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
var reply structs.CAConfiguration
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.ConfigurationGet", args, &reply))
|
||||
|
||||
actual, err := ca.ParseConsulCAConfig(reply.Config)
|
||||
require.NoError(err)
|
||||
expected, err := ca.ParseConsulCAConfig(newConfig.Config)
|
||||
require.NoError(err)
|
||||
assert.Equal(reply.Provider, newConfig.Provider)
|
||||
assert.Equal(actual, expected)
|
||||
}
|
||||
|
||||
// Verify that new leaf certs get the cross-signed intermediate bundled
|
||||
{
|
||||
// Generate a CSR and request signing
|
||||
spiffeId := connect.TestSpiffeIDService(t, "web")
|
||||
csr, _ := connect.TestCSR(t, spiffeId)
|
||||
args := &structs.CASignRequest{
|
||||
Datacenter: "dc1",
|
||||
CSR: csr,
|
||||
}
|
||||
var reply structs.IssuedCert
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
|
||||
|
||||
// Verify that the cert is signed by the new CA
|
||||
{
|
||||
roots := x509.NewCertPool()
|
||||
require.True(roots.AppendCertsFromPEM([]byte(newRootPEM)))
|
||||
leaf, err := connect.ParseCert(reply.CertPEM)
|
||||
require.NoError(err)
|
||||
_, err = leaf.Verify(x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
})
|
||||
require.NoError(err)
|
||||
}
|
||||
|
||||
// And that it validates via the intermediate
|
||||
{
|
||||
roots := x509.NewCertPool()
|
||||
assert.True(roots.AppendCertsFromPEM([]byte(oldRoot.RootCert)))
|
||||
leaf, err := connect.ParseCert(reply.CertPEM)
|
||||
require.NoError(err)
|
||||
|
||||
// Make sure the intermediate was returned as well as leaf
|
||||
_, rest := pem.Decode([]byte(reply.CertPEM))
|
||||
require.NotEmpty(rest)
|
||||
|
||||
intermediates := x509.NewCertPool()
|
||||
require.True(intermediates.AppendCertsFromPEM(rest))
|
||||
|
||||
_, err = leaf.Verify(x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
Intermediates: intermediates,
|
||||
})
|
||||
require.NoError(err)
|
||||
}
|
||||
|
||||
// Verify other fields
|
||||
assert.Equal("web", reply.Service)
|
||||
assert.Equal(spiffeId.URI().String(), reply.ServiceURI)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CA signing
|
||||
func TestConnectCASign(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Generate a CSR and request signing
|
||||
spiffeId := connect.TestSpiffeIDService(t, "web")
|
||||
csr, _ := connect.TestCSR(t, spiffeId)
|
||||
args := &structs.CASignRequest{
|
||||
Datacenter: "dc1",
|
||||
CSR: csr,
|
||||
}
|
||||
var reply structs.IssuedCert
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply))
|
||||
|
||||
// Get the current CA
|
||||
state := s1.fsm.State()
|
||||
_, ca, err := state.CARootActive(nil)
|
||||
require.NoError(err)
|
||||
|
||||
// Verify that the cert is signed by the CA
|
||||
roots := x509.NewCertPool()
|
||||
assert.True(roots.AppendCertsFromPEM([]byte(ca.RootCert)))
|
||||
leaf, err := connect.ParseCert(reply.CertPEM)
|
||||
require.NoError(err)
|
||||
_, err = leaf.Verify(x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
})
|
||||
require.NoError(err)
|
||||
|
||||
// Verify other fields
|
||||
assert.Equal("web", reply.Service)
|
||||
assert.Equal(spiffeId.URI().String(), reply.ServiceURI)
|
||||
}
|
||||
|
||||
func TestConnectCASignValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir1, s1 := testServerWithConfig(t, func(c *Config) {
|
||||
c.ACLDatacenter = "dc1"
|
||||
c.ACLMasterToken = "root"
|
||||
c.ACLDefaultPolicy = "deny"
|
||||
})
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Create an ACL token with service:write for web*
|
||||
var webToken string
|
||||
{
|
||||
arg := structs.ACLRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.ACLSet,
|
||||
ACL: structs.ACL{
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
Rules: `
|
||||
service "web" {
|
||||
policy = "write"
|
||||
}`,
|
||||
},
|
||||
WriteRequest: structs.WriteRequest{Token: "root"},
|
||||
}
|
||||
require.NoError(t, msgpackrpc.CallWithCodec(codec, "ACL.Apply", &arg, &webToken))
|
||||
}
|
||||
|
||||
testWebID := connect.TestSpiffeIDService(t, "web")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id connect.CertURI
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "different cluster",
|
||||
id: &connect.SpiffeIDService{
|
||||
Host: "55555555-4444-3333-2222-111111111111.consul",
|
||||
Namespace: testWebID.Namespace,
|
||||
Datacenter: testWebID.Datacenter,
|
||||
Service: testWebID.Service,
|
||||
},
|
||||
wantErr: "different trust domain",
|
||||
},
|
||||
{
|
||||
name: "same cluster should validate",
|
||||
id: testWebID,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "same cluster, CSR for a different DC should NOT validate",
|
||||
id: &connect.SpiffeIDService{
|
||||
Host: testWebID.Host,
|
||||
Namespace: testWebID.Namespace,
|
||||
Datacenter: "dc2",
|
||||
Service: testWebID.Service,
|
||||
},
|
||||
wantErr: "different datacenter",
|
||||
},
|
||||
{
|
||||
name: "same cluster and DC, different service should not have perms",
|
||||
id: &connect.SpiffeIDService{
|
||||
Host: testWebID.Host,
|
||||
Namespace: testWebID.Namespace,
|
||||
Datacenter: testWebID.Datacenter,
|
||||
Service: "db",
|
||||
},
|
||||
wantErr: "Permission denied",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
csr, _ := connect.TestCSR(t, tt.id)
|
||||
args := &structs.CASignRequest{
|
||||
Datacenter: "dc1",
|
||||
CSR: csr,
|
||||
WriteRequest: structs.WriteRequest{Token: webToken},
|
||||
}
|
||||
var reply structs.IssuedCert
|
||||
err := msgpackrpc.CallWithCodec(codec, "ConnectCA.Sign", args, &reply)
|
||||
if tt.wantErr == "" {
|
||||
require.NoError(t, err)
|
||||
// No other validation that is handled in different tests
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// consulCADelegate providers callbacks for the Consul CA provider
|
||||
// to use the state store for its operations.
|
||||
type consulCADelegate struct {
|
||||
srv *Server
|
||||
}
|
||||
|
||||
func (c *consulCADelegate) State() *state.Store {
|
||||
return c.srv.fsm.State()
|
||||
}
|
||||
|
||||
func (c *consulCADelegate) ApplyCARequest(req *structs.CARequest) error {
|
||||
resp, err := c.srv.raftApply(structs.ConnectCARequestType, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -20,6 +20,8 @@ func init() {
|
|||
registerCommand(structs.PreparedQueryRequestType, (*FSM).applyPreparedQueryOperation)
|
||||
registerCommand(structs.TxnRequestType, (*FSM).applyTxn)
|
||||
registerCommand(structs.AutopilotRequestType, (*FSM).applyAutopilotUpdate)
|
||||
registerCommand(structs.IntentionRequestType, (*FSM).applyIntentionOperation)
|
||||
registerCommand(structs.ConnectCARequestType, (*FSM).applyConnectCAOperation)
|
||||
}
|
||||
|
||||
func (c *FSM) applyRegister(buf []byte, index uint64) interface{} {
|
||||
|
@ -246,3 +248,85 @@ func (c *FSM) applyAutopilotUpdate(buf []byte, index uint64) interface{} {
|
|||
}
|
||||
return c.state.AutopilotSetConfig(index, &req.Config)
|
||||
}
|
||||
|
||||
// applyIntentionOperation applies the given intention operation to the state store.
|
||||
func (c *FSM) applyIntentionOperation(buf []byte, index uint64) interface{} {
|
||||
var req structs.IntentionRequest
|
||||
if err := structs.Decode(buf, &req); err != nil {
|
||||
panic(fmt.Errorf("failed to decode request: %v", err))
|
||||
}
|
||||
|
||||
defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "intention"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
|
||||
defer metrics.MeasureSinceWithLabels([]string{"fsm", "intention"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
|
||||
switch req.Op {
|
||||
case structs.IntentionOpCreate, structs.IntentionOpUpdate:
|
||||
return c.state.IntentionSet(index, req.Intention)
|
||||
case structs.IntentionOpDelete:
|
||||
return c.state.IntentionDelete(index, req.Intention.ID)
|
||||
default:
|
||||
c.logger.Printf("[WARN] consul.fsm: Invalid Intention operation '%s'", req.Op)
|
||||
return fmt.Errorf("Invalid Intention operation '%s'", req.Op)
|
||||
}
|
||||
}
|
||||
|
||||
// applyConnectCAOperation applies the given CA operation to the state store.
|
||||
func (c *FSM) applyConnectCAOperation(buf []byte, index uint64) interface{} {
|
||||
var req structs.CARequest
|
||||
if err := structs.Decode(buf, &req); err != nil {
|
||||
panic(fmt.Errorf("failed to decode request: %v", err))
|
||||
}
|
||||
|
||||
defer metrics.MeasureSinceWithLabels([]string{"consul", "fsm", "ca"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
|
||||
defer metrics.MeasureSinceWithLabels([]string{"fsm", "ca"}, time.Now(),
|
||||
[]metrics.Label{{Name: "op", Value: string(req.Op)}})
|
||||
switch req.Op {
|
||||
case structs.CAOpSetConfig:
|
||||
if req.Config.ModifyIndex != 0 {
|
||||
act, err := c.state.CACheckAndSetConfig(index, req.Config.ModifyIndex, req.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return act
|
||||
}
|
||||
|
||||
return c.state.CASetConfig(index, req.Config)
|
||||
case structs.CAOpSetRoots:
|
||||
act, err := c.state.CARootSetCAS(index, req.Index, req.Roots)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return act
|
||||
case structs.CAOpSetProviderState:
|
||||
act, err := c.state.CASetProviderState(index, req.ProviderState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return act
|
||||
case structs.CAOpDeleteProviderState:
|
||||
if err := c.state.CADeleteProviderState(req.ProviderState.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return true
|
||||
case structs.CAOpSetRootsAndConfig:
|
||||
act, err := c.state.CARootSetCAS(index, req.Index, req.Roots)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.state.CASetConfig(index+1, req.Config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return act
|
||||
default:
|
||||
c.logger.Printf("[WARN] consul.fsm: Invalid CA operation '%s'", req.Op)
|
||||
return fmt.Errorf("Invalid CA operation '%s'", req.Op)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,13 +8,16 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/serf/coordinate"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func generateUUID() (ret string) {
|
||||
|
@ -1148,3 +1151,209 @@ func TestFSM_Autopilot(t *testing.T) {
|
|||
t.Fatalf("bad: %v", config.CleanupDeadServers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFSM_Intention_CRUD(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
fsm, err := New(nil, os.Stderr)
|
||||
assert.Nil(err)
|
||||
|
||||
// Create a new intention.
|
||||
ixn := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: structs.TestIntention(t),
|
||||
}
|
||||
ixn.Intention.ID = generateUUID()
|
||||
ixn.Intention.UpdatePrecedence()
|
||||
|
||||
{
|
||||
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
|
||||
assert.Nil(err)
|
||||
assert.Nil(fsm.Apply(makeLog(buf)))
|
||||
}
|
||||
|
||||
// Verify it's in the state store.
|
||||
{
|
||||
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
|
||||
assert.Nil(err)
|
||||
|
||||
actual.CreateIndex, actual.ModifyIndex = 0, 0
|
||||
actual.CreatedAt = ixn.Intention.CreatedAt
|
||||
actual.UpdatedAt = ixn.Intention.UpdatedAt
|
||||
assert.Equal(ixn.Intention, actual)
|
||||
}
|
||||
|
||||
// Make an update
|
||||
ixn.Op = structs.IntentionOpUpdate
|
||||
ixn.Intention.SourceName = "api"
|
||||
{
|
||||
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
|
||||
assert.Nil(err)
|
||||
assert.Nil(fsm.Apply(makeLog(buf)))
|
||||
}
|
||||
|
||||
// Verify the update.
|
||||
{
|
||||
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
|
||||
assert.Nil(err)
|
||||
|
||||
actual.CreateIndex, actual.ModifyIndex = 0, 0
|
||||
actual.CreatedAt = ixn.Intention.CreatedAt
|
||||
actual.UpdatedAt = ixn.Intention.UpdatedAt
|
||||
assert.Equal(ixn.Intention, actual)
|
||||
}
|
||||
|
||||
// Delete
|
||||
ixn.Op = structs.IntentionOpDelete
|
||||
{
|
||||
buf, err := structs.Encode(structs.IntentionRequestType, ixn)
|
||||
assert.Nil(err)
|
||||
assert.Nil(fsm.Apply(makeLog(buf)))
|
||||
}
|
||||
|
||||
// Make sure it's gone.
|
||||
{
|
||||
_, actual, err := fsm.state.IntentionGet(nil, ixn.Intention.ID)
|
||||
assert.Nil(err)
|
||||
assert.Nil(actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFSM_CAConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
fsm, err := New(nil, os.Stderr)
|
||||
assert.Nil(err)
|
||||
|
||||
// Set the autopilot config using a request.
|
||||
req := structs.CARequest{
|
||||
Op: structs.CAOpSetConfig,
|
||||
Config: &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": "asdf",
|
||||
"RootCert": "qwer",
|
||||
"RotationPeriod": 90 * 24 * time.Hour,
|
||||
},
|
||||
},
|
||||
}
|
||||
buf, err := structs.Encode(structs.ConnectCARequestType, req)
|
||||
assert.Nil(err)
|
||||
resp := fsm.Apply(makeLog(buf))
|
||||
if _, ok := resp.(error); ok {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
|
||||
// Verify key is set directly in the state store.
|
||||
_, config, err := fsm.state.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
var conf *structs.ConsulCAProviderConfig
|
||||
if err := mapstructure.WeakDecode(config.Config, &conf); err != nil {
|
||||
t.Fatalf("error decoding config: %s, %v", err, config.Config)
|
||||
}
|
||||
if got, want := config.Provider, req.Config.Provider; got != want {
|
||||
t.Fatalf("got %v, want %v", got, want)
|
||||
}
|
||||
if got, want := conf.PrivateKey, "asdf"; got != want {
|
||||
t.Fatalf("got %v, want %v", got, want)
|
||||
}
|
||||
if got, want := conf.RootCert, "qwer"; got != want {
|
||||
t.Fatalf("got %v, want %v", got, want)
|
||||
}
|
||||
if got, want := conf.RotationPeriod, 90*24*time.Hour; got != want {
|
||||
t.Fatalf("got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
// Now use CAS and provide an old index
|
||||
req.Config.Provider = "static"
|
||||
req.Config.ModifyIndex = config.ModifyIndex - 1
|
||||
buf, err = structs.Encode(structs.ConnectCARequestType, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
resp = fsm.Apply(makeLog(buf))
|
||||
if _, ok := resp.(error); ok {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
|
||||
_, config, err = fsm.state.CAConfig()
|
||||
assert.Nil(err)
|
||||
if config.Provider != "static" {
|
||||
t.Fatalf("bad: %v", config.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFSM_CARoots(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
fsm, err := New(nil, os.Stderr)
|
||||
assert.Nil(err)
|
||||
|
||||
// Roots
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
ca2.Active = false
|
||||
|
||||
// Create a new request.
|
||||
req := structs.CARequest{
|
||||
Op: structs.CAOpSetRoots,
|
||||
Roots: []*structs.CARoot{ca1, ca2},
|
||||
}
|
||||
|
||||
{
|
||||
buf, err := structs.Encode(structs.ConnectCARequestType, req)
|
||||
assert.Nil(err)
|
||||
assert.True(fsm.Apply(makeLog(buf)).(bool))
|
||||
}
|
||||
|
||||
// Verify it's in the state store.
|
||||
{
|
||||
_, roots, err := fsm.state.CARoots(nil)
|
||||
assert.Nil(err)
|
||||
assert.Len(roots, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFSM_CABuiltinProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
fsm, err := New(nil, os.Stderr)
|
||||
assert.Nil(err)
|
||||
|
||||
// Provider state.
|
||||
expected := &structs.CAConsulProviderState{
|
||||
ID: "foo",
|
||||
PrivateKey: "a",
|
||||
RootCert: "b",
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 1,
|
||||
ModifyIndex: 1,
|
||||
},
|
||||
}
|
||||
|
||||
// Create a new request.
|
||||
req := structs.CARequest{
|
||||
Op: structs.CAOpSetProviderState,
|
||||
ProviderState: expected,
|
||||
}
|
||||
|
||||
{
|
||||
buf, err := structs.Encode(structs.ConnectCARequestType, req)
|
||||
assert.Nil(err)
|
||||
assert.True(fsm.Apply(makeLog(buf)).(bool))
|
||||
}
|
||||
|
||||
// Verify it's in the state store.
|
||||
{
|
||||
_, state, err := fsm.state.CAProviderState("foo")
|
||||
assert.Nil(err)
|
||||
assert.Equal(expected, state)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@ func init() {
|
|||
registerRestorer(structs.CoordinateBatchUpdateType, restoreCoordinates)
|
||||
registerRestorer(structs.PreparedQueryRequestType, restorePreparedQuery)
|
||||
registerRestorer(structs.AutopilotRequestType, restoreAutopilot)
|
||||
registerRestorer(structs.IntentionRequestType, restoreIntention)
|
||||
registerRestorer(structs.ConnectCARequestType, restoreConnectCA)
|
||||
}
|
||||
|
||||
func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error {
|
||||
|
@ -44,6 +46,12 @@ func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) err
|
|||
if err := s.persistAutopilot(sink, encoder); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.persistIntentions(sink, encoder); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.persistConnectCA(sink, encoder); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -258,6 +266,42 @@ func (s *snapshot) persistAutopilot(sink raft.SnapshotSink,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *snapshot) persistConnectCA(sink raft.SnapshotSink,
|
||||
encoder *codec.Encoder) error {
|
||||
roots, err := s.state.CARoots()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, r := range roots {
|
||||
if _, err := sink.Write([]byte{byte(structs.ConnectCARequestType)}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := encoder.Encode(r); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *snapshot) persistIntentions(sink raft.SnapshotSink,
|
||||
encoder *codec.Encoder) error {
|
||||
ixns, err := s.state.Intentions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, ixn := range ixns {
|
||||
if _, err := sink.Write([]byte{byte(structs.IntentionRequestType)}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := encoder.Encode(ixn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreRegistration(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
|
||||
var req structs.RegisterRequest
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
|
@ -364,3 +408,25 @@ func restoreAutopilot(header *snapshotHeader, restore *state.Restore, decoder *c
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreIntention(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
|
||||
var req structs.Intention
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := restore.Intention(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreConnectCA(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error {
|
||||
var req structs.CARoot
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := restore.CARoot(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -7,16 +7,20 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFSM_SnapshotRestore_OSS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
fsm, err := New(nil, os.Stderr)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
|
@ -98,6 +102,27 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
|
|||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Intentions
|
||||
ixn := structs.TestIntention(t)
|
||||
ixn.ID = generateUUID()
|
||||
ixn.RaftIndex = structs.RaftIndex{
|
||||
CreateIndex: 14,
|
||||
ModifyIndex: 14,
|
||||
}
|
||||
assert.Nil(fsm.state.IntentionSet(14, ixn))
|
||||
|
||||
// CA Roots
|
||||
roots := []*structs.CARoot{
|
||||
connect.TestCA(t, nil),
|
||||
connect.TestCA(t, nil),
|
||||
}
|
||||
for _, r := range roots[1:] {
|
||||
r.Active = false
|
||||
}
|
||||
ok, err := fsm.state.CARootSetCAS(15, 0, roots)
|
||||
assert.Nil(err)
|
||||
assert.True(ok)
|
||||
|
||||
// Snapshot
|
||||
snap, err := fsm.Snapshot()
|
||||
if err != nil {
|
||||
|
@ -260,6 +285,17 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) {
|
|||
t.Fatalf("bad: %#v, %#v", restoredConf, autopilotConf)
|
||||
}
|
||||
|
||||
// Verify intentions are restored.
|
||||
_, ixns, err := fsm2.state.Intentions(nil)
|
||||
assert.Nil(err)
|
||||
assert.Len(ixns, 1)
|
||||
assert.Equal(ixn, ixns[0])
|
||||
|
||||
// Verify CA roots are restored.
|
||||
_, roots, err = fsm2.state.CARoots(nil)
|
||||
assert.Nil(err)
|
||||
assert.Len(roots, 2)
|
||||
|
||||
// Snapshot
|
||||
snap, err = fsm2.Snapshot()
|
||||
if err != nil {
|
||||
|
|
|
@ -111,18 +111,37 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc
|
|||
return fmt.Errorf("Must provide service name")
|
||||
}
|
||||
|
||||
// Determine the function we'll call
|
||||
var f func(memdb.WatchSet, *state.Store, *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error)
|
||||
switch {
|
||||
case args.Connect:
|
||||
f = h.serviceNodesConnect
|
||||
case args.TagFilter:
|
||||
f = h.serviceNodesTagFilter
|
||||
default:
|
||||
f = h.serviceNodesDefault
|
||||
}
|
||||
|
||||
// If we're doing a connect query, we need read access to the service
|
||||
// we're trying to find proxies for, so check that.
|
||||
if args.Connect {
|
||||
// Fetch the ACL token, if any.
|
||||
rule, err := h.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule != nil && !rule.ServiceRead(args.ServiceName) {
|
||||
// Just return nil, which will return an empty response (tested)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
err := h.srv.blockingQuery(
|
||||
&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
var index uint64
|
||||
var nodes structs.CheckServiceNodes
|
||||
var err error
|
||||
if args.TagFilter {
|
||||
index, nodes, err = state.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
|
||||
} else {
|
||||
index, nodes, err = state.CheckServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
index, nodes, err := f(ws, state, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -139,16 +158,37 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc
|
|||
|
||||
// Provide some metrics
|
||||
if err == nil {
|
||||
metrics.IncrCounterWithLabels([]string{"health", "service", "query"}, 1,
|
||||
// For metrics, we separate Connect-based lookups from non-Connect
|
||||
key := "service"
|
||||
if args.Connect {
|
||||
key = "connect"
|
||||
}
|
||||
|
||||
metrics.IncrCounterWithLabels([]string{"health", key, "query"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
|
||||
if args.ServiceTag != "" {
|
||||
metrics.IncrCounterWithLabels([]string{"health", "service", "query-tag"}, 1,
|
||||
metrics.IncrCounterWithLabels([]string{"health", key, "query-tag"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}, {Name: "tag", Value: args.ServiceTag}})
|
||||
}
|
||||
if len(reply.Nodes) == 0 {
|
||||
metrics.IncrCounterWithLabels([]string{"health", "service", "not-found"}, 1,
|
||||
metrics.IncrCounterWithLabels([]string{"health", key, "not-found"}, 1,
|
||||
[]metrics.Label{{Name: "service", Value: args.ServiceName}})
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// The serviceNodes* functions below are the various lookup methods that
|
||||
// can be used by the ServiceNodes endpoint.
|
||||
|
||||
func (h *Health) serviceNodesConnect(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
|
||||
return s.CheckConnectServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
|
||||
func (h *Health) serviceNodesTagFilter(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
|
||||
return s.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTag)
|
||||
}
|
||||
|
||||
func (h *Health) serviceNodesDefault(ws memdb.WatchSet, s *state.Store, args *structs.ServiceSpecificRequest) (uint64, structs.CheckServiceNodes, error) {
|
||||
return s.CheckServiceNodes(ws, args.ServiceName)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHealth_ChecksInState(t *testing.T) {
|
||||
|
@ -821,6 +822,106 @@ func TestHealth_ServiceNodes_DistanceSort(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHealth_ServiceNodes_ConnectProxy_ACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
dir1, s1 := testServerWithConfig(t, func(c *Config) {
|
||||
c.ACLDatacenter = "dc1"
|
||||
c.ACLMasterToken = "root"
|
||||
c.ACLDefaultPolicy = "deny"
|
||||
c.ACLEnforceVersion8 = false
|
||||
})
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
// Create the ACL.
|
||||
arg := structs.ACLRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.ACLSet,
|
||||
ACL: structs.ACL{
|
||||
Name: "User token",
|
||||
Type: structs.ACLTypeClient,
|
||||
Rules: `
|
||||
service "foo" {
|
||||
policy = "write"
|
||||
}
|
||||
`,
|
||||
},
|
||||
WriteRequest: structs.WriteRequest{Token: "root"},
|
||||
}
|
||||
var token string
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "ACL.Apply", arg, &token))
|
||||
|
||||
{
|
||||
var out struct{}
|
||||
|
||||
// Register a service
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.WriteRequest.Token = "root"
|
||||
args.Service.ID = "foo-proxy-0"
|
||||
args.Service.Service = "foo-proxy"
|
||||
args.Service.ProxyDestination = "bar"
|
||||
args.Check = &structs.HealthCheck{
|
||||
Name: "proxy",
|
||||
Status: api.HealthPassing,
|
||||
ServiceID: args.Service.ID,
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// Register a service
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.WriteRequest.Token = "root"
|
||||
args.Service.Service = "foo-proxy"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.Check = &structs.HealthCheck{
|
||||
Name: "proxy",
|
||||
Status: api.HealthPassing,
|
||||
ServiceID: args.Service.Service,
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
|
||||
// Register a service
|
||||
args = structs.TestRegisterRequestProxy(t)
|
||||
args.WriteRequest.Token = "root"
|
||||
args.Service.Service = "another-proxy"
|
||||
args.Service.ProxyDestination = "foo"
|
||||
args.Check = &structs.HealthCheck{
|
||||
Name: "proxy",
|
||||
Status: api.HealthPassing,
|
||||
ServiceID: args.Service.Service,
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &args, &out))
|
||||
}
|
||||
|
||||
// List w/ token. This should disallow because we don't have permission
|
||||
// to read "bar"
|
||||
req := structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: "bar",
|
||||
QueryOptions: structs.QueryOptions{Token: token},
|
||||
}
|
||||
var resp structs.IndexedCheckServiceNodes
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.Nodes, 0)
|
||||
|
||||
// List w/ token. This should work since we're requesting "foo", but should
|
||||
// also only contain the proxies with names that adhere to our ACL.
|
||||
req = structs.ServiceSpecificRequest{
|
||||
Connect: true,
|
||||
Datacenter: "dc1",
|
||||
ServiceName: "foo",
|
||||
QueryOptions: structs.QueryOptions{Token: token},
|
||||
}
|
||||
assert.Nil(msgpackrpc.CallWithCodec(codec, "Health.ServiceNodes", &req, &resp))
|
||||
assert.Len(resp.Nodes, 1)
|
||||
}
|
||||
|
||||
func TestHealth_NodeChecks_FilterACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir, token, srv, codec := testACLFilterServer(t)
|
||||
|
|
|
@ -0,0 +1,358 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrIntentionNotFound is returned if the intention lookup failed.
|
||||
ErrIntentionNotFound = errors.New("Intention not found")
|
||||
)
|
||||
|
||||
// Intention manages the Connect intentions.
|
||||
type Intention struct {
|
||||
// srv is a pointer back to the server.
|
||||
srv *Server
|
||||
}
|
||||
|
||||
// Apply creates or updates an intention in the data store.
|
||||
func (s *Intention) Apply(
|
||||
args *structs.IntentionRequest,
|
||||
reply *string) error {
|
||||
if done, err := s.srv.forward("Intention.Apply", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
defer metrics.MeasureSince([]string{"consul", "intention", "apply"}, time.Now())
|
||||
defer metrics.MeasureSince([]string{"intention", "apply"}, time.Now())
|
||||
|
||||
// Always set a non-nil intention to avoid nil-access below
|
||||
if args.Intention == nil {
|
||||
args.Intention = &structs.Intention{}
|
||||
}
|
||||
|
||||
// If no ID is provided, generate a new ID. This must be done prior to
|
||||
// appending to the Raft log, because the ID is not deterministic. Once
|
||||
// the entry is in the log, the state update MUST be deterministic or
|
||||
// the followers will not converge.
|
||||
if args.Op == structs.IntentionOpCreate {
|
||||
if args.Intention.ID != "" {
|
||||
return fmt.Errorf("ID must be empty when creating a new intention")
|
||||
}
|
||||
|
||||
state := s.srv.fsm.State()
|
||||
for {
|
||||
var err error
|
||||
args.Intention.ID, err = uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
s.srv.logger.Printf("[ERR] consul.intention: UUID generation failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, ixn, err := state.IntentionGet(nil, args.Intention.ID)
|
||||
if err != nil {
|
||||
s.srv.logger.Printf("[ERR] consul.intention: intention lookup failed: %v", err)
|
||||
return err
|
||||
}
|
||||
if ixn == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Set the created at
|
||||
args.Intention.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
*reply = args.Intention.ID
|
||||
|
||||
// Get the ACL token for the request for the checks below.
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Perform the ACL check
|
||||
if prefix, ok := args.Intention.GetACLPrefix(); ok {
|
||||
if rule != nil && !rule.IntentionWrite(prefix) {
|
||||
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention '%s' denied due to ACLs", args.Intention.ID)
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
// If this is not a create, then we have to verify the ID.
|
||||
if args.Op != structs.IntentionOpCreate {
|
||||
state := s.srv.fsm.State()
|
||||
_, ixn, err := state.IntentionGet(nil, args.Intention.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Intention lookup failed: %v", err)
|
||||
}
|
||||
if ixn == nil {
|
||||
return fmt.Errorf("Cannot modify non-existent intention: '%s'", args.Intention.ID)
|
||||
}
|
||||
|
||||
// Perform the ACL check that we have write to the old prefix too,
|
||||
// which must be true to perform any rename.
|
||||
if prefix, ok := ixn.GetACLPrefix(); ok {
|
||||
if rule != nil && !rule.IntentionWrite(prefix) {
|
||||
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention '%s' denied due to ACLs", args.Intention.ID)
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We always update the updatedat field. This has no effect for deletion.
|
||||
args.Intention.UpdatedAt = time.Now().UTC()
|
||||
|
||||
// Default source type
|
||||
if args.Intention.SourceType == "" {
|
||||
args.Intention.SourceType = structs.IntentionSourceConsul
|
||||
}
|
||||
|
||||
// Until we support namespaces, we force all namespaces to be default
|
||||
if args.Intention.SourceNS == "" {
|
||||
args.Intention.SourceNS = structs.IntentionDefaultNamespace
|
||||
}
|
||||
if args.Intention.DestinationNS == "" {
|
||||
args.Intention.DestinationNS = structs.IntentionDefaultNamespace
|
||||
}
|
||||
|
||||
// Validate. We do not validate on delete since it is valid to only
|
||||
// send an ID in that case.
|
||||
if args.Op != structs.IntentionOpDelete {
|
||||
// Set the precedence
|
||||
args.Intention.UpdatePrecedence()
|
||||
|
||||
if err := args.Intention.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit
|
||||
resp, err := s.srv.raftApply(structs.IntentionRequestType, args)
|
||||
if err != nil {
|
||||
s.srv.logger.Printf("[ERR] consul.intention: Apply failed %v", err)
|
||||
return err
|
||||
}
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns a single intention by ID.
|
||||
func (s *Intention) Get(
|
||||
args *structs.IntentionQueryRequest,
|
||||
reply *structs.IndexedIntentions) error {
|
||||
// Forward if necessary
|
||||
if done, err := s.srv.forward("Intention.Get", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.srv.blockingQuery(
|
||||
&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, ixn, err := state.IntentionGet(ws, args.IntentionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ixn == nil {
|
||||
return ErrIntentionNotFound
|
||||
}
|
||||
|
||||
reply.Index = index
|
||||
reply.Intentions = structs.Intentions{ixn}
|
||||
|
||||
// Filter
|
||||
if err := s.srv.filterACL(args.Token, reply); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If ACLs prevented any responses, error
|
||||
if len(reply.Intentions) == 0 {
|
||||
s.srv.logger.Printf("[WARN] consul.intention: Request to get intention '%s' denied due to ACLs", args.IntentionID)
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// List returns all the intentions.
|
||||
func (s *Intention) List(
|
||||
args *structs.DCSpecificRequest,
|
||||
reply *structs.IndexedIntentions) error {
|
||||
// Forward if necessary
|
||||
if done, err := s.srv.forward("Intention.List", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.srv.blockingQuery(
|
||||
&args.QueryOptions, &reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, ixns, err := state.Intentions(ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index, reply.Intentions = index, ixns
|
||||
if reply.Intentions == nil {
|
||||
reply.Intentions = make(structs.Intentions, 0)
|
||||
}
|
||||
|
||||
return s.srv.filterACL(args.Token, reply)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Match returns the set of intentions that match the given source/destination.
|
||||
func (s *Intention) Match(
|
||||
args *structs.IntentionQueryRequest,
|
||||
reply *structs.IndexedIntentionMatches) error {
|
||||
// Forward if necessary
|
||||
if done, err := s.srv.forward("Intention.Match", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the ACL token for the request for the checks below.
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rule != nil {
|
||||
// We go through each entry and test the destination to check if it
|
||||
// matches.
|
||||
for _, entry := range args.Match.Entries {
|
||||
if prefix := entry.Name; prefix != "" && !rule.IntentionRead(prefix) {
|
||||
s.srv.logger.Printf("[WARN] consul.intention: Operation on intention prefix '%s' denied due to ACLs", prefix)
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.srv.blockingQuery(
|
||||
&args.QueryOptions,
|
||||
&reply.QueryMeta,
|
||||
func(ws memdb.WatchSet, state *state.Store) error {
|
||||
index, matches, err := state.IntentionMatch(ws, args.Match)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Index = index
|
||||
reply.Matches = matches
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Check tests a source/destination and returns whether it would be allowed
|
||||
// or denied based on the current ACL configuration.
|
||||
//
|
||||
// Note: Whenever the logic for this method is changed, you should take
|
||||
// a look at the agent authorize endpoint (agent/agent_endpoint.go) since
|
||||
// the logic there is similar.
|
||||
func (s *Intention) Check(
|
||||
args *structs.IntentionQueryRequest,
|
||||
reply *structs.IntentionQueryCheckResponse) error {
|
||||
// Forward maybe
|
||||
if done, err := s.srv.forward("Intention.Check", args, args, reply); done {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the test args, and defensively guard against nil
|
||||
query := args.Check
|
||||
if query == nil {
|
||||
return errors.New("Check must be specified on args")
|
||||
}
|
||||
|
||||
// Build the URI
|
||||
var uri connect.CertURI
|
||||
switch query.SourceType {
|
||||
case structs.IntentionSourceConsul:
|
||||
uri = &connect.SpiffeIDService{
|
||||
Namespace: query.SourceNS,
|
||||
Service: query.SourceName,
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported SourceType: %q", query.SourceType)
|
||||
}
|
||||
|
||||
// Get the ACL token for the request for the checks below.
|
||||
rule, err := s.srv.resolveToken(args.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Perform the ACL check. For Check we only require ServiceRead and
|
||||
// NOT IntentionRead because the Check API only returns pass/fail and
|
||||
// returns no other information about the intentions used.
|
||||
if prefix, ok := query.GetACLPrefix(); ok {
|
||||
if rule != nil && !rule.ServiceRead(prefix) {
|
||||
s.srv.logger.Printf("[WARN] consul.intention: test on intention '%s' denied due to ACLs", prefix)
|
||||
return acl.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
// Get the matches for this destination
|
||||
state := s.srv.fsm.State()
|
||||
_, matches, err := state.IntentionMatch(nil, &structs.IntentionQueryMatch{
|
||||
Type: structs.IntentionMatchDestination,
|
||||
Entries: []structs.IntentionMatchEntry{
|
||||
structs.IntentionMatchEntry{
|
||||
Namespace: query.DestinationNS,
|
||||
Name: query.DestinationName,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(matches) != 1 {
|
||||
// This should never happen since the documented behavior of the
|
||||
// Match call is that it'll always return exactly the number of results
|
||||
// as entries passed in. But we guard against misbehavior.
|
||||
return errors.New("internal error loading matches")
|
||||
}
|
||||
|
||||
// Check the authorization for each match
|
||||
for _, ixn := range matches[0] {
|
||||
if auth, ok := uri.Authorize(ixn); ok {
|
||||
reply.Allowed = auth
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// No match, we need to determine the default behavior. We do this by
|
||||
// specifying the anonymous token token, which will get that behavior.
|
||||
// The default behavior if ACLs are disabled is to allow connections
|
||||
// to mimic the behavior of Consul itself: everything is allowed if
|
||||
// ACLs are disabled.
|
||||
//
|
||||
// NOTE(mitchellh): This is the same behavior as the agent authorize
|
||||
// endpoint. If this behavior is incorrect, we should also change it there
|
||||
// which is much more important.
|
||||
rule, err = s.srv.resolveToken("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reply.Allowed = true
|
||||
if rule != nil {
|
||||
reply.Allowed = rule.IntentionDefaultAllow()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -4,16 +4,20 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
ca "github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/types"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/hashicorp/raft"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
|
@ -210,6 +214,12 @@ func (s *Server) establishLeadership() error {
|
|||
|
||||
s.getOrCreateAutopilotConfig()
|
||||
s.autopilot.Start()
|
||||
|
||||
// todo(kyhavlov): start a goroutine here for handling periodic CA rotation
|
||||
if err := s.initializeCA(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.setConsistentReadReady()
|
||||
return nil
|
||||
}
|
||||
|
@ -226,6 +236,8 @@ func (s *Server) revokeLeadership() error {
|
|||
return err
|
||||
}
|
||||
|
||||
s.setCAProvider(nil, nil)
|
||||
|
||||
s.resetConsistentReadReady()
|
||||
s.autopilot.Stop()
|
||||
return nil
|
||||
|
@ -359,6 +371,185 @@ func (s *Server) getOrCreateAutopilotConfig() *autopilot.Config {
|
|||
return config
|
||||
}
|
||||
|
||||
// initializeCAConfig is used to initialize the CA config if necessary
|
||||
// when setting up the CA during establishLeadership
|
||||
func (s *Server) initializeCAConfig() (*structs.CAConfiguration, error) {
|
||||
state := s.fsm.State()
|
||||
_, config, err := state.CAConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if config != nil {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
config = s.config.CAConfig
|
||||
if config.ClusterID == "" {
|
||||
id, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.ClusterID = id
|
||||
}
|
||||
|
||||
req := structs.CARequest{
|
||||
Op: structs.CAOpSetConfig,
|
||||
Config: config,
|
||||
}
|
||||
if _, err = s.raftApply(structs.ConnectCARequestType, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// initializeCA sets up the CA provider when gaining leadership, bootstrapping
|
||||
// the root in the state store if necessary.
|
||||
func (s *Server) initializeCA() error {
|
||||
// Bail if connect isn't enabled.
|
||||
if !s.config.ConnectEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
conf, err := s.initializeCAConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the right provider based on the config
|
||||
provider, err := s.createCAProvider(conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the active root cert from the CA
|
||||
rootPEM, err := provider.ActiveRoot()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting root cert: %v", err)
|
||||
}
|
||||
|
||||
rootCA, err := parseCARoot(rootPEM, conf.Provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(banks): in the case that we've just gained leadership in an already
|
||||
// configured cluster. We really need to fetch RootCA from state to provide it
|
||||
// in setCAProvider. This matters because if the current active root has
|
||||
// intermediates, parsing the rootCA from only the root cert PEM above will
|
||||
// not include them and so leafs we sign will not bundle the intermediates.
|
||||
|
||||
s.setCAProvider(provider, rootCA)
|
||||
|
||||
// Check if the CA root is already initialized and exit if it is.
|
||||
// Every change to the CA after this initial bootstrapping should
|
||||
// be done through the rotation process.
|
||||
state := s.fsm.State()
|
||||
_, activeRoot, err := state.CARootActive(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if activeRoot != nil {
|
||||
if activeRoot.ID != rootCA.ID {
|
||||
// TODO(banks): this seems like a pretty catastrophic state to get into.
|
||||
// Shouldn't we do something stronger than warn and continue signing with
|
||||
// a key that's not the active CA according to the state?
|
||||
s.logger.Printf("[WARN] connect: CA root %q is not the active root (%q)", rootCA.ID, activeRoot.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the highest index
|
||||
idx, _, err := state.CARoots(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the root cert in raft
|
||||
resp, err := s.raftApply(structs.ConnectCARequestType, &structs.CARequest{
|
||||
Op: structs.CAOpSetRoots,
|
||||
Index: idx,
|
||||
Roots: []*structs.CARoot{rootCA},
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Printf("[ERR] connect: Apply failed %v", err)
|
||||
return err
|
||||
}
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return respErr
|
||||
}
|
||||
|
||||
s.logger.Printf("[INFO] connect: initialized CA with provider %q", conf.Provider)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseCARoot returns a filled-in structs.CARoot from a raw PEM value.
|
||||
func parseCARoot(pemValue, provider string) (*structs.CARoot, error) {
|
||||
id, err := connect.CalculateCertFingerprint(pemValue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing root fingerprint: %v", err)
|
||||
}
|
||||
rootCert, err := connect.ParseCert(pemValue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing root cert: %v", err)
|
||||
}
|
||||
return &structs.CARoot{
|
||||
ID: id,
|
||||
Name: fmt.Sprintf("%s CA Root Cert", strings.Title(provider)),
|
||||
SerialNumber: rootCert.SerialNumber.Uint64(),
|
||||
SigningKeyID: connect.HexString(rootCert.AuthorityKeyId),
|
||||
NotBefore: rootCert.NotBefore,
|
||||
NotAfter: rootCert.NotAfter,
|
||||
RootCert: pemValue,
|
||||
Active: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createProvider returns a connect CA provider from the given config.
|
||||
func (s *Server) createCAProvider(conf *structs.CAConfiguration) (ca.Provider, error) {
|
||||
switch conf.Provider {
|
||||
case structs.ConsulCAProvider:
|
||||
return ca.NewConsulProvider(conf.Config, &consulCADelegate{s})
|
||||
case structs.VaultCAProvider:
|
||||
return ca.NewVaultProvider(conf.Config, conf.ClusterID)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown CA provider %q", conf.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) getCAProvider() (ca.Provider, *structs.CARoot) {
|
||||
retries := 0
|
||||
var result ca.Provider
|
||||
var resultRoot *structs.CARoot
|
||||
for result == nil {
|
||||
s.caProviderLock.RLock()
|
||||
result = s.caProvider
|
||||
resultRoot = s.caProviderRoot
|
||||
s.caProviderLock.RUnlock()
|
||||
|
||||
// In cases where an agent is started with managed proxies, we may ask
|
||||
// for the provider before establishLeadership completes. If we're the
|
||||
// leader, then wait and get the provider again
|
||||
if result == nil && s.IsLeader() && retries < 10 {
|
||||
retries++
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return result, resultRoot
|
||||
}
|
||||
|
||||
func (s *Server) setCAProvider(newProvider ca.Provider, root *structs.CARoot) {
|
||||
s.caProviderLock.Lock()
|
||||
defer s.caProviderLock.Unlock()
|
||||
s.caProvider = newProvider
|
||||
s.caProviderRoot = root
|
||||
}
|
||||
|
||||
// reconcileReaped is used to reconcile nodes that have failed and been reaped
|
||||
// from Serf but remain in the catalog. This is done by looking for unknown nodes with serfHealth checks registered.
|
||||
// We generate a "reap" event to cause the node to be cleaned up.
|
||||
|
|
|
@ -354,7 +354,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
|
|||
}
|
||||
|
||||
// Execute the query for the local DC.
|
||||
if err := p.execute(query, reply); err != nil {
|
||||
if err := p.execute(query, reply, args.Connect); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -450,7 +450,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest,
|
|||
// by the query setup.
|
||||
if len(reply.Nodes) == 0 {
|
||||
wrapper := &queryServerWrapper{p.srv}
|
||||
if err := queryFailover(wrapper, query, args.Limit, args.QueryOptions, reply); err != nil {
|
||||
if err := queryFailover(wrapper, query, args, reply); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -479,7 +479,7 @@ func (p *PreparedQuery) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRe
|
|||
}
|
||||
|
||||
// Run the query locally to see what we can find.
|
||||
if err := p.execute(&args.Query, reply); err != nil {
|
||||
if err := p.execute(&args.Query, reply, args.Connect); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -509,9 +509,18 @@ func (p *PreparedQuery) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRe
|
|||
// execute runs a prepared query in the local DC without any failover. We don't
|
||||
// apply any sorting options or ACL checks at this level - it should be done up above.
|
||||
func (p *PreparedQuery) execute(query *structs.PreparedQuery,
|
||||
reply *structs.PreparedQueryExecuteResponse) error {
|
||||
reply *structs.PreparedQueryExecuteResponse,
|
||||
forceConnect bool) error {
|
||||
state := p.srv.fsm.State()
|
||||
_, nodes, err := state.CheckServiceNodes(nil, query.Service.Service)
|
||||
|
||||
// If we're requesting Connect-capable services, then switch the
|
||||
// lookup to be the Connect function.
|
||||
f := state.CheckServiceNodes
|
||||
if query.Service.Connect || forceConnect {
|
||||
f = state.CheckConnectServiceNodes
|
||||
}
|
||||
|
||||
_, nodes, err := f(nil, query.Service.Service)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -651,7 +660,7 @@ func (q *queryServerWrapper) ForwardDC(method, dc string, args interface{}, repl
|
|||
// queryFailover runs an algorithm to determine which DCs to try and then calls
|
||||
// them to try to locate alternative services.
|
||||
func queryFailover(q queryServer, query *structs.PreparedQuery,
|
||||
limit int, options structs.QueryOptions,
|
||||
args *structs.PreparedQueryExecuteRequest,
|
||||
reply *structs.PreparedQueryExecuteResponse) error {
|
||||
|
||||
// Pull the list of other DCs. This is sorted by RTT in case the user
|
||||
|
@ -719,8 +728,9 @@ func queryFailover(q queryServer, query *structs.PreparedQuery,
|
|||
remote := &structs.PreparedQueryExecuteRemoteRequest{
|
||||
Datacenter: dc,
|
||||
Query: *query,
|
||||
Limit: limit,
|
||||
QueryOptions: options,
|
||||
Limit: args.Limit,
|
||||
QueryOptions: args.QueryOptions,
|
||||
Connect: args.Connect,
|
||||
}
|
||||
if err := q.ForwardDC("PreparedQuery.ExecuteRemote", dc, remote, reply); err != nil {
|
||||
q.GetLogger().Printf("[WARN] consul.prepared_query: Failed querying for service '%s' in datacenter '%s': %s", query.Service.Service, dc, err)
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/hashicorp/consul/types"
|
||||
"github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/hashicorp/serf/coordinate"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPreparedQuery_Apply(t *testing.T) {
|
||||
|
@ -2617,6 +2618,159 @@ func TestPreparedQuery_Execute_ForwardLeader(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPreparedQuery_Execute_ConnectExact(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
dir1, s1 := testServer(t)
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
// Setup 3 services on 3 nodes: one is non-Connect, one is Connect native,
|
||||
// and one is a proxy to the non-Connect one.
|
||||
for i := 0; i < 3; i++ {
|
||||
req := structs.RegisterRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: fmt.Sprintf("node%d", i+1),
|
||||
Address: fmt.Sprintf("127.0.0.%d", i+1),
|
||||
Service: &structs.NodeService{
|
||||
Service: "foo",
|
||||
Port: 8000,
|
||||
},
|
||||
}
|
||||
|
||||
switch i {
|
||||
case 0:
|
||||
// Default do nothing
|
||||
|
||||
case 1:
|
||||
// Connect native
|
||||
req.Service.Connect.Native = true
|
||||
|
||||
case 2:
|
||||
// Connect proxy
|
||||
req.Service.Kind = structs.ServiceKindConnectProxy
|
||||
req.Service.ProxyDestination = req.Service.Service
|
||||
req.Service.Service = "proxy"
|
||||
}
|
||||
|
||||
var reply struct{}
|
||||
require.NoError(msgpackrpc.CallWithCodec(codec, "Catalog.Register", &req, &reply))
|
||||
}
|
||||
|
||||
// The query, start with connect disabled
|
||||
query := structs.PreparedQueryRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.PreparedQueryCreate,
|
||||
Query: &structs.PreparedQuery{
|
||||
Name: "test",
|
||||
Service: structs.ServiceQuery{
|
||||
Service: "foo",
|
||||
},
|
||||
DNS: structs.QueryDNSOptions{
|
||||
TTL: "10s",
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
|
||||
|
||||
// In the future we'll run updates
|
||||
query.Op = structs.PreparedQueryUpdate
|
||||
|
||||
// Run the registered query.
|
||||
{
|
||||
req := structs.PreparedQueryExecuteRequest{
|
||||
Datacenter: "dc1",
|
||||
QueryIDOrName: query.Query.ID,
|
||||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Execute", &req, &reply))
|
||||
|
||||
// Result should have two because it omits the proxy whose name
|
||||
// doesn't match the query.
|
||||
require.Len(reply.Nodes, 2)
|
||||
require.Equal(query.Query.Service.Service, reply.Service)
|
||||
require.Equal(query.Query.DNS, reply.DNS)
|
||||
require.True(reply.QueryMeta.KnownLeader, "queried leader")
|
||||
}
|
||||
|
||||
// Run with the Connect setting specified on the request
|
||||
{
|
||||
req := structs.PreparedQueryExecuteRequest{
|
||||
Datacenter: "dc1",
|
||||
QueryIDOrName: query.Query.ID,
|
||||
Connect: true,
|
||||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Execute", &req, &reply))
|
||||
|
||||
// Result should have two because we should get the native AND
|
||||
// the proxy (since the destination matches our service name).
|
||||
require.Len(reply.Nodes, 2)
|
||||
require.Equal(query.Query.Service.Service, reply.Service)
|
||||
require.Equal(query.Query.DNS, reply.DNS)
|
||||
require.True(reply.QueryMeta.KnownLeader, "queried leader")
|
||||
|
||||
// Make sure the native is the first one
|
||||
if !reply.Nodes[0].Service.Connect.Native {
|
||||
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
|
||||
}
|
||||
|
||||
require.True(reply.Nodes[0].Service.Connect.Native, "native")
|
||||
require.Equal(reply.Service, reply.Nodes[0].Service.Service)
|
||||
|
||||
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
|
||||
require.Equal(reply.Service, reply.Nodes[1].Service.ProxyDestination)
|
||||
}
|
||||
|
||||
// Update the query
|
||||
query.Query.Service.Connect = true
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
|
||||
|
||||
// Run the registered query.
|
||||
{
|
||||
req := structs.PreparedQueryExecuteRequest{
|
||||
Datacenter: "dc1",
|
||||
QueryIDOrName: query.Query.ID,
|
||||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Execute", &req, &reply))
|
||||
|
||||
// Result should have two because we should get the native AND
|
||||
// the proxy (since the destination matches our service name).
|
||||
require.Len(reply.Nodes, 2)
|
||||
require.Equal(query.Query.Service.Service, reply.Service)
|
||||
require.Equal(query.Query.DNS, reply.DNS)
|
||||
require.True(reply.QueryMeta.KnownLeader, "queried leader")
|
||||
|
||||
// Make sure the native is the first one
|
||||
if !reply.Nodes[0].Service.Connect.Native {
|
||||
reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0]
|
||||
}
|
||||
|
||||
require.True(reply.Nodes[0].Service.Connect.Native, "native")
|
||||
require.Equal(reply.Service, reply.Nodes[0].Service.Service)
|
||||
|
||||
require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind)
|
||||
require.Equal(reply.Service, reply.Nodes[1].Service.ProxyDestination)
|
||||
}
|
||||
|
||||
// Unset the query
|
||||
query.Query.Service.Connect = false
|
||||
require.NoError(msgpackrpc.CallWithCodec(
|
||||
codec, "PreparedQuery.Apply", &query, &query.Query.ID))
|
||||
}
|
||||
|
||||
func TestPreparedQuery_tagFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
testNodes := func() structs.CheckServiceNodes {
|
||||
|
@ -2820,7 +2974,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 0 || reply.Datacenter != "" || reply.Failovers != 0 {
|
||||
|
@ -2836,7 +2990,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply)
|
||||
err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply)
|
||||
if err == nil || !strings.Contains(err.Error(), "XXX") {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
|
@ -2853,7 +3007,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 0 || reply.Datacenter != "" || reply.Failovers != 0 {
|
||||
|
@ -2876,7 +3030,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -2904,7 +3058,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -2925,7 +3079,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 0 ||
|
||||
|
@ -2954,7 +3108,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -2983,7 +3137,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -3012,7 +3166,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -3047,7 +3201,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -3079,7 +3233,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
@ -3115,7 +3269,10 @@ func TestPreparedQuery_queryFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
var reply structs.PreparedQueryExecuteResponse
|
||||
if err := queryFailover(mock, query, 5, structs.QueryOptions{RequireConsistent: true}, &reply); err != nil {
|
||||
if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{
|
||||
Limit: 5,
|
||||
QueryOptions: structs.QueryOptions{RequireConsistent: true},
|
||||
}, &reply); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(reply.Nodes) != 3 ||
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
ca "github.com/hashicorp/consul/agent/connect/ca"
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/consul/agent/consul/fsm"
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
|
@ -96,6 +97,16 @@ type Server struct {
|
|||
// autopilotWaitGroup is used to block until Autopilot shuts down.
|
||||
autopilotWaitGroup sync.WaitGroup
|
||||
|
||||
// caProvider is the current CA provider in use for Connect. This is
|
||||
// only non-nil when we are the leader.
|
||||
caProvider ca.Provider
|
||||
// caProviderRoot is the CARoot that was stored along with the ca.Provider
|
||||
// active. It's only updated in lock-step with the caProvider. This prevents
|
||||
// races between state updates to active roots and the fetch of the provider
|
||||
// instance.
|
||||
caProviderRoot *structs.CARoot
|
||||
caProviderLock sync.RWMutex
|
||||
|
||||
// Consul configuration
|
||||
config *Config
|
||||
|
||||
|
|
|
@ -4,7 +4,9 @@ func init() {
|
|||
registerEndpoint(func(s *Server) interface{} { return &ACL{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &Catalog{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return NewCoordinate(s) })
|
||||
registerEndpoint(func(s *Server) interface{} { return &ConnectCA{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &Health{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &Intention{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &Internal{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &KVS{s} })
|
||||
registerEndpoint(func(s *Server) interface{} { return &Operator{s} })
|
||||
|
|
|
@ -10,7 +10,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/agent/token"
|
||||
"github.com/hashicorp/consul/lib/freeport"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
|
@ -91,6 +93,17 @@ func testServerConfig(t *testing.T) (string, *Config) {
|
|||
// looks like several depend on it.
|
||||
config.RPCHoldTimeout = 5 * time.Second
|
||||
|
||||
config.ConnectEnabled = true
|
||||
config.CAConfig = &structs.CAConfiguration{
|
||||
ClusterID: connect.TestClusterID,
|
||||
Provider: structs.ConsulCAProvider,
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": "",
|
||||
"RootCert": "",
|
||||
"RotationPeriod": 90 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
|
||||
return dir, config
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,10 @@ import (
|
|||
"github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
const (
|
||||
servicesTableName = "services"
|
||||
)
|
||||
|
||||
// nodesTableSchema returns a new table schema used for storing node
|
||||
// information.
|
||||
func nodesTableSchema() *memdb.TableSchema {
|
||||
|
@ -87,6 +91,12 @@ func servicesTableSchema() *memdb.TableSchema {
|
|||
Lowercase: true,
|
||||
},
|
||||
},
|
||||
"connect": &memdb.IndexSchema{
|
||||
Name: "connect",
|
||||
AllowMissing: true,
|
||||
Unique: false,
|
||||
Indexer: &IndexConnectService{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -779,15 +789,39 @@ func maxIndexForService(tx *memdb.Txn, serviceName string, checks bool) uint64 {
|
|||
return maxIndexTxn(tx, "nodes", "services")
|
||||
}
|
||||
|
||||
// ConnectServiceNodes returns the nodes associated with a Connect
|
||||
// compatible destination for the given service name. This will include
|
||||
// both proxies and native integrations.
|
||||
func (s *Store) ConnectServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.ServiceNodes, error) {
|
||||
return s.serviceNodes(ws, serviceName, true)
|
||||
}
|
||||
|
||||
// ServiceNodes returns the nodes associated with a given service name.
|
||||
func (s *Store) ServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.ServiceNodes, error) {
|
||||
return s.serviceNodes(ws, serviceName, false)
|
||||
}
|
||||
|
||||
func (s *Store) serviceNodes(ws memdb.WatchSet, serviceName string, connect bool) (uint64, structs.ServiceNodes, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexForService(tx, serviceName, false)
|
||||
|
||||
// Function for lookup
|
||||
var f func() (memdb.ResultIterator, error)
|
||||
if !connect {
|
||||
f = func() (memdb.ResultIterator, error) {
|
||||
return tx.Get("services", "service", serviceName)
|
||||
}
|
||||
} else {
|
||||
f = func() (memdb.ResultIterator, error) {
|
||||
return tx.Get("services", "connect", serviceName)
|
||||
}
|
||||
}
|
||||
|
||||
// List all the services.
|
||||
services, err := tx.Get("services", "service", serviceName)
|
||||
services, err := f()
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed service lookup: %s", err)
|
||||
}
|
||||
|
@ -1479,14 +1513,36 @@ func (s *Store) deleteCheckTxn(tx *memdb.Txn, idx uint64, node string, checkID t
|
|||
|
||||
// CheckServiceNodes is used to query all nodes and checks for a given service.
|
||||
func (s *Store) CheckServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.CheckServiceNodes, error) {
|
||||
return s.checkServiceNodes(ws, serviceName, false)
|
||||
}
|
||||
|
||||
// CheckConnectServiceNodes is used to query all nodes and checks for Connect
|
||||
// compatible endpoints for a given service.
|
||||
func (s *Store) CheckConnectServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.CheckServiceNodes, error) {
|
||||
return s.checkServiceNodes(ws, serviceName, true)
|
||||
}
|
||||
|
||||
func (s *Store) checkServiceNodes(ws memdb.WatchSet, serviceName string, connect bool) (uint64, structs.CheckServiceNodes, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexForService(tx, serviceName, true)
|
||||
|
||||
// Function for lookup
|
||||
var f func() (memdb.ResultIterator, error)
|
||||
if !connect {
|
||||
f = func() (memdb.ResultIterator, error) {
|
||||
return tx.Get("services", "service", serviceName)
|
||||
}
|
||||
} else {
|
||||
f = func() (memdb.ResultIterator, error) {
|
||||
return tx.Get("services", "connect", serviceName)
|
||||
}
|
||||
}
|
||||
|
||||
// Query the state store for the service.
|
||||
iter, err := tx.Get("services", "service", serviceName)
|
||||
iter, err := f()
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed service lookup: %s", err)
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/hashicorp/go-memdb"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func makeRandomNodeID(t *testing.T) types.NodeID {
|
||||
|
@ -981,6 +982,35 @@ func TestStateStore_EnsureService(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestStateStore_EnsureService_connectProxy(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Create the service registration.
|
||||
ns1 := &structs.NodeService{
|
||||
Kind: structs.ServiceKindConnectProxy,
|
||||
ID: "connect-proxy",
|
||||
Service: "connect-proxy",
|
||||
Address: "1.1.1.1",
|
||||
Port: 1111,
|
||||
ProxyDestination: "foo",
|
||||
}
|
||||
|
||||
// Service successfully registers into the state store.
|
||||
testRegisterNode(t, s, 0, "node1")
|
||||
assert.Nil(s.EnsureService(10, "node1", ns1))
|
||||
|
||||
// Retrieve and verify
|
||||
_, out, err := s.NodeServices(nil, "node1")
|
||||
assert.Nil(err)
|
||||
assert.NotNil(out)
|
||||
assert.Len(out.Services, 1)
|
||||
|
||||
expect1 := *ns1
|
||||
expect1.CreateIndex, expect1.ModifyIndex = 10, 10
|
||||
assert.Equal(&expect1, out.Services["connect-proxy"])
|
||||
}
|
||||
|
||||
func TestStateStore_Services(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
|
@ -1542,6 +1572,51 @@ func TestStateStore_DeleteService(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestStateStore_ConnectServiceNodes(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Listing with no results returns an empty list.
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, nodes, err := s.ConnectServiceNodes(ws, "db")
|
||||
assert.Nil(err)
|
||||
assert.Equal(idx, uint64(0))
|
||||
assert.Len(nodes, 0)
|
||||
|
||||
// Create some nodes and services.
|
||||
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
|
||||
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
|
||||
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
|
||||
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
|
||||
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
|
||||
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
|
||||
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "native-db", Service: "db", Connect: structs.ServiceConnect{Native: true}}))
|
||||
assert.Nil(s.EnsureService(17, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}))
|
||||
assert.True(watchFired(ws))
|
||||
|
||||
// Read everything back.
|
||||
ws = memdb.NewWatchSet()
|
||||
idx, nodes, err = s.ConnectServiceNodes(ws, "db")
|
||||
assert.Nil(err)
|
||||
assert.Equal(idx, uint64(idx))
|
||||
assert.Len(nodes, 3)
|
||||
|
||||
for _, n := range nodes {
|
||||
assert.True(
|
||||
n.ServiceKind == structs.ServiceKindConnectProxy ||
|
||||
n.ServiceConnect.Native,
|
||||
"either proxy or connect native")
|
||||
}
|
||||
|
||||
// Registering some unrelated node should not fire the watch.
|
||||
testRegisterNode(t, s, 17, "nope")
|
||||
assert.False(watchFired(ws))
|
||||
|
||||
// But removing a node with the "db" service should fire the watch.
|
||||
assert.Nil(s.DeleteNode(18, "bar"))
|
||||
assert.True(watchFired(ws))
|
||||
}
|
||||
|
||||
func TestStateStore_Service_Snapshot(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
|
@ -2457,6 +2532,48 @@ func TestStateStore_CheckServiceNodes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestStateStore_CheckConnectServiceNodes(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Listing with no results returns an empty list.
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, nodes, err := s.CheckConnectServiceNodes(ws, "db")
|
||||
assert.Nil(err)
|
||||
assert.Equal(idx, uint64(0))
|
||||
assert.Len(nodes, 0)
|
||||
|
||||
// Create some nodes and services.
|
||||
assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"}))
|
||||
assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"}))
|
||||
assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000}))
|
||||
assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000}))
|
||||
assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
|
||||
assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000}))
|
||||
assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001}))
|
||||
assert.True(watchFired(ws))
|
||||
|
||||
// Register node checks
|
||||
testRegisterCheck(t, s, 17, "foo", "", "check1", api.HealthPassing)
|
||||
testRegisterCheck(t, s, 18, "bar", "", "check2", api.HealthPassing)
|
||||
|
||||
// Register checks against the services.
|
||||
testRegisterCheck(t, s, 19, "foo", "db", "check3", api.HealthPassing)
|
||||
testRegisterCheck(t, s, 20, "bar", "proxy", "check4", api.HealthPassing)
|
||||
|
||||
// Read everything back.
|
||||
ws = memdb.NewWatchSet()
|
||||
idx, nodes, err = s.CheckConnectServiceNodes(ws, "db")
|
||||
assert.Nil(err)
|
||||
assert.Equal(idx, uint64(idx))
|
||||
assert.Len(nodes, 2)
|
||||
|
||||
for _, n := range nodes {
|
||||
assert.Equal(structs.ServiceKindConnectProxy, n.Service.Kind)
|
||||
assert.Equal("db", n.Service.ProxyDestination)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCheckServiceNodes(b *testing.B) {
|
||||
s, err := NewStateStore(nil)
|
||||
if err != nil {
|
||||
|
|
|
@ -0,0 +1,435 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
const (
|
||||
caBuiltinProviderTableName = "connect-ca-builtin"
|
||||
caConfigTableName = "connect-ca-config"
|
||||
caRootTableName = "connect-ca-roots"
|
||||
)
|
||||
|
||||
// caBuiltinProviderTableSchema returns a new table schema used for storing
|
||||
// the built-in CA provider's state for connect. This is only used by
|
||||
// the internal Consul CA provider.
|
||||
func caBuiltinProviderTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: caBuiltinProviderTableName,
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
"id": &memdb.IndexSchema{
|
||||
Name: "id",
|
||||
AllowMissing: false,
|
||||
Unique: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "ID",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// caConfigTableSchema returns a new table schema used for storing
|
||||
// the CA config for Connect.
|
||||
func caConfigTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: caConfigTableName,
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
// This table only stores one row, so this just ignores the ID field
|
||||
// and always overwrites the same config object.
|
||||
"id": &memdb.IndexSchema{
|
||||
Name: "id",
|
||||
AllowMissing: true,
|
||||
Unique: true,
|
||||
Indexer: &memdb.ConditionalIndex{
|
||||
Conditional: func(obj interface{}) (bool, error) { return true, nil },
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// caRootTableSchema returns a new table schema used for storing
|
||||
// CA roots for Connect.
|
||||
func caRootTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: caRootTableName,
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
"id": &memdb.IndexSchema{
|
||||
Name: "id",
|
||||
AllowMissing: false,
|
||||
Unique: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "ID",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
registerSchema(caBuiltinProviderTableSchema)
|
||||
registerSchema(caConfigTableSchema)
|
||||
registerSchema(caRootTableSchema)
|
||||
}
|
||||
|
||||
// CAConfig is used to pull the CA config from the snapshot.
|
||||
func (s *Snapshot) CAConfig() (*structs.CAConfiguration, error) {
|
||||
c, err := s.tx.First(caConfigTableName, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config, ok := c.(*structs.CAConfiguration)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// CAConfig is used when restoring from a snapshot.
|
||||
func (s *Restore) CAConfig(config *structs.CAConfiguration) error {
|
||||
if err := s.tx.Insert(caConfigTableName, config); err != nil {
|
||||
return fmt.Errorf("failed restoring CA config: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CAConfig is used to get the current CA configuration.
|
||||
func (s *Store) CAConfig() (uint64, *structs.CAConfiguration, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the CA config
|
||||
c, err := tx.First(caConfigTableName, "id")
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed CA config lookup: %s", err)
|
||||
}
|
||||
|
||||
config, ok := c.(*structs.CAConfiguration)
|
||||
if !ok {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
return config.ModifyIndex, config, nil
|
||||
}
|
||||
|
||||
// CASetConfig is used to set the current CA configuration.
|
||||
func (s *Store) CASetConfig(idx uint64, config *structs.CAConfiguration) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
if err := s.caSetConfigTxn(idx, tx, config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CACheckAndSetConfig is used to try updating the CA configuration with a
|
||||
// given Raft index. If the CAS index specified is not equal to the last observed index
|
||||
// for the config, then the call is a noop,
|
||||
func (s *Store) CACheckAndSetConfig(idx, cidx uint64, config *structs.CAConfiguration) (bool, error) {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
// Check for an existing config
|
||||
existing, err := tx.First(caConfigTableName, "id")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed CA config lookup: %s", err)
|
||||
}
|
||||
|
||||
// If the existing index does not match the provided CAS
|
||||
// index arg, then we shouldn't update anything and can safely
|
||||
// return early here.
|
||||
e, ok := existing.(*structs.CAConfiguration)
|
||||
if !ok || e.ModifyIndex != cidx {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := s.caSetConfigTxn(idx, tx, config); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *Store) caSetConfigTxn(idx uint64, tx *memdb.Txn, config *structs.CAConfiguration) error {
|
||||
// Check for an existing config
|
||||
prev, err := tx.First(caConfigTableName, "id")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed CA config lookup: %s", err)
|
||||
}
|
||||
|
||||
// Set the indexes, prevent the cluster ID from changing.
|
||||
if prev != nil {
|
||||
existing := prev.(*structs.CAConfiguration)
|
||||
config.CreateIndex = existing.CreateIndex
|
||||
config.ClusterID = existing.ClusterID
|
||||
} else {
|
||||
config.CreateIndex = idx
|
||||
}
|
||||
config.ModifyIndex = idx
|
||||
|
||||
if err := tx.Insert(caConfigTableName, config); err != nil {
|
||||
return fmt.Errorf("failed updating CA config: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CARoots is used to pull all the CA roots for the snapshot.
|
||||
func (s *Snapshot) CARoots() (structs.CARoots, error) {
|
||||
ixns, err := s.tx.Get(caRootTableName, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret structs.CARoots
|
||||
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
|
||||
ret = append(ret, wrapped.(*structs.CARoot))
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// CARoots is used when restoring from a snapshot.
|
||||
func (s *Restore) CARoot(r *structs.CARoot) error {
|
||||
// Insert
|
||||
if err := s.tx.Insert(caRootTableName, r); err != nil {
|
||||
return fmt.Errorf("failed restoring CA root: %s", err)
|
||||
}
|
||||
if err := indexUpdateMaxTxn(s.tx, r.ModifyIndex, caRootTableName); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CARoots returns the list of all CA roots.
|
||||
func (s *Store) CARoots(ws memdb.WatchSet) (uint64, structs.CARoots, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the index
|
||||
idx := maxIndexTxn(tx, caRootTableName)
|
||||
|
||||
// Get all
|
||||
iter, err := tx.Get(caRootTableName, "id")
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed CA root lookup: %s", err)
|
||||
}
|
||||
ws.Add(iter.WatchCh())
|
||||
|
||||
var results structs.CARoots
|
||||
for v := iter.Next(); v != nil; v = iter.Next() {
|
||||
results = append(results, v.(*structs.CARoot))
|
||||
}
|
||||
return idx, results, nil
|
||||
}
|
||||
|
||||
// CARootActive returns the currently active CARoot.
|
||||
func (s *Store) CARootActive(ws memdb.WatchSet) (uint64, *structs.CARoot, error) {
|
||||
// Get all the roots since there should never be that many and just
|
||||
// do the filtering in this method.
|
||||
var result *structs.CARoot
|
||||
idx, roots, err := s.CARoots(ws)
|
||||
if err == nil {
|
||||
for _, r := range roots {
|
||||
if r.Active {
|
||||
result = r
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return idx, result, err
|
||||
}
|
||||
|
||||
// CARootSetCAS sets the current CA root state using a check-and-set operation.
|
||||
// On success, this will replace the previous set of CARoots completely with
|
||||
// the given set of roots.
|
||||
//
|
||||
// The first boolean result returns whether the transaction succeeded or not.
|
||||
func (s *Store) CARootSetCAS(idx, cidx uint64, rs []*structs.CARoot) (bool, error) {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
// There must be exactly one active CA root.
|
||||
activeCount := 0
|
||||
for _, r := range rs {
|
||||
if r.Active {
|
||||
activeCount++
|
||||
}
|
||||
}
|
||||
if activeCount != 1 {
|
||||
return false, fmt.Errorf("there must be exactly one active CA")
|
||||
}
|
||||
|
||||
// Get the current max index
|
||||
if midx := maxIndexTxn(tx, caRootTableName); midx != cidx {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Go through and find any existing matching CAs so we can preserve and
|
||||
// update their Create/ModifyIndex values.
|
||||
for _, r := range rs {
|
||||
if r.ID == "" {
|
||||
return false, ErrMissingCARootID
|
||||
}
|
||||
|
||||
existing, err := tx.First(caRootTableName, "id", r.ID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed CA root lookup: %s", err)
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
r.CreateIndex = existing.(*structs.CARoot).CreateIndex
|
||||
} else {
|
||||
r.CreateIndex = idx
|
||||
}
|
||||
r.ModifyIndex = idx
|
||||
}
|
||||
|
||||
// Delete all
|
||||
_, err := tx.DeleteAll(caRootTableName, "id")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Insert all
|
||||
for _, r := range rs {
|
||||
if err := tx.Insert(caRootTableName, r); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Update the index
|
||||
if err := tx.Insert("index", &IndexEntry{caRootTableName, idx}); err != nil {
|
||||
return false, fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// CAProviderState is used to pull the built-in provider states from the snapshot.
|
||||
func (s *Snapshot) CAProviderState() ([]*structs.CAConsulProviderState, error) {
|
||||
ixns, err := s.tx.Get(caBuiltinProviderTableName, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret []*structs.CAConsulProviderState
|
||||
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
|
||||
ret = append(ret, wrapped.(*structs.CAConsulProviderState))
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// CAProviderState is used when restoring from a snapshot.
|
||||
func (s *Restore) CAProviderState(state *structs.CAConsulProviderState) error {
|
||||
if err := s.tx.Insert(caBuiltinProviderTableName, state); err != nil {
|
||||
return fmt.Errorf("failed restoring built-in CA state: %s", err)
|
||||
}
|
||||
if err := indexUpdateMaxTxn(s.tx, state.ModifyIndex, caBuiltinProviderTableName); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CAProviderState is used to get the Consul CA provider state for the given ID.
|
||||
func (s *Store) CAProviderState(id string) (uint64, *structs.CAConsulProviderState, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the index
|
||||
idx := maxIndexTxn(tx, caBuiltinProviderTableName)
|
||||
|
||||
// Get the provider config
|
||||
c, err := tx.First(caBuiltinProviderTableName, "id", id)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed built-in CA state lookup: %s", err)
|
||||
}
|
||||
|
||||
state, ok := c.(*structs.CAConsulProviderState)
|
||||
if !ok {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
return idx, state, nil
|
||||
}
|
||||
|
||||
// CASetProviderState is used to set the current built-in CA provider state.
|
||||
func (s *Store) CASetProviderState(idx uint64, state *structs.CAConsulProviderState) (bool, error) {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
// Check for an existing config
|
||||
existing, err := tx.First(caBuiltinProviderTableName, "id", state.ID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed built-in CA state lookup: %s", err)
|
||||
}
|
||||
|
||||
// Set the indexes.
|
||||
if existing != nil {
|
||||
state.CreateIndex = existing.(*structs.CAConsulProviderState).CreateIndex
|
||||
} else {
|
||||
state.CreateIndex = idx
|
||||
}
|
||||
state.ModifyIndex = idx
|
||||
|
||||
if err := tx.Insert(caBuiltinProviderTableName, state); err != nil {
|
||||
return false, fmt.Errorf("failed updating built-in CA state: %s", err)
|
||||
}
|
||||
|
||||
// Update the index
|
||||
if err := tx.Insert("index", &IndexEntry{caBuiltinProviderTableName, idx}); err != nil {
|
||||
return false, fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// CADeleteProviderState is used to remove the built-in Consul CA provider
|
||||
// state for the given ID.
|
||||
func (s *Store) CADeleteProviderState(id string) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the index
|
||||
idx := maxIndexTxn(tx, caBuiltinProviderTableName)
|
||||
|
||||
// Check for an existing config
|
||||
existing, err := tx.First(caBuiltinProviderTableName, "id", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed built-in CA state lookup: %s", err)
|
||||
}
|
||||
if existing == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
providerState := existing.(*structs.CAConsulProviderState)
|
||||
|
||||
// Do the delete and update the index
|
||||
if err := tx.Delete(caBuiltinProviderTableName, providerState); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Insert("index", &IndexEntry{caBuiltinProviderTableName, idx}); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,449 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/connect"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStore_CAConfig(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
expected := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": "asdf",
|
||||
"RootCert": "qwer",
|
||||
"RotationPeriod": 90 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.CASetConfig(0, expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
idx, config, err := s.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if idx != 0 {
|
||||
t.Fatalf("bad: %d", idx)
|
||||
}
|
||||
if !reflect.DeepEqual(expected, config) {
|
||||
t.Fatalf("bad: %#v, %#v", expected, config)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_CAConfigCAS(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
|
||||
expected := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
}
|
||||
|
||||
if err := s.CASetConfig(0, expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Do an extra operation to move the index up by 1 for the
|
||||
// check-and-set operation after this
|
||||
if err := s.CASetConfig(1, expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Do a CAS with an index lower than the entry
|
||||
ok, err := s.CACheckAndSetConfig(2, 0, &structs.CAConfiguration{
|
||||
Provider: "static",
|
||||
})
|
||||
if ok || err != nil {
|
||||
t.Fatalf("expected (false, nil), got: (%v, %#v)", ok, err)
|
||||
}
|
||||
|
||||
// Check that the index is untouched and the entry
|
||||
// has not been updated.
|
||||
idx, config, err := s.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if idx != 1 {
|
||||
t.Fatalf("bad: %d", idx)
|
||||
}
|
||||
if config.Provider != "consul" {
|
||||
t.Fatalf("bad: %#v", config)
|
||||
}
|
||||
|
||||
// Do another CAS, this time with the correct index
|
||||
ok, err = s.CACheckAndSetConfig(2, 1, &structs.CAConfiguration{
|
||||
Provider: "static",
|
||||
})
|
||||
if !ok || err != nil {
|
||||
t.Fatalf("expected (true, nil), got: (%v, %#v)", ok, err)
|
||||
}
|
||||
|
||||
// Make sure the config was updated
|
||||
idx, config, err = s.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if idx != 2 {
|
||||
t.Fatalf("bad: %d", idx)
|
||||
}
|
||||
if config.Provider != "static" {
|
||||
t.Fatalf("bad: %#v", config)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_CAConfig_Snapshot_Restore(t *testing.T) {
|
||||
s := testStateStore(t)
|
||||
before := &structs.CAConfiguration{
|
||||
Provider: "consul",
|
||||
Config: map[string]interface{}{
|
||||
"PrivateKey": "asdf",
|
||||
"RootCert": "qwer",
|
||||
"RotationPeriod": 90 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
if err := s.CASetConfig(99, before); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
snap := s.Snapshot()
|
||||
defer snap.Close()
|
||||
|
||||
after := &structs.CAConfiguration{
|
||||
Provider: "static",
|
||||
Config: map[string]interface{}{},
|
||||
}
|
||||
if err := s.CASetConfig(100, after); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
snapped, err := snap.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
verify.Values(t, "", before, snapped)
|
||||
|
||||
s2 := testStateStore(t)
|
||||
restore := s2.Restore()
|
||||
if err := restore.CAConfig(snapped); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
restore.Commit()
|
||||
|
||||
idx, res, err := s2.CAConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if idx != 99 {
|
||||
t.Fatalf("bad index: %d", idx)
|
||||
}
|
||||
verify.Values(t, "", before, res)
|
||||
}
|
||||
|
||||
func TestStore_CARootSetList(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call list to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
|
||||
// Build a valid value
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
|
||||
// Set
|
||||
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
|
||||
assert.Nil(err)
|
||||
assert.True(ok)
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(s.maxIndex(caRootTableName), uint64(1))
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Read it back out and verify it.
|
||||
expected := *ca1
|
||||
expected.RaftIndex = structs.RaftIndex{
|
||||
CreateIndex: 1,
|
||||
ModifyIndex: 1,
|
||||
}
|
||||
|
||||
ws = memdb.NewWatchSet()
|
||||
_, roots, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
assert.Len(roots, 1)
|
||||
actual := roots[0]
|
||||
assert.Equal(&expected, actual)
|
||||
}
|
||||
|
||||
func TestStore_CARootSet_emptyID(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call list to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
|
||||
// Build a valid value
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca1.ID = ""
|
||||
|
||||
// Set
|
||||
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), ErrMissingCARootID.Error())
|
||||
assert.False(ok)
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(s.maxIndex(caRootTableName), uint64(0))
|
||||
assert.False(watchFired(ws), "watch fired")
|
||||
|
||||
// Read it back out and verify it.
|
||||
ws = memdb.NewWatchSet()
|
||||
_, roots, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
assert.Len(roots, 0)
|
||||
}
|
||||
|
||||
func TestStore_CARootSet_noActive(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call list to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
|
||||
// Build a valid value
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca1.Active = false
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
ca2.Active = false
|
||||
|
||||
// Set
|
||||
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "exactly one active")
|
||||
assert.False(ok)
|
||||
}
|
||||
|
||||
func TestStore_CARootSet_multipleActive(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call list to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.CARoots(ws)
|
||||
assert.Nil(err)
|
||||
|
||||
// Build a valid value
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
|
||||
// Set
|
||||
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "exactly one active")
|
||||
assert.False(ok)
|
||||
}
|
||||
|
||||
func TestStore_CARootActive_valid(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Build a valid value
|
||||
ca1 := connect.TestCA(t, nil)
|
||||
ca1.Active = false
|
||||
ca2 := connect.TestCA(t, nil)
|
||||
ca3 := connect.TestCA(t, nil)
|
||||
ca3.Active = false
|
||||
|
||||
// Set
|
||||
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2, ca3})
|
||||
assert.Nil(err)
|
||||
assert.True(ok)
|
||||
|
||||
// Query
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, res, err := s.CARootActive(ws)
|
||||
assert.Equal(idx, uint64(1))
|
||||
assert.Nil(err)
|
||||
assert.NotNil(res)
|
||||
assert.Equal(ca2.ID, res.ID)
|
||||
}
|
||||
|
||||
// Test that querying the active CA returns the correct value.
|
||||
func TestStore_CARootActive_none(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Querying with no results returns nil.
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, res, err := s.CARootActive(ws)
|
||||
assert.Equal(idx, uint64(0))
|
||||
assert.Nil(res)
|
||||
assert.Nil(err)
|
||||
}
|
||||
|
||||
func TestStore_CARoot_Snapshot_Restore(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Create some intentions.
|
||||
roots := structs.CARoots{
|
||||
connect.TestCA(t, nil),
|
||||
connect.TestCA(t, nil),
|
||||
connect.TestCA(t, nil),
|
||||
}
|
||||
for _, r := range roots[1:] {
|
||||
r.Active = false
|
||||
}
|
||||
|
||||
// Force the sort order of the UUIDs before we create them so the
|
||||
// order is deterministic.
|
||||
id := testUUID()
|
||||
roots[0].ID = "a" + id[1:]
|
||||
roots[1].ID = "b" + id[1:]
|
||||
roots[2].ID = "c" + id[1:]
|
||||
|
||||
// Now create
|
||||
ok, err := s.CARootSetCAS(1, 0, roots)
|
||||
assert.Nil(err)
|
||||
assert.True(ok)
|
||||
|
||||
// Snapshot the queries.
|
||||
snap := s.Snapshot()
|
||||
defer snap.Close()
|
||||
|
||||
// Alter the real state store.
|
||||
ok, err = s.CARootSetCAS(2, 1, roots[:1])
|
||||
assert.Nil(err)
|
||||
assert.True(ok)
|
||||
|
||||
// Verify the snapshot.
|
||||
assert.Equal(snap.LastIndex(), uint64(1))
|
||||
dump, err := snap.CARoots()
|
||||
assert.Nil(err)
|
||||
assert.Equal(roots, dump)
|
||||
|
||||
// Restore the values into a new state store.
|
||||
func() {
|
||||
s := testStateStore(t)
|
||||
restore := s.Restore()
|
||||
for _, r := range dump {
|
||||
assert.Nil(restore.CARoot(r))
|
||||
}
|
||||
restore.Commit()
|
||||
|
||||
// Read the restored values back out and verify that they match.
|
||||
idx, actual, err := s.CARoots(nil)
|
||||
assert.Nil(err)
|
||||
assert.Equal(idx, uint64(2))
|
||||
assert.Equal(roots, actual)
|
||||
}()
|
||||
}
|
||||
|
||||
func TestStore_CABuiltinProvider(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
{
|
||||
expected := &structs.CAConsulProviderState{
|
||||
ID: "foo",
|
||||
PrivateKey: "a",
|
||||
RootCert: "b",
|
||||
}
|
||||
|
||||
ok, err := s.CASetProviderState(0, expected)
|
||||
assert.NoError(err)
|
||||
assert.True(ok)
|
||||
|
||||
idx, state, err := s.CAProviderState(expected.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(0))
|
||||
assert.Equal(expected, state)
|
||||
}
|
||||
|
||||
{
|
||||
expected := &structs.CAConsulProviderState{
|
||||
ID: "bar",
|
||||
PrivateKey: "c",
|
||||
RootCert: "d",
|
||||
}
|
||||
|
||||
ok, err := s.CASetProviderState(1, expected)
|
||||
assert.NoError(err)
|
||||
assert.True(ok)
|
||||
|
||||
idx, state, err := s.CAProviderState(expected.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(1))
|
||||
assert.Equal(expected, state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Create multiple state entries.
|
||||
before := []*structs.CAConsulProviderState{
|
||||
{
|
||||
ID: "bar",
|
||||
PrivateKey: "y",
|
||||
RootCert: "z",
|
||||
},
|
||||
{
|
||||
ID: "foo",
|
||||
PrivateKey: "a",
|
||||
RootCert: "b",
|
||||
},
|
||||
}
|
||||
|
||||
for i, state := range before {
|
||||
ok, err := s.CASetProviderState(uint64(98+i), state)
|
||||
assert.NoError(err)
|
||||
assert.True(ok)
|
||||
}
|
||||
|
||||
// Take a snapshot.
|
||||
snap := s.Snapshot()
|
||||
defer snap.Close()
|
||||
|
||||
// Modify the state store.
|
||||
after := &structs.CAConsulProviderState{
|
||||
ID: "foo",
|
||||
PrivateKey: "c",
|
||||
RootCert: "d",
|
||||
}
|
||||
ok, err := s.CASetProviderState(100, after)
|
||||
assert.NoError(err)
|
||||
assert.True(ok)
|
||||
|
||||
snapped, err := snap.CAProviderState()
|
||||
assert.NoError(err)
|
||||
assert.Equal(before, snapped)
|
||||
|
||||
// Restore onto a new state store.
|
||||
s2 := testStateStore(t)
|
||||
restore := s2.Restore()
|
||||
for _, entry := range snapped {
|
||||
assert.NoError(restore.CAProviderState(entry))
|
||||
}
|
||||
restore.Commit()
|
||||
|
||||
// Verify the restored values match those from before the snapshot.
|
||||
for _, state := range before {
|
||||
idx, res, err := s2.CAProviderState(state.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(99))
|
||||
assert.Equal(state, res)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// IndexConnectService indexes a *struct.ServiceNode for querying by
|
||||
// services that support Connect to some target service. This will
|
||||
// properly index the proxy destination for proxies and the service name
|
||||
// for native services.
|
||||
type IndexConnectService struct{}
|
||||
|
||||
func (idx *IndexConnectService) FromObject(obj interface{}) (bool, []byte, error) {
|
||||
sn, ok := obj.(*structs.ServiceNode)
|
||||
if !ok {
|
||||
return false, nil, fmt.Errorf("Object must be ServiceNode, got %T", obj)
|
||||
}
|
||||
|
||||
var result []byte
|
||||
switch {
|
||||
case sn.ServiceKind == structs.ServiceKindConnectProxy:
|
||||
// For proxies, this service supports Connect for the destination
|
||||
result = []byte(strings.ToLower(sn.ServiceProxyDestination))
|
||||
|
||||
case sn.ServiceConnect.Native:
|
||||
// For native, this service supports Connect directly
|
||||
result = []byte(strings.ToLower(sn.ServiceName))
|
||||
|
||||
default:
|
||||
// Doesn't support Connect at all
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Return the result with the null terminator appended so we can
|
||||
// differentiate prefix vs. non-prefix matches.
|
||||
return true, append(result, '\x00'), nil
|
||||
}
|
||||
|
||||
func (idx *IndexConnectService) FromArgs(args ...interface{}) ([]byte, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, fmt.Errorf("must provide only a single argument")
|
||||
}
|
||||
|
||||
arg, ok := args[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
|
||||
}
|
||||
|
||||
// Add the null character as a terminator
|
||||
return append([]byte(strings.ToLower(arg)), '\x00'), nil
|
||||
}
|
|
@ -0,0 +1,122 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIndexConnectService_FromObject(t *testing.T) {
|
||||
cases := []struct {
|
||||
Name string
|
||||
Input interface{}
|
||||
ExpectMatch bool
|
||||
ExpectVal []byte
|
||||
ExpectErr string
|
||||
}{
|
||||
{
|
||||
"not a ServiceNode",
|
||||
42,
|
||||
false,
|
||||
nil,
|
||||
"ServiceNode",
|
||||
},
|
||||
|
||||
{
|
||||
"typical service, not native",
|
||||
&structs.ServiceNode{
|
||||
ServiceName: "db",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"typical service, is native",
|
||||
&structs.ServiceNode{
|
||||
ServiceName: "dB",
|
||||
ServiceConnect: structs.ServiceConnect{Native: true},
|
||||
},
|
||||
true,
|
||||
[]byte("db\x00"),
|
||||
"",
|
||||
},
|
||||
|
||||
{
|
||||
"proxy service",
|
||||
&structs.ServiceNode{
|
||||
ServiceKind: structs.ServiceKindConnectProxy,
|
||||
ServiceName: "db",
|
||||
ServiceProxyDestination: "fOo",
|
||||
},
|
||||
true,
|
||||
[]byte("foo\x00"),
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
var idx IndexConnectService
|
||||
match, val, err := idx.FromObject(tc.Input)
|
||||
if tc.ExpectErr != "" {
|
||||
require.Error(err)
|
||||
require.Contains(err.Error(), tc.ExpectErr)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
require.Equal(tc.ExpectMatch, match)
|
||||
require.Equal(tc.ExpectVal, val)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexConnectService_FromArgs(t *testing.T) {
|
||||
cases := []struct {
|
||||
Name string
|
||||
Args []interface{}
|
||||
ExpectVal []byte
|
||||
ExpectErr string
|
||||
}{
|
||||
{
|
||||
"multiple arguments",
|
||||
[]interface{}{"foo", "bar"},
|
||||
nil,
|
||||
"single",
|
||||
},
|
||||
|
||||
{
|
||||
"not a string",
|
||||
[]interface{}{42},
|
||||
nil,
|
||||
"must be a string",
|
||||
},
|
||||
|
||||
{
|
||||
"string",
|
||||
[]interface{}{"fOO"},
|
||||
[]byte("foo\x00"),
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
var idx IndexConnectService
|
||||
val, err := idx.FromArgs(tc.Args...)
|
||||
if tc.ExpectErr != "" {
|
||||
require.Error(err)
|
||||
require.Contains(err.Error(), tc.ExpectErr)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
require.Equal(tc.ExpectVal, val)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,366 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
const (
|
||||
intentionsTableName = "connect-intentions"
|
||||
)
|
||||
|
||||
// intentionsTableSchema returns a new table schema used for storing
|
||||
// intentions for Connect.
|
||||
func intentionsTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: intentionsTableName,
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
"id": &memdb.IndexSchema{
|
||||
Name: "id",
|
||||
AllowMissing: false,
|
||||
Unique: true,
|
||||
Indexer: &memdb.UUIDFieldIndex{
|
||||
Field: "ID",
|
||||
},
|
||||
},
|
||||
"destination": &memdb.IndexSchema{
|
||||
Name: "destination",
|
||||
AllowMissing: true,
|
||||
// This index is not unique since we need uniqueness across the whole
|
||||
// 4-tuple.
|
||||
Unique: false,
|
||||
Indexer: &memdb.CompoundIndex{
|
||||
Indexes: []memdb.Indexer{
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "DestinationNS",
|
||||
Lowercase: true,
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "DestinationName",
|
||||
Lowercase: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"source": &memdb.IndexSchema{
|
||||
Name: "source",
|
||||
AllowMissing: true,
|
||||
// This index is not unique since we need uniqueness across the whole
|
||||
// 4-tuple.
|
||||
Unique: false,
|
||||
Indexer: &memdb.CompoundIndex{
|
||||
Indexes: []memdb.Indexer{
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "SourceNS",
|
||||
Lowercase: true,
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "SourceName",
|
||||
Lowercase: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"source_destination": &memdb.IndexSchema{
|
||||
Name: "source_destination",
|
||||
AllowMissing: true,
|
||||
Unique: true,
|
||||
Indexer: &memdb.CompoundIndex{
|
||||
Indexes: []memdb.Indexer{
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "SourceNS",
|
||||
Lowercase: true,
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "SourceName",
|
||||
Lowercase: true,
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "DestinationNS",
|
||||
Lowercase: true,
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "DestinationName",
|
||||
Lowercase: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
registerSchema(intentionsTableSchema)
|
||||
}
|
||||
|
||||
// Intentions is used to pull all the intentions from the snapshot.
|
||||
func (s *Snapshot) Intentions() (structs.Intentions, error) {
|
||||
ixns, err := s.tx.Get(intentionsTableName, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret structs.Intentions
|
||||
for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() {
|
||||
ret = append(ret, wrapped.(*structs.Intention))
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Intention is used when restoring from a snapshot.
|
||||
func (s *Restore) Intention(ixn *structs.Intention) error {
|
||||
// Insert the intention
|
||||
if err := s.tx.Insert(intentionsTableName, ixn); err != nil {
|
||||
return fmt.Errorf("failed restoring intention: %s", err)
|
||||
}
|
||||
if err := indexUpdateMaxTxn(s.tx, ixn.ModifyIndex, intentionsTableName); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Intentions returns the list of all intentions.
|
||||
func (s *Store) Intentions(ws memdb.WatchSet) (uint64, structs.Intentions, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the index
|
||||
idx := maxIndexTxn(tx, intentionsTableName)
|
||||
if idx < 1 {
|
||||
idx = 1
|
||||
}
|
||||
|
||||
// Get all intentions
|
||||
iter, err := tx.Get(intentionsTableName, "id")
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
ws.Add(iter.WatchCh())
|
||||
|
||||
var results structs.Intentions
|
||||
for ixn := iter.Next(); ixn != nil; ixn = iter.Next() {
|
||||
results = append(results, ixn.(*structs.Intention))
|
||||
}
|
||||
|
||||
// Sort by precedence just because that's nicer and probably what most clients
|
||||
// want for presentation.
|
||||
sort.Sort(structs.IntentionPrecedenceSorter(results))
|
||||
|
||||
return idx, results, nil
|
||||
}
|
||||
|
||||
// IntentionSet creates or updates an intention.
|
||||
func (s *Store) IntentionSet(idx uint64, ixn *structs.Intention) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
if err := s.intentionSetTxn(tx, idx, ixn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// intentionSetTxn is the inner method used to insert an intention with
|
||||
// the proper indexes into the state store.
|
||||
func (s *Store) intentionSetTxn(tx *memdb.Txn, idx uint64, ixn *structs.Intention) error {
|
||||
// ID is required
|
||||
if ixn.ID == "" {
|
||||
return ErrMissingIntentionID
|
||||
}
|
||||
|
||||
// Ensure Precedence is populated correctly on "write"
|
||||
ixn.UpdatePrecedence()
|
||||
|
||||
// Check for an existing intention
|
||||
existing, err := tx.First(intentionsTableName, "id", ixn.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
if existing != nil {
|
||||
oldIxn := existing.(*structs.Intention)
|
||||
ixn.CreateIndex = oldIxn.CreateIndex
|
||||
ixn.CreatedAt = oldIxn.CreatedAt
|
||||
} else {
|
||||
ixn.CreateIndex = idx
|
||||
}
|
||||
ixn.ModifyIndex = idx
|
||||
|
||||
// Check for duplicates on the 4-tuple.
|
||||
duplicate, err := tx.First(intentionsTableName, "source_destination",
|
||||
ixn.SourceNS, ixn.SourceName, ixn.DestinationNS, ixn.DestinationName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
if duplicate != nil {
|
||||
dupIxn := duplicate.(*structs.Intention)
|
||||
// Same ID is OK - this is an update
|
||||
if dupIxn.ID != ixn.ID {
|
||||
return fmt.Errorf("duplicate intention found: %s", dupIxn.String())
|
||||
}
|
||||
}
|
||||
|
||||
// We always force meta to be non-nil so that we its an empty map.
|
||||
// This makes it easy for API responses to not nil-check this everywhere.
|
||||
if ixn.Meta == nil {
|
||||
ixn.Meta = make(map[string]string)
|
||||
}
|
||||
|
||||
// Insert
|
||||
if err := tx.Insert(intentionsTableName, ixn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Insert("index", &IndexEntry{intentionsTableName, idx}); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IntentionGet returns the given intention by ID.
|
||||
func (s *Store) IntentionGet(ws memdb.WatchSet, id string) (uint64, *structs.Intention, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexTxn(tx, intentionsTableName)
|
||||
if idx < 1 {
|
||||
idx = 1
|
||||
}
|
||||
|
||||
// Look up by its ID.
|
||||
watchCh, intention, err := tx.FirstWatch(intentionsTableName, "id", id)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
ws.Add(watchCh)
|
||||
|
||||
// Convert the interface{} if it is non-nil
|
||||
var result *structs.Intention
|
||||
if intention != nil {
|
||||
result = intention.(*structs.Intention)
|
||||
}
|
||||
|
||||
return idx, result, nil
|
||||
}
|
||||
|
||||
// IntentionDelete deletes the given intention by ID.
|
||||
func (s *Store) IntentionDelete(idx uint64, id string) error {
|
||||
tx := s.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
|
||||
if err := s.intentionDeleteTxn(tx, idx, id); err != nil {
|
||||
return fmt.Errorf("failed intention delete: %s", err)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// intentionDeleteTxn is the inner method used to delete a intention
|
||||
// with the proper indexes into the state store.
|
||||
func (s *Store) intentionDeleteTxn(tx *memdb.Txn, idx uint64, queryID string) error {
|
||||
// Pull the query.
|
||||
wrapped, err := tx.First(intentionsTableName, "id", queryID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
if wrapped == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete the query and update the index.
|
||||
if err := tx.Delete(intentionsTableName, wrapped); err != nil {
|
||||
return fmt.Errorf("failed intention delete: %s", err)
|
||||
}
|
||||
if err := tx.Insert("index", &IndexEntry{intentionsTableName, idx}); err != nil {
|
||||
return fmt.Errorf("failed updating index: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IntentionMatch returns the list of intentions that match the namespace and
|
||||
// name for either a source or destination. This applies the resolution rules
|
||||
// so wildcards will match any value.
|
||||
//
|
||||
// The returned value is the list of intentions in the same order as the
|
||||
// entries in args. The intentions themselves are sorted based on the
|
||||
// intention precedence rules. i.e. result[0][0] is the highest precedent
|
||||
// rule to match for the first entry.
|
||||
func (s *Store) IntentionMatch(ws memdb.WatchSet, args *structs.IntentionQueryMatch) (uint64, []structs.Intentions, error) {
|
||||
tx := s.db.Txn(false)
|
||||
defer tx.Abort()
|
||||
|
||||
// Get the table index.
|
||||
idx := maxIndexTxn(tx, intentionsTableName)
|
||||
if idx < 1 {
|
||||
idx = 1
|
||||
}
|
||||
|
||||
// Make all the calls and accumulate the results
|
||||
results := make([]structs.Intentions, len(args.Entries))
|
||||
for i, entry := range args.Entries {
|
||||
// Each search entry may require multiple queries to memdb, so this
|
||||
// returns the arguments for each necessary Get. Note on performance:
|
||||
// this is not the most optimal set of queries since we repeat some
|
||||
// many times (such as */*). We can work on improving that in the
|
||||
// future, the test cases shouldn't have to change for that.
|
||||
getParams, err := s.intentionMatchGetParams(entry)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
// Perform each call and accumulate the result.
|
||||
var ixns structs.Intentions
|
||||
for _, params := range getParams {
|
||||
iter, err := tx.Get(intentionsTableName, string(args.Type), params...)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("failed intention lookup: %s", err)
|
||||
}
|
||||
|
||||
ws.Add(iter.WatchCh())
|
||||
|
||||
for ixn := iter.Next(); ixn != nil; ixn = iter.Next() {
|
||||
ixns = append(ixns, ixn.(*structs.Intention))
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the results by precedence
|
||||
sort.Sort(structs.IntentionPrecedenceSorter(ixns))
|
||||
|
||||
// Store the result
|
||||
results[i] = ixns
|
||||
}
|
||||
|
||||
return idx, results, nil
|
||||
}
|
||||
|
||||
// intentionMatchGetParams returns the tx.Get parameters to find all the
|
||||
// intentions for a certain entry.
|
||||
func (s *Store) intentionMatchGetParams(entry structs.IntentionMatchEntry) ([][]interface{}, error) {
|
||||
// We always query for "*/*" so include that. If the namespace is a
|
||||
// wildcard, then we're actually done.
|
||||
result := make([][]interface{}, 0, 3)
|
||||
result = append(result, []interface{}{"*", "*"})
|
||||
if entry.Namespace == structs.IntentionWildcard {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Search for NS/* intentions. If we have a wildcard name, then we're done.
|
||||
result = append(result, []interface{}{entry.Namespace, "*"})
|
||||
if entry.Name == structs.IntentionWildcard {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Search for the exact NS/N value.
|
||||
result = append(result, []interface{}{entry.Namespace, entry.Name})
|
||||
return result, nil
|
||||
}
|
|
@ -0,0 +1,559 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStore_IntentionGet_none(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Querying with no results returns nil.
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, res, err := s.IntentionGet(ws, testUUID())
|
||||
assert.Equal(uint64(1), idx)
|
||||
assert.Nil(res)
|
||||
assert.Nil(err)
|
||||
}
|
||||
|
||||
func TestStore_IntentionSetGet_basic(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call Get to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.IntentionGet(ws, testUUID())
|
||||
assert.Nil(err)
|
||||
|
||||
// Build a valid intention
|
||||
ixn := &structs.Intention{
|
||||
ID: testUUID(),
|
||||
SourceNS: "default",
|
||||
SourceName: "*",
|
||||
DestinationNS: "default",
|
||||
DestinationName: "web",
|
||||
Meta: map[string]string{},
|
||||
}
|
||||
|
||||
// Inserting a with empty ID is disallowed.
|
||||
assert.NoError(s.IntentionSet(1, ixn))
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(uint64(1), s.maxIndex(intentionsTableName))
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Read it back out and verify it.
|
||||
expected := &structs.Intention{
|
||||
ID: ixn.ID,
|
||||
SourceNS: "default",
|
||||
SourceName: "*",
|
||||
DestinationNS: "default",
|
||||
DestinationName: "web",
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 1,
|
||||
ModifyIndex: 1,
|
||||
},
|
||||
}
|
||||
expected.UpdatePrecedence()
|
||||
|
||||
ws = memdb.NewWatchSet()
|
||||
idx, actual, err := s.IntentionGet(ws, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(expected.CreateIndex, idx)
|
||||
assert.Equal(expected, actual)
|
||||
|
||||
// Change a value and test updating
|
||||
ixn.SourceNS = "foo"
|
||||
assert.NoError(s.IntentionSet(2, ixn))
|
||||
|
||||
// Change a value that isn't in the unique 4 tuple and check we don't
|
||||
// incorrectly consider this a duplicate when updating.
|
||||
ixn.Action = structs.IntentionActionDeny
|
||||
assert.NoError(s.IntentionSet(2, ixn))
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(uint64(2), s.maxIndex(intentionsTableName))
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Read it back and verify the data was updated
|
||||
expected.SourceNS = ixn.SourceNS
|
||||
expected.Action = structs.IntentionActionDeny
|
||||
expected.ModifyIndex = 2
|
||||
ws = memdb.NewWatchSet()
|
||||
idx, actual, err = s.IntentionGet(ws, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(expected.ModifyIndex, idx)
|
||||
assert.Equal(expected, actual)
|
||||
|
||||
// Attempt to insert another intention with duplicate 4-tuple
|
||||
ixn = &structs.Intention{
|
||||
ID: testUUID(),
|
||||
SourceNS: "default",
|
||||
SourceName: "*",
|
||||
DestinationNS: "default",
|
||||
DestinationName: "web",
|
||||
Meta: map[string]string{},
|
||||
}
|
||||
|
||||
// Duplicate 4-tuple should cause an error
|
||||
ws = memdb.NewWatchSet()
|
||||
assert.Error(s.IntentionSet(3, ixn))
|
||||
|
||||
// Make sure the index did NOT get updated.
|
||||
assert.Equal(uint64(2), s.maxIndex(intentionsTableName))
|
||||
assert.False(watchFired(ws), "watch not fired")
|
||||
}
|
||||
|
||||
func TestStore_IntentionSet_emptyId(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.IntentionGet(ws, testUUID())
|
||||
assert.NoError(err)
|
||||
|
||||
// Inserting a with empty ID is disallowed.
|
||||
err = s.IntentionSet(1, &structs.Intention{})
|
||||
assert.Error(err)
|
||||
assert.Contains(err.Error(), ErrMissingIntentionID.Error())
|
||||
|
||||
// Index is not updated if nothing is saved.
|
||||
assert.Equal(s.maxIndex(intentionsTableName), uint64(0))
|
||||
assert.False(watchFired(ws), "watch fired")
|
||||
}
|
||||
|
||||
func TestStore_IntentionSet_updateCreatedAt(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Build a valid intention
|
||||
now := time.Now()
|
||||
ixn := structs.Intention{
|
||||
ID: testUUID(),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
// Insert
|
||||
assert.NoError(s.IntentionSet(1, &ixn))
|
||||
|
||||
// Change a value and test updating
|
||||
ixnUpdate := ixn
|
||||
ixnUpdate.CreatedAt = now.Add(10 * time.Second)
|
||||
assert.NoError(s.IntentionSet(2, &ixnUpdate))
|
||||
|
||||
// Read it back and verify
|
||||
_, actual, err := s.IntentionGet(nil, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(now, actual.CreatedAt)
|
||||
}
|
||||
|
||||
func TestStore_IntentionSet_metaNil(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Build a valid intention
|
||||
ixn := structs.Intention{
|
||||
ID: testUUID(),
|
||||
}
|
||||
|
||||
// Insert
|
||||
assert.NoError(s.IntentionSet(1, &ixn))
|
||||
|
||||
// Read it back and verify
|
||||
_, actual, err := s.IntentionGet(nil, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(actual.Meta)
|
||||
}
|
||||
|
||||
func TestStore_IntentionSet_metaSet(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Build a valid intention
|
||||
ixn := structs.Intention{
|
||||
ID: testUUID(),
|
||||
Meta: map[string]string{"foo": "bar"},
|
||||
}
|
||||
|
||||
// Insert
|
||||
assert.NoError(s.IntentionSet(1, &ixn))
|
||||
|
||||
// Read it back and verify
|
||||
_, actual, err := s.IntentionGet(nil, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(ixn.Meta, actual.Meta)
|
||||
}
|
||||
|
||||
func TestStore_IntentionDelete(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Call Get to populate the watch set
|
||||
ws := memdb.NewWatchSet()
|
||||
_, _, err := s.IntentionGet(ws, testUUID())
|
||||
assert.NoError(err)
|
||||
|
||||
// Create
|
||||
ixn := &structs.Intention{ID: testUUID()}
|
||||
assert.NoError(s.IntentionSet(1, ixn))
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(s.maxIndex(intentionsTableName), uint64(1))
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Delete
|
||||
assert.NoError(s.IntentionDelete(2, ixn.ID))
|
||||
|
||||
// Make sure the index got updated.
|
||||
assert.Equal(s.maxIndex(intentionsTableName), uint64(2))
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Sanity check to make sure it's not there.
|
||||
idx, actual, err := s.IntentionGet(nil, ixn.ID)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(2))
|
||||
assert.Nil(actual)
|
||||
}
|
||||
|
||||
func TestStore_IntentionsList(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Querying with no results returns nil.
|
||||
ws := memdb.NewWatchSet()
|
||||
idx, res, err := s.Intentions(ws)
|
||||
assert.NoError(err)
|
||||
assert.Nil(res)
|
||||
assert.Equal(uint64(1), idx)
|
||||
|
||||
// Create some intentions
|
||||
ixns := structs.Intentions{
|
||||
&structs.Intention{
|
||||
ID: testUUID(),
|
||||
Meta: map[string]string{},
|
||||
},
|
||||
&structs.Intention{
|
||||
ID: testUUID(),
|
||||
Meta: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
// Force deterministic sort order
|
||||
ixns[0].ID = "a" + ixns[0].ID[1:]
|
||||
ixns[1].ID = "b" + ixns[1].ID[1:]
|
||||
|
||||
// Create
|
||||
for i, ixn := range ixns {
|
||||
assert.NoError(s.IntentionSet(uint64(1+i), ixn))
|
||||
}
|
||||
assert.True(watchFired(ws), "watch fired")
|
||||
|
||||
// Read it back and verify.
|
||||
expected := structs.Intentions{
|
||||
&structs.Intention{
|
||||
ID: ixns[0].ID,
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 1,
|
||||
ModifyIndex: 1,
|
||||
},
|
||||
},
|
||||
&structs.Intention{
|
||||
ID: ixns[1].ID,
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 2,
|
||||
ModifyIndex: 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := range expected {
|
||||
expected[i].UpdatePrecedence() // to match what is returned...
|
||||
}
|
||||
idx, actual, err := s.Intentions(nil)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(2))
|
||||
assert.Equal(expected, actual)
|
||||
}
|
||||
|
||||
// Test the matrix of match logic.
|
||||
//
|
||||
// Note that this doesn't need to test the intention sort logic exhaustively
|
||||
// since this is tested in their sort implementation in the structs.
|
||||
func TestStore_IntentionMatch_table(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Insert [][]string // List of intentions to insert
|
||||
Query [][]string // List of intentions to match
|
||||
Expected [][][]string // List of matches, where each match is a list of intentions
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
"single exact namespace/name",
|
||||
[][]string{
|
||||
{"foo", "*"},
|
||||
{"foo", "bar"},
|
||||
{"foo", "baz"}, // shouldn't match
|
||||
{"bar", "bar"}, // shouldn't match
|
||||
{"bar", "*"}, // shouldn't match
|
||||
{"*", "*"},
|
||||
},
|
||||
[][]string{
|
||||
{"foo", "bar"},
|
||||
},
|
||||
[][][]string{
|
||||
{
|
||||
{"foo", "bar"},
|
||||
{"foo", "*"},
|
||||
{"*", "*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
"multiple exact namespace/name",
|
||||
[][]string{
|
||||
{"foo", "*"},
|
||||
{"foo", "bar"},
|
||||
{"foo", "baz"}, // shouldn't match
|
||||
{"bar", "bar"},
|
||||
{"bar", "*"},
|
||||
},
|
||||
[][]string{
|
||||
{"foo", "bar"},
|
||||
{"bar", "bar"},
|
||||
},
|
||||
[][][]string{
|
||||
{
|
||||
{"foo", "bar"},
|
||||
{"foo", "*"},
|
||||
},
|
||||
{
|
||||
{"bar", "bar"},
|
||||
{"bar", "*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
"single exact namespace/name with duplicate destinations",
|
||||
[][]string{
|
||||
// 4-tuple specifies src and destination to test duplicate destinations
|
||||
// with different sources. We flip them around to test in both
|
||||
// directions. The first pair are the ones searched on in both cases so
|
||||
// the duplicates need to be there.
|
||||
{"foo", "bar", "foo", "*"},
|
||||
{"foo", "bar", "bar", "*"},
|
||||
{"*", "*", "*", "*"},
|
||||
},
|
||||
[][]string{
|
||||
{"foo", "bar"},
|
||||
},
|
||||
[][][]string{
|
||||
{
|
||||
// Note the first two have the same precedence so we rely on arbitrary
|
||||
// lexicographical tie-break behaviour.
|
||||
{"foo", "bar", "bar", "*"},
|
||||
{"foo", "bar", "foo", "*"},
|
||||
{"*", "*", "*", "*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// testRunner implements the test for a single case, but can be
|
||||
// parameterized to run for both source and destination so we can
|
||||
// test both cases.
|
||||
testRunner := func(t *testing.T, tc testCase, typ structs.IntentionMatchType) {
|
||||
// Insert the set
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
var idx uint64 = 1
|
||||
for _, v := range tc.Insert {
|
||||
ixn := &structs.Intention{ID: testUUID()}
|
||||
switch typ {
|
||||
case structs.IntentionMatchDestination:
|
||||
ixn.DestinationNS = v[0]
|
||||
ixn.DestinationName = v[1]
|
||||
if len(v) == 4 {
|
||||
ixn.SourceNS = v[2]
|
||||
ixn.SourceName = v[3]
|
||||
}
|
||||
case structs.IntentionMatchSource:
|
||||
ixn.SourceNS = v[0]
|
||||
ixn.SourceName = v[1]
|
||||
if len(v) == 4 {
|
||||
ixn.DestinationNS = v[2]
|
||||
ixn.DestinationName = v[3]
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(s.IntentionSet(idx, ixn))
|
||||
|
||||
idx++
|
||||
}
|
||||
|
||||
// Build the arguments
|
||||
args := &structs.IntentionQueryMatch{Type: typ}
|
||||
for _, q := range tc.Query {
|
||||
args.Entries = append(args.Entries, structs.IntentionMatchEntry{
|
||||
Namespace: q[0],
|
||||
Name: q[1],
|
||||
})
|
||||
}
|
||||
|
||||
// Match
|
||||
_, matches, err := s.IntentionMatch(nil, args)
|
||||
assert.NoError(err)
|
||||
|
||||
// Should have equal lengths
|
||||
require.Len(t, matches, len(tc.Expected))
|
||||
|
||||
// Verify matches
|
||||
for i, expected := range tc.Expected {
|
||||
var actual [][]string
|
||||
for _, ixn := range matches[i] {
|
||||
switch typ {
|
||||
case structs.IntentionMatchDestination:
|
||||
if len(expected) > 1 && len(expected[0]) == 4 {
|
||||
actual = append(actual, []string{
|
||||
ixn.DestinationNS,
|
||||
ixn.DestinationName,
|
||||
ixn.SourceNS,
|
||||
ixn.SourceName,
|
||||
})
|
||||
} else {
|
||||
actual = append(actual, []string{ixn.DestinationNS, ixn.DestinationName})
|
||||
}
|
||||
case structs.IntentionMatchSource:
|
||||
if len(expected) > 1 && len(expected[0]) == 4 {
|
||||
actual = append(actual, []string{
|
||||
ixn.SourceNS,
|
||||
ixn.SourceName,
|
||||
ixn.DestinationNS,
|
||||
ixn.DestinationName,
|
||||
})
|
||||
} else {
|
||||
actual = append(actual, []string{ixn.SourceNS, ixn.SourceName})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Name+" (destination)", func(t *testing.T) {
|
||||
testRunner(t, tc, structs.IntentionMatchDestination)
|
||||
})
|
||||
|
||||
t.Run(tc.Name+" (source)", func(t *testing.T) {
|
||||
testRunner(t, tc, structs.IntentionMatchSource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Intention_Snapshot_Restore(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := testStateStore(t)
|
||||
|
||||
// Create some intentions.
|
||||
ixns := structs.Intentions{
|
||||
&structs.Intention{
|
||||
DestinationName: "foo",
|
||||
},
|
||||
&structs.Intention{
|
||||
DestinationName: "bar",
|
||||
},
|
||||
&structs.Intention{
|
||||
DestinationName: "baz",
|
||||
},
|
||||
}
|
||||
|
||||
// Force the sort order of the UUIDs before we create them so the
|
||||
// order is deterministic.
|
||||
id := testUUID()
|
||||
ixns[0].ID = "a" + id[1:]
|
||||
ixns[1].ID = "b" + id[1:]
|
||||
ixns[2].ID = "c" + id[1:]
|
||||
|
||||
// Now create
|
||||
for i, ixn := range ixns {
|
||||
assert.NoError(s.IntentionSet(uint64(4+i), ixn))
|
||||
}
|
||||
|
||||
// Snapshot the queries.
|
||||
snap := s.Snapshot()
|
||||
defer snap.Close()
|
||||
|
||||
// Alter the real state store.
|
||||
assert.NoError(s.IntentionDelete(7, ixns[0].ID))
|
||||
|
||||
// Verify the snapshot.
|
||||
assert.Equal(snap.LastIndex(), uint64(6))
|
||||
|
||||
// Expect them sorted in insertion order
|
||||
expected := structs.Intentions{
|
||||
&structs.Intention{
|
||||
ID: ixns[0].ID,
|
||||
DestinationName: "foo",
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 4,
|
||||
ModifyIndex: 4,
|
||||
},
|
||||
},
|
||||
&structs.Intention{
|
||||
ID: ixns[1].ID,
|
||||
DestinationName: "bar",
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 5,
|
||||
ModifyIndex: 5,
|
||||
},
|
||||
},
|
||||
&structs.Intention{
|
||||
ID: ixns[2].ID,
|
||||
DestinationName: "baz",
|
||||
Meta: map[string]string{},
|
||||
RaftIndex: structs.RaftIndex{
|
||||
CreateIndex: 6,
|
||||
ModifyIndex: 6,
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := range expected {
|
||||
expected[i].UpdatePrecedence() // to match what is returned...
|
||||
}
|
||||
dump, err := snap.Intentions()
|
||||
assert.NoError(err)
|
||||
assert.Equal(expected, dump)
|
||||
|
||||
// Restore the values into a new state store.
|
||||
func() {
|
||||
s := testStateStore(t)
|
||||
restore := s.Restore()
|
||||
for _, ixn := range dump {
|
||||
assert.NoError(restore.Intention(ixn))
|
||||
}
|
||||
restore.Commit()
|
||||
|
||||
// Read the restored values back out and verify that they match. Note that
|
||||
// Intentions are returned precedence sorted unlike the snapshot so we need
|
||||
// to rearrange the expected slice some.
|
||||
expected[0], expected[1], expected[2] = expected[1], expected[2], expected[0]
|
||||
idx, actual, err := s.Intentions(nil)
|
||||
assert.NoError(err)
|
||||
assert.Equal(idx, uint64(6))
|
||||
assert.Equal(expected, actual)
|
||||
}()
|
||||
}
|
|
@ -28,6 +28,14 @@ var (
|
|||
// ErrMissingQueryID is returned when a Query set is called on
|
||||
// a Query with an empty ID.
|
||||
ErrMissingQueryID = errors.New("Missing Query ID")
|
||||
|
||||
// ErrMissingCARootID is returned when an CARoot set is called
|
||||
// with an CARoot with an empty ID.
|
||||
ErrMissingCARootID = errors.New("Missing CA Root ID")
|
||||
|
||||
// ErrMissingIntentionID is returned when an Intention set is called
|
||||
// with an Intention with an empty ID.
|
||||
ErrMissingIntentionID = errors.New("Missing Intention ID")
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
23
agent/dns.go
23
agent/dns.go
|
@ -339,7 +339,7 @@ func (d *DNSServer) addSOA(msg *dns.Msg) {
|
|||
// nameservers returns the names and ip addresses of up to three random servers
|
||||
// in the current cluster which serve as authoritative name servers for zone.
|
||||
func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
|
||||
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "")
|
||||
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false)
|
||||
if err != nil {
|
||||
d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err)
|
||||
return nil, nil
|
||||
|
@ -417,7 +417,7 @@ PARSE:
|
|||
n = n + 1
|
||||
}
|
||||
|
||||
switch labels[n-1] {
|
||||
switch kind := labels[n-1]; kind {
|
||||
case "service":
|
||||
if n == 1 {
|
||||
goto INVALID
|
||||
|
@ -435,7 +435,7 @@ PARSE:
|
|||
}
|
||||
|
||||
// _name._tag.service.consul
|
||||
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, req, resp)
|
||||
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp)
|
||||
|
||||
// Consul 0.3 and prior format for SRV queries
|
||||
} else {
|
||||
|
@ -447,9 +447,17 @@ PARSE:
|
|||
}
|
||||
|
||||
// tag[.tag].name.service.consul
|
||||
d.serviceLookup(network, datacenter, labels[n-2], tag, req, resp)
|
||||
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp)
|
||||
}
|
||||
|
||||
case "connect":
|
||||
if n == 1 {
|
||||
goto INVALID
|
||||
}
|
||||
|
||||
// name.connect.consul
|
||||
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp)
|
||||
|
||||
case "node":
|
||||
if n == 1 {
|
||||
goto INVALID
|
||||
|
@ -913,8 +921,9 @@ func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed
|
|||
}
|
||||
|
||||
// lookupServiceNodes returns nodes with a given service.
|
||||
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string) (structs.IndexedCheckServiceNodes, error) {
|
||||
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool) (structs.IndexedCheckServiceNodes, error) {
|
||||
args := structs.ServiceSpecificRequest{
|
||||
Connect: connect,
|
||||
Datacenter: datacenter,
|
||||
ServiceName: service,
|
||||
ServiceTag: tag,
|
||||
|
@ -950,8 +959,8 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string) (structs
|
|||
}
|
||||
|
||||
// serviceLookup is used to handle a service query
|
||||
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, resp *dns.Msg) {
|
||||
out, err := d.lookupServiceNodes(datacenter, service, tag)
|
||||
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg) {
|
||||
out, err := d.lookupServiceNodes(datacenter, service, tag, connect)
|
||||
if err != nil {
|
||||
d.logger.Printf("[ERR] dns: rpc error: %v", err)
|
||||
resp.SetRcode(req, dns.RcodeServerFailure)
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/hashicorp/serf/coordinate"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -1125,6 +1126,51 @@ func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) {
|
|||
verify.Values(t, "extra", in.Extra, wantExtra)
|
||||
}
|
||||
|
||||
func TestDNS_ConnectServiceLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
{
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Address = "127.0.0.55"
|
||||
args.Service.ProxyDestination = "db"
|
||||
args.Service.Address = ""
|
||||
args.Service.Port = 12345
|
||||
var out struct{}
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
}
|
||||
|
||||
// Look up the service
|
||||
questions := []string{
|
||||
"db.connect.consul.",
|
||||
}
|
||||
for _, question := range questions {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(question, dns.TypeSRV)
|
||||
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m, a.DNSAddr())
|
||||
assert.Nil(err)
|
||||
assert.Len(in.Answer, 1)
|
||||
|
||||
srvRec, ok := in.Answer[0].(*dns.SRV)
|
||||
assert.True(ok)
|
||||
assert.Equal(uint16(12345), srvRec.Port)
|
||||
assert.Equal("foo.node.dc1.consul.", srvRec.Target)
|
||||
assert.Equal(uint32(0), srvRec.Hdr.Ttl)
|
||||
|
||||
cnameRec, ok := in.Extra[0].(*dns.A)
|
||||
assert.True(ok)
|
||||
assert.Equal("foo.node.dc1.consul.", cnameRec.Hdr.Name)
|
||||
assert.Equal(uint32(0), srvRec.Hdr.Ttl)
|
||||
assert.Equal("127.0.0.55", cnameRec.A.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_ExternalServiceLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
|
|
|
@ -143,9 +143,17 @@ RETRY_ONCE:
|
|||
return out.HealthChecks, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) HealthConnectServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
return s.healthServiceNodes(resp, req, true)
|
||||
}
|
||||
|
||||
func (s *HTTPServer) HealthServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
return s.healthServiceNodes(resp, req, false)
|
||||
}
|
||||
|
||||
func (s *HTTPServer) healthServiceNodes(resp http.ResponseWriter, req *http.Request, connect bool) (interface{}, error) {
|
||||
// Set default DC
|
||||
args := structs.ServiceSpecificRequest{}
|
||||
args := structs.ServiceSpecificRequest{Connect: connect}
|
||||
s.parseSource(req, &args.Source)
|
||||
args.NodeMetaFilters = s.parseMetaFilter(req)
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
|
@ -159,8 +167,14 @@ func (s *HTTPServer) HealthServiceNodes(resp http.ResponseWriter, req *http.Requ
|
|||
args.TagFilter = true
|
||||
}
|
||||
|
||||
// Determine the prefix
|
||||
prefix := "/v1/health/service/"
|
||||
if connect {
|
||||
prefix = "/v1/health/connect/"
|
||||
}
|
||||
|
||||
// Pull out the service name
|
||||
args.ServiceName = strings.TrimPrefix(req.URL.Path, "/v1/health/service/")
|
||||
args.ServiceName = strings.TrimPrefix(req.URL.Path, prefix)
|
||||
if args.ServiceName == "" {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(resp, "Missing service name")
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/testutil/retry"
|
||||
"github.com/hashicorp/serf/coordinate"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHealthChecksInState(t *testing.T) {
|
||||
|
@ -770,6 +771,105 @@ func TestHealthServiceNodes_WanTranslation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHealthConnectServiceNodes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
var out struct{}
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/health/connect/%s?dc=dc1", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
// Should be a non-nil empty list for checks
|
||||
nodes := obj.(structs.CheckServiceNodes)
|
||||
assert.Len(nodes, 1)
|
||||
assert.Len(nodes[0].Checks, 0)
|
||||
}
|
||||
|
||||
func TestHealthConnectServiceNodes_PassingFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register
|
||||
args := structs.TestRegisterRequestProxy(t)
|
||||
args.Check = &structs.HealthCheck{
|
||||
Node: args.Node,
|
||||
Name: "check",
|
||||
ServiceID: args.Service.Service,
|
||||
Status: api.HealthCritical,
|
||||
}
|
||||
var out struct{}
|
||||
assert.Nil(t, a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
t.Run("bc_no_query_value", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/health/connect/%s?passing", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
// Should be 0 health check for consul
|
||||
nodes := obj.(structs.CheckServiceNodes)
|
||||
assert.Len(nodes, 0)
|
||||
})
|
||||
|
||||
t.Run("passing_true", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/health/connect/%s?passing=true", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
// Should be 0 health check for consul
|
||||
nodes := obj.(structs.CheckServiceNodes)
|
||||
assert.Len(nodes, 0)
|
||||
})
|
||||
|
||||
t.Run("passing_false", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/health/connect/%s?passing=false", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.HealthConnectServiceNodes(resp, req)
|
||||
assert.Nil(err)
|
||||
assertIndex(t, resp)
|
||||
|
||||
// Should be 1
|
||||
nodes := obj.(structs.CheckServiceNodes)
|
||||
assert.Len(nodes, 1)
|
||||
})
|
||||
|
||||
t.Run("passing_bad", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf(
|
||||
"/v1/health/connect/%s?passing=nope-nope", args.Service.ProxyDestination), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
a.srv.HealthConnectServiceNodes(resp, req)
|
||||
assert.Equal(400, resp.Code)
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
assert.Nil(err)
|
||||
assert.True(bytes.Contains(body, []byte("Invalid value for ?passing")))
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterNonPassing(t *testing.T) {
|
||||
t.Parallel()
|
||||
nodes := structs.CheckServiceNodes{
|
||||
|
|
|
@ -3,6 +3,7 @@ package agent
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
|
@ -16,6 +17,7 @@ import (
|
|||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/cache"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
@ -384,6 +386,13 @@ func (s *HTTPServer) Index(resp http.ResponseWriter, req *http.Request) {
|
|||
|
||||
// decodeBody is used to decode a JSON request body
|
||||
func decodeBody(req *http.Request, out interface{}, cb func(interface{}) error) error {
|
||||
// This generally only happens in tests since real HTTP requests set
|
||||
// a non-nil body with no content. We guard against it anyways to prevent
|
||||
// a panic. The EOF response is the same behavior as an empty reader.
|
||||
if req.Body == nil {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
var raw interface{}
|
||||
dec := json.NewDecoder(req.Body)
|
||||
if err := dec.Decode(&raw); err != nil {
|
||||
|
@ -409,6 +418,14 @@ func setTranslateAddr(resp http.ResponseWriter, active bool) {
|
|||
|
||||
// setIndex is used to set the index response header
|
||||
func setIndex(resp http.ResponseWriter, index uint64) {
|
||||
// If we ever return X-Consul-Index of 0 blocking clients will go into a busy
|
||||
// loop and hammer us since ?index=0 will never block. It's always safe to
|
||||
// return index=1 since the very first Raft write is always an internal one
|
||||
// writing the raft config for the cluster so no user-facing blocking query
|
||||
// will ever legitimately have an X-Consul-Index of 1.
|
||||
if index == 0 {
|
||||
index = 1
|
||||
}
|
||||
resp.Header().Set("X-Consul-Index", strconv.FormatUint(index, 10))
|
||||
}
|
||||
|
||||
|
@ -444,6 +461,15 @@ func setMeta(resp http.ResponseWriter, m *structs.QueryMeta) {
|
|||
setConsistency(resp, m.ConsistencyLevel)
|
||||
}
|
||||
|
||||
// setCacheMeta sets http response headers to indicate cache status.
|
||||
func setCacheMeta(resp http.ResponseWriter, m *cache.ResultMeta) {
|
||||
str := "MISS"
|
||||
if m != nil && m.Hit {
|
||||
str = "HIT"
|
||||
}
|
||||
resp.Header().Set("X-Cache", str)
|
||||
}
|
||||
|
||||
// setHeaders is used to set canonical response header fields
|
||||
func setHeaders(resp http.ResponseWriter, headers map[string]string) {
|
||||
for field, value := range headers {
|
||||
|
|
|
@ -29,16 +29,27 @@ func init() {
|
|||
registerEndpoint("/v1/agent/check/warn/", []string{"PUT"}, (*HTTPServer).AgentCheckWarn)
|
||||
registerEndpoint("/v1/agent/check/fail/", []string{"PUT"}, (*HTTPServer).AgentCheckFail)
|
||||
registerEndpoint("/v1/agent/check/update/", []string{"PUT"}, (*HTTPServer).AgentCheckUpdate)
|
||||
registerEndpoint("/v1/agent/connect/authorize", []string{"POST"}, (*HTTPServer).AgentConnectAuthorize)
|
||||
registerEndpoint("/v1/agent/connect/ca/roots", []string{"GET"}, (*HTTPServer).AgentConnectCARoots)
|
||||
registerEndpoint("/v1/agent/connect/ca/leaf/", []string{"GET"}, (*HTTPServer).AgentConnectCALeafCert)
|
||||
registerEndpoint("/v1/agent/connect/proxy/", []string{"GET"}, (*HTTPServer).AgentConnectProxyConfig)
|
||||
registerEndpoint("/v1/agent/service/register", []string{"PUT"}, (*HTTPServer).AgentRegisterService)
|
||||
registerEndpoint("/v1/agent/service/deregister/", []string{"PUT"}, (*HTTPServer).AgentDeregisterService)
|
||||
registerEndpoint("/v1/agent/service/maintenance/", []string{"PUT"}, (*HTTPServer).AgentServiceMaintenance)
|
||||
registerEndpoint("/v1/catalog/register", []string{"PUT"}, (*HTTPServer).CatalogRegister)
|
||||
registerEndpoint("/v1/catalog/connect/", []string{"GET"}, (*HTTPServer).CatalogConnectServiceNodes)
|
||||
registerEndpoint("/v1/catalog/deregister", []string{"PUT"}, (*HTTPServer).CatalogDeregister)
|
||||
registerEndpoint("/v1/catalog/datacenters", []string{"GET"}, (*HTTPServer).CatalogDatacenters)
|
||||
registerEndpoint("/v1/catalog/nodes", []string{"GET"}, (*HTTPServer).CatalogNodes)
|
||||
registerEndpoint("/v1/catalog/services", []string{"GET"}, (*HTTPServer).CatalogServices)
|
||||
registerEndpoint("/v1/catalog/service/", []string{"GET"}, (*HTTPServer).CatalogServiceNodes)
|
||||
registerEndpoint("/v1/catalog/node/", []string{"GET"}, (*HTTPServer).CatalogNodeServices)
|
||||
registerEndpoint("/v1/connect/ca/configuration", []string{"GET", "PUT"}, (*HTTPServer).ConnectCAConfiguration)
|
||||
registerEndpoint("/v1/connect/ca/roots", []string{"GET"}, (*HTTPServer).ConnectCARoots)
|
||||
registerEndpoint("/v1/connect/intentions", []string{"GET", "POST"}, (*HTTPServer).IntentionEndpoint)
|
||||
registerEndpoint("/v1/connect/intentions/match", []string{"GET"}, (*HTTPServer).IntentionMatch)
|
||||
registerEndpoint("/v1/connect/intentions/check", []string{"GET"}, (*HTTPServer).IntentionCheck)
|
||||
registerEndpoint("/v1/connect/intentions/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).IntentionSpecific)
|
||||
registerEndpoint("/v1/coordinate/datacenters", []string{"GET"}, (*HTTPServer).CoordinateDatacenters)
|
||||
registerEndpoint("/v1/coordinate/nodes", []string{"GET"}, (*HTTPServer).CoordinateNodes)
|
||||
registerEndpoint("/v1/coordinate/node/", []string{"GET"}, (*HTTPServer).CoordinateNode)
|
||||
|
@ -49,6 +60,7 @@ func init() {
|
|||
registerEndpoint("/v1/health/checks/", []string{"GET"}, (*HTTPServer).HealthServiceChecks)
|
||||
registerEndpoint("/v1/health/state/", []string{"GET"}, (*HTTPServer).HealthChecksInState)
|
||||
registerEndpoint("/v1/health/service/", []string{"GET"}, (*HTTPServer).HealthServiceNodes)
|
||||
registerEndpoint("/v1/health/connect/", []string{"GET"}, (*HTTPServer).HealthConnectServiceNodes)
|
||||
registerEndpoint("/v1/internal/ui/nodes", []string{"GET"}, (*HTTPServer).UINodes)
|
||||
registerEndpoint("/v1/internal/ui/node/", []string{"GET"}, (*HTTPServer).UINodeInfo)
|
||||
registerEndpoint("/v1/internal/ui/services", []string{"GET"}, (*HTTPServer).UIServices)
|
||||
|
|
|
@ -0,0 +1,302 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
// /v1/connection/intentions
|
||||
func (s *HTTPServer) IntentionEndpoint(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
switch req.Method {
|
||||
case "GET":
|
||||
return s.IntentionList(resp, req)
|
||||
|
||||
case "POST":
|
||||
return s.IntentionCreate(resp, req)
|
||||
|
||||
default:
|
||||
return nil, MethodNotAllowedError{req.Method, []string{"GET", "POST"}}
|
||||
}
|
||||
}
|
||||
|
||||
// GET /v1/connect/intentions
|
||||
func (s *HTTPServer) IntentionList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in IntentionEndpoint
|
||||
|
||||
var args structs.DCSpecificRequest
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var reply structs.IndexedIntentions
|
||||
if err := s.agent.RPC("Intention.List", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reply.Intentions, nil
|
||||
}
|
||||
|
||||
// POST /v1/connect/intentions
|
||||
func (s *HTTPServer) IntentionCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in IntentionEndpoint
|
||||
|
||||
args := structs.IntentionRequest{
|
||||
Op: structs.IntentionOpCreate,
|
||||
}
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
s.parseToken(req, &args.Token)
|
||||
if err := decodeBody(req, &args.Intention, nil); err != nil {
|
||||
return nil, fmt.Errorf("Failed to decode request body: %s", err)
|
||||
}
|
||||
|
||||
var reply string
|
||||
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return intentionCreateResponse{reply}, nil
|
||||
}
|
||||
|
||||
// GET /v1/connect/intentions/match
|
||||
func (s *HTTPServer) IntentionMatch(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Prepare args
|
||||
args := &structs.IntentionQueryRequest{Match: &structs.IntentionQueryMatch{}}
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
|
||||
// Extract the "by" query parameter
|
||||
if by, ok := q["by"]; !ok || len(by) != 1 {
|
||||
return nil, fmt.Errorf("required query parameter 'by' not set")
|
||||
} else {
|
||||
switch v := structs.IntentionMatchType(by[0]); v {
|
||||
case structs.IntentionMatchSource, structs.IntentionMatchDestination:
|
||||
args.Match.Type = v
|
||||
default:
|
||||
return nil, fmt.Errorf("'by' parameter must be one of 'source' or 'destination'")
|
||||
}
|
||||
}
|
||||
|
||||
// Extract all the match names
|
||||
names, ok := q["name"]
|
||||
if !ok || len(names) == 0 {
|
||||
return nil, fmt.Errorf("required query parameter 'name' not set")
|
||||
}
|
||||
|
||||
// Build the entries in order. The order matters since that is the
|
||||
// order of the returned responses.
|
||||
args.Match.Entries = make([]structs.IntentionMatchEntry, len(names))
|
||||
for i, n := range names {
|
||||
entry, err := parseIntentionMatchEntry(n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("name %q is invalid: %s", n, err)
|
||||
}
|
||||
|
||||
args.Match.Entries[i] = entry
|
||||
}
|
||||
|
||||
var reply structs.IndexedIntentionMatches
|
||||
if err := s.agent.RPC("Intention.Match", args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We must have an identical count of matches
|
||||
if len(reply.Matches) != len(names) {
|
||||
return nil, fmt.Errorf("internal error: match response count didn't match input count")
|
||||
}
|
||||
|
||||
// Use empty list instead of nil.
|
||||
response := make(map[string]structs.Intentions)
|
||||
for i, ixns := range reply.Matches {
|
||||
response[names[i]] = ixns
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GET /v1/connect/intentions/check
|
||||
func (s *HTTPServer) IntentionCheck(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Prepare args
|
||||
args := &structs.IntentionQueryRequest{Check: &structs.IntentionQueryCheck{}}
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
|
||||
// Set the source type if set
|
||||
args.Check.SourceType = structs.IntentionSourceConsul
|
||||
if sourceType, ok := q["source-type"]; ok && len(sourceType) > 0 {
|
||||
args.Check.SourceType = structs.IntentionSourceType(sourceType[0])
|
||||
}
|
||||
|
||||
// Extract the source/destination
|
||||
source, ok := q["source"]
|
||||
if !ok || len(source) != 1 {
|
||||
return nil, fmt.Errorf("required query parameter 'source' not set")
|
||||
}
|
||||
destination, ok := q["destination"]
|
||||
if !ok || len(destination) != 1 {
|
||||
return nil, fmt.Errorf("required query parameter 'destination' not set")
|
||||
}
|
||||
|
||||
// We parse them the same way as matches to extract namespace/name
|
||||
args.Check.SourceName = source[0]
|
||||
if args.Check.SourceType == structs.IntentionSourceConsul {
|
||||
entry, err := parseIntentionMatchEntry(source[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("source %q is invalid: %s", source[0], err)
|
||||
}
|
||||
args.Check.SourceNS = entry.Namespace
|
||||
args.Check.SourceName = entry.Name
|
||||
}
|
||||
|
||||
// The destination is always in the Consul format
|
||||
entry, err := parseIntentionMatchEntry(destination[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("destination %q is invalid: %s", destination[0], err)
|
||||
}
|
||||
args.Check.DestinationNS = entry.Namespace
|
||||
args.Check.DestinationName = entry.Name
|
||||
|
||||
var reply structs.IntentionQueryCheckResponse
|
||||
if err := s.agent.RPC("Intention.Check", args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &reply, nil
|
||||
}
|
||||
|
||||
// IntentionSpecific handles the endpoint for /v1/connection/intentions/:id
|
||||
func (s *HTTPServer) IntentionSpecific(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
id := strings.TrimPrefix(req.URL.Path, "/v1/connect/intentions/")
|
||||
|
||||
switch req.Method {
|
||||
case "GET":
|
||||
return s.IntentionSpecificGet(id, resp, req)
|
||||
|
||||
case "PUT":
|
||||
return s.IntentionSpecificUpdate(id, resp, req)
|
||||
|
||||
case "DELETE":
|
||||
return s.IntentionSpecificDelete(id, resp, req)
|
||||
|
||||
default:
|
||||
return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}}
|
||||
}
|
||||
}
|
||||
|
||||
// GET /v1/connect/intentions/:id
|
||||
func (s *HTTPServer) IntentionSpecificGet(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in IntentionEndpoint
|
||||
|
||||
args := structs.IntentionQueryRequest{
|
||||
IntentionID: id,
|
||||
}
|
||||
if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var reply structs.IndexedIntentions
|
||||
if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil {
|
||||
// We have to check the string since the RPC sheds the error type
|
||||
if err.Error() == consul.ErrIntentionNotFound.Error() {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprint(resp, err.Error())
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// This shouldn't happen since the called API documents it shouldn't,
|
||||
// but we check since the alternative if it happens is a panic.
|
||||
if len(reply.Intentions) == 0 {
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return reply.Intentions[0], nil
|
||||
}
|
||||
|
||||
// PUT /v1/connect/intentions/:id
|
||||
func (s *HTTPServer) IntentionSpecificUpdate(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in IntentionEndpoint
|
||||
|
||||
args := structs.IntentionRequest{
|
||||
Op: structs.IntentionOpUpdate,
|
||||
}
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
s.parseToken(req, &args.Token)
|
||||
if err := decodeBody(req, &args.Intention, nil); err != nil {
|
||||
resp.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(resp, "Request decode failed: %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Use the ID from the URL
|
||||
args.Intention.ID = id
|
||||
|
||||
var reply string
|
||||
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update uses the same create response
|
||||
return intentionCreateResponse{reply}, nil
|
||||
|
||||
}
|
||||
|
||||
// DELETE /v1/connect/intentions/:id
|
||||
func (s *HTTPServer) IntentionSpecificDelete(id string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
// Method is tested in IntentionEndpoint
|
||||
|
||||
args := structs.IntentionRequest{
|
||||
Op: structs.IntentionOpDelete,
|
||||
Intention: &structs.Intention{ID: id},
|
||||
}
|
||||
s.parseDC(req, &args.Datacenter)
|
||||
s.parseToken(req, &args.Token)
|
||||
|
||||
var reply string
|
||||
if err := s.agent.RPC("Intention.Apply", &args, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// intentionCreateResponse is the response structure for creating an intention.
|
||||
type intentionCreateResponse struct{ ID string }
|
||||
|
||||
// parseIntentionMatchEntry parses the query parameter for an intention
|
||||
// match query entry.
|
||||
func parseIntentionMatchEntry(input string) (structs.IntentionMatchEntry, error) {
|
||||
var result structs.IntentionMatchEntry
|
||||
result.Namespace = structs.IntentionDefaultNamespace
|
||||
|
||||
// TODO(mitchellh): when namespaces are introduced, set the default
|
||||
// namespace to be the namespace of the requestor.
|
||||
|
||||
// Get the index to the '/'. If it doesn't exist, we have just a name
|
||||
// so just set that and return.
|
||||
idx := strings.IndexByte(input, '/')
|
||||
if idx == -1 {
|
||||
result.Name = input
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result.Namespace = input[:idx]
|
||||
result.Name = input[idx+1:]
|
||||
if strings.IndexByte(result.Name, '/') != -1 {
|
||||
return result, fmt.Errorf("input can contain at most one '/'")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
|
@ -0,0 +1,502 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIntentionsList_empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Make sure an empty list is non-nil.
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/intentions", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionList(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(structs.Intentions)
|
||||
assert.NotNil(value)
|
||||
assert.Len(value, 0)
|
||||
}
|
||||
|
||||
func TestIntentionsList_values(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Create some intentions, note we create the lowest precedence first to test
|
||||
// sorting.
|
||||
for _, v := range []string{"*", "foo", "bar"} {
|
||||
req := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: structs.TestIntention(t),
|
||||
}
|
||||
req.Intention.SourceName = v
|
||||
|
||||
var reply string
|
||||
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
|
||||
}
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET", "/v1/connect/intentions", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionList(resp, req)
|
||||
assert.NoError(err)
|
||||
|
||||
value := obj.(structs.Intentions)
|
||||
assert.Len(value, 3)
|
||||
|
||||
expected := []string{"bar", "foo", "*"}
|
||||
actual := []string{
|
||||
value[0].SourceName,
|
||||
value[1].SourceName,
|
||||
value[2].SourceName,
|
||||
}
|
||||
assert.Equal(expected, actual)
|
||||
}
|
||||
|
||||
func TestIntentionsMatch_basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Create some intentions
|
||||
{
|
||||
insert := [][]string{
|
||||
{"foo", "*", "foo", "*"},
|
||||
{"foo", "*", "foo", "bar"},
|
||||
{"foo", "*", "foo", "baz"}, // shouldn't match
|
||||
{"foo", "*", "bar", "bar"}, // shouldn't match
|
||||
{"foo", "*", "bar", "*"}, // shouldn't match
|
||||
{"foo", "*", "*", "*"},
|
||||
{"bar", "*", "foo", "bar"}, // duplicate destination different source
|
||||
}
|
||||
|
||||
for _, v := range insert {
|
||||
ixn := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: structs.TestIntention(t),
|
||||
}
|
||||
ixn.Intention.SourceNS = v[0]
|
||||
ixn.Intention.SourceName = v[1]
|
||||
ixn.Intention.DestinationNS = v[2]
|
||||
ixn.Intention.DestinationName = v[3]
|
||||
|
||||
// Create
|
||||
var reply string
|
||||
assert.Nil(a.RPC("Intention.Apply", &ixn, &reply))
|
||||
}
|
||||
}
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/match?by=destination&name=foo/bar", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionMatch(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(map[string]structs.Intentions)
|
||||
assert.Len(value, 1)
|
||||
|
||||
var actual [][]string
|
||||
expected := [][]string{
|
||||
{"bar", "*", "foo", "bar"},
|
||||
{"foo", "*", "foo", "bar"},
|
||||
{"foo", "*", "foo", "*"},
|
||||
{"foo", "*", "*", "*"},
|
||||
}
|
||||
for _, ixn := range value["foo/bar"] {
|
||||
actual = append(actual, []string{
|
||||
ixn.SourceNS,
|
||||
ixn.SourceName,
|
||||
ixn.DestinationNS,
|
||||
ixn.DestinationName,
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(expected, actual)
|
||||
}
|
||||
|
||||
func TestIntentionsMatch_noBy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/match?name=foo/bar", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionMatch(resp, req)
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "by")
|
||||
assert.Nil(obj)
|
||||
}
|
||||
|
||||
func TestIntentionsMatch_byInvalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/match?by=datacenter", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionMatch(resp, req)
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "'by' parameter")
|
||||
assert.Nil(obj)
|
||||
}
|
||||
|
||||
func TestIntentionsMatch_noName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/match?by=source", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionMatch(resp, req)
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "'name' not set")
|
||||
assert.Nil(obj)
|
||||
}
|
||||
|
||||
func TestIntentionsCheck_basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Create some intentions
|
||||
{
|
||||
insert := [][]string{
|
||||
{"foo", "*", "foo", "*"},
|
||||
{"foo", "*", "foo", "bar"},
|
||||
{"bar", "*", "foo", "bar"},
|
||||
}
|
||||
|
||||
for _, v := range insert {
|
||||
ixn := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: structs.TestIntention(t),
|
||||
}
|
||||
ixn.Intention.SourceNS = v[0]
|
||||
ixn.Intention.SourceName = v[1]
|
||||
ixn.Intention.DestinationNS = v[2]
|
||||
ixn.Intention.DestinationName = v[3]
|
||||
ixn.Intention.Action = structs.IntentionActionDeny
|
||||
|
||||
// Create
|
||||
var reply string
|
||||
require.Nil(a.RPC("Intention.Apply", &ixn, &reply))
|
||||
}
|
||||
}
|
||||
|
||||
// Request matching intention
|
||||
{
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/test?source=foo/bar&destination=foo/baz", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionCheck(resp, req)
|
||||
require.Nil(err)
|
||||
value := obj.(*structs.IntentionQueryCheckResponse)
|
||||
require.False(value.Allowed)
|
||||
}
|
||||
|
||||
// Request non-matching intention
|
||||
{
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/test?source=foo/bar&destination=bar/qux", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionCheck(resp, req)
|
||||
require.Nil(err)
|
||||
value := obj.(*structs.IntentionQueryCheckResponse)
|
||||
require.True(value.Allowed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntentionsCheck_noSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/test?destination=B", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionCheck(resp, req)
|
||||
require.NotNil(err)
|
||||
require.Contains(err.Error(), "'source' not set")
|
||||
require.Nil(obj)
|
||||
}
|
||||
|
||||
func TestIntentionsCheck_noDestination(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Request
|
||||
req, _ := http.NewRequest("GET",
|
||||
"/v1/connect/intentions/test?source=B", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionCheck(resp, req)
|
||||
require.NotNil(err)
|
||||
require.Contains(err.Error(), "'destination' not set")
|
||||
require.Nil(obj)
|
||||
}
|
||||
|
||||
func TestIntentionsCreate_good(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Make sure an empty list is non-nil.
|
||||
args := structs.TestIntention(t)
|
||||
args.SourceName = "foo"
|
||||
req, _ := http.NewRequest("POST", "/v1/connect/intentions", jsonReader(args))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionCreate(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(intentionCreateResponse)
|
||||
assert.NotEqual("", value.ID)
|
||||
|
||||
// Read the value
|
||||
{
|
||||
req := &structs.IntentionQueryRequest{
|
||||
Datacenter: "dc1",
|
||||
IntentionID: value.ID,
|
||||
}
|
||||
var resp structs.IndexedIntentions
|
||||
assert.Nil(a.RPC("Intention.Get", req, &resp))
|
||||
assert.Len(resp.Intentions, 1)
|
||||
actual := resp.Intentions[0]
|
||||
assert.Equal("foo", actual.SourceName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntentionsCreate_noBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// Create with no body
|
||||
req, _ := http.NewRequest("POST", "/v1/connect/intentions", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.IntentionCreate(resp, req)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIntentionsSpecificGet_good(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// The intention
|
||||
ixn := structs.TestIntention(t)
|
||||
|
||||
// Create an intention directly
|
||||
var reply string
|
||||
{
|
||||
req := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: ixn,
|
||||
}
|
||||
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
|
||||
}
|
||||
|
||||
// Get the value
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf("/v1/connect/intentions/%s", reply), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionSpecific(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(*structs.Intention)
|
||||
assert.Equal(reply, value.ID)
|
||||
|
||||
ixn.ID = value.ID
|
||||
ixn.RaftIndex = value.RaftIndex
|
||||
ixn.CreatedAt, ixn.UpdatedAt = value.CreatedAt, value.UpdatedAt
|
||||
assert.Equal(ixn, value)
|
||||
}
|
||||
|
||||
func TestIntentionsSpecificUpdate_good(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// The intention
|
||||
ixn := structs.TestIntention(t)
|
||||
|
||||
// Create an intention directly
|
||||
var reply string
|
||||
{
|
||||
req := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: ixn,
|
||||
}
|
||||
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
|
||||
}
|
||||
|
||||
// Update the intention
|
||||
ixn.ID = "bogus"
|
||||
ixn.SourceName = "bar"
|
||||
req, _ := http.NewRequest("PUT", fmt.Sprintf("/v1/connect/intentions/%s", reply), jsonReader(ixn))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionSpecific(resp, req)
|
||||
assert.Nil(err)
|
||||
|
||||
value := obj.(intentionCreateResponse)
|
||||
assert.Equal(reply, value.ID)
|
||||
|
||||
// Read the value
|
||||
{
|
||||
req := &structs.IntentionQueryRequest{
|
||||
Datacenter: "dc1",
|
||||
IntentionID: reply,
|
||||
}
|
||||
var resp structs.IndexedIntentions
|
||||
assert.Nil(a.RPC("Intention.Get", req, &resp))
|
||||
assert.Len(resp.Intentions, 1)
|
||||
actual := resp.Intentions[0]
|
||||
assert.Equal("bar", actual.SourceName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntentionsSpecificDelete_good(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := NewTestAgent(t.Name(), "")
|
||||
defer a.Shutdown()
|
||||
|
||||
// The intention
|
||||
ixn := structs.TestIntention(t)
|
||||
ixn.SourceName = "foo"
|
||||
|
||||
// Create an intention directly
|
||||
var reply string
|
||||
{
|
||||
req := structs.IntentionRequest{
|
||||
Datacenter: "dc1",
|
||||
Op: structs.IntentionOpCreate,
|
||||
Intention: ixn,
|
||||
}
|
||||
assert.Nil(a.RPC("Intention.Apply", &req, &reply))
|
||||
}
|
||||
|
||||
// Sanity check that the intention exists
|
||||
{
|
||||
req := &structs.IntentionQueryRequest{
|
||||
Datacenter: "dc1",
|
||||
IntentionID: reply,
|
||||
}
|
||||
var resp structs.IndexedIntentions
|
||||
assert.Nil(a.RPC("Intention.Get", req, &resp))
|
||||
assert.Len(resp.Intentions, 1)
|
||||
actual := resp.Intentions[0]
|
||||
assert.Equal("foo", actual.SourceName)
|
||||
}
|
||||
|
||||
// Delete the intention
|
||||
req, _ := http.NewRequest("DELETE", fmt.Sprintf("/v1/connect/intentions/%s", reply), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.IntentionSpecific(resp, req)
|
||||
assert.Nil(err)
|
||||
assert.Equal(true, obj)
|
||||
|
||||
// Verify the intention is gone
|
||||
{
|
||||
req := &structs.IntentionQueryRequest{
|
||||
Datacenter: "dc1",
|
||||
IntentionID: reply,
|
||||
}
|
||||
var resp structs.IndexedIntentions
|
||||
err := a.RPC("Intention.Get", req, &resp)
|
||||
assert.NotNil(err)
|
||||
assert.Contains(err.Error(), "not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseIntentionMatchEntry(t *testing.T) {
|
||||
cases := []struct {
|
||||
Input string
|
||||
Expected structs.IntentionMatchEntry
|
||||
Err bool
|
||||
}{
|
||||
{
|
||||
"foo",
|
||||
structs.IntentionMatchEntry{
|
||||
Namespace: structs.IntentionDefaultNamespace,
|
||||
Name: "foo",
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
"foo/bar",
|
||||
structs.IntentionMatchEntry{
|
||||
Namespace: "foo",
|
||||
Name: "bar",
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
"foo/bar/baz",
|
||||
structs.IntentionMatchEntry{},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.Input, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
actual, err := parseIntentionMatchEntry(tc.Input)
|
||||
assert.Equal(err != nil, tc.Err, err)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(tc.Expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@ package local
|
|||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -10,6 +11,8 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/agent/token"
|
||||
|
@ -27,6 +30,8 @@ type Config struct {
|
|||
NodeID types.NodeID
|
||||
NodeName string
|
||||
TaggedAddresses map[string]string
|
||||
ProxyBindMinPort int
|
||||
ProxyBindMaxPort int
|
||||
}
|
||||
|
||||
// ServiceState describes the state of a service record.
|
||||
|
@ -107,6 +112,32 @@ type rpc interface {
|
|||
RPC(method string, args interface{}, reply interface{}) error
|
||||
}
|
||||
|
||||
// ManagedProxy represents the local state for a registered proxy instance.
|
||||
type ManagedProxy struct {
|
||||
Proxy *structs.ConnectManagedProxy
|
||||
|
||||
// ProxyToken is a special local-only security token that grants the bearer
|
||||
// access to the proxy's config as well as allowing it to request certificates
|
||||
// on behalf of the target service. Certain connect endpoints will validate
|
||||
// against this token and if it matches will then use the target service's
|
||||
// registration token to actually authenticate the upstream RPC on behalf of
|
||||
// the service. This token is passed securely to the proxy process via ENV
|
||||
// vars and should never be exposed any other way. Unmanaged proxies will
|
||||
// never see this and need to use service-scoped ACL tokens distributed
|
||||
// externally. It is persisted in the local state to allow authenticating
|
||||
// running proxies after the agent restarts.
|
||||
//
|
||||
// TODO(banks): In theory we only need to persist this at all to _validate_
|
||||
// which means we could keep only a hash in memory and on disk and only pass
|
||||
// the actual token to the process on startup. That would require a bit of
|
||||
// refactoring though to have the required interaction with the proxy manager.
|
||||
ProxyToken string
|
||||
|
||||
// WatchCh is a close-only chan that is closed when the proxy is removed or
|
||||
// updated.
|
||||
WatchCh chan struct{}
|
||||
}
|
||||
|
||||
// State is used to represent the node's services,
|
||||
// and checks. We use it to perform anti-entropy with the
|
||||
// catalog representation
|
||||
|
@ -150,9 +181,23 @@ type State struct {
|
|||
|
||||
// tokens contains the ACL tokens
|
||||
tokens *token.Store
|
||||
|
||||
// managedProxies is a map of all managed connect proxies registered locally on
|
||||
// this agent. This is NOT kept in sync with servers since it's agent-local
|
||||
// config only. Proxy instances have separate service registrations in the
|
||||
// services map above which are kept in sync via anti-entropy. Un-managed
|
||||
// proxies (that registered themselves separately from the service
|
||||
// registration) do not appear here as the agent doesn't need to manage their
|
||||
// process nor config. The _do_ still exist in services above though as
|
||||
// services with Kind == connect-proxy.
|
||||
//
|
||||
// managedProxyHandlers is a map of registered channel listeners that
|
||||
// are sent a message each time a proxy changes via Add or RemoveProxy.
|
||||
managedProxies map[string]*ManagedProxy
|
||||
managedProxyHandlers map[chan<- struct{}]struct{}
|
||||
}
|
||||
|
||||
// NewLocalState creates a new local state for the agent.
|
||||
// NewState creates a new local state for the agent.
|
||||
func NewState(c Config, lg *log.Logger, tokens *token.Store) *State {
|
||||
l := &State{
|
||||
config: c,
|
||||
|
@ -161,6 +206,8 @@ func NewState(c Config, lg *log.Logger, tokens *token.Store) *State {
|
|||
checks: make(map[types.CheckID]*CheckState),
|
||||
metadata: make(map[string]string),
|
||||
tokens: tokens,
|
||||
managedProxies: make(map[string]*ManagedProxy),
|
||||
managedProxyHandlers: make(map[chan<- struct{}]struct{}),
|
||||
}
|
||||
l.SetDiscardCheckOutput(c.DiscardCheckOutput)
|
||||
return l
|
||||
|
@ -529,6 +576,204 @@ func (l *State) CriticalCheckStates() map[types.CheckID]*CheckState {
|
|||
return m
|
||||
}
|
||||
|
||||
// AddProxy is used to add a connect proxy entry to the local state. This
|
||||
// assumes the proxy's NodeService is already registered via Agent.AddService
|
||||
// (since that has to do other book keeping). The token passed here is the ACL
|
||||
// token the service used to register itself so must have write on service
|
||||
// record. AddProxy returns the newly added proxy and an error.
|
||||
//
|
||||
// The restoredProxyToken argument should only be used when restoring proxy
|
||||
// definitions from disk; new proxies must leave it blank to get a new token
|
||||
// assigned. We need to restore from disk to enable to continue authenticating
|
||||
// running proxies that already had that credential injected.
|
||||
func (l *State) AddProxy(proxy *structs.ConnectManagedProxy, token,
|
||||
restoredProxyToken string) (*ManagedProxy, error) {
|
||||
if proxy == nil {
|
||||
return nil, fmt.Errorf("no proxy")
|
||||
}
|
||||
|
||||
// Lookup the local service
|
||||
target := l.Service(proxy.TargetServiceID)
|
||||
if target == nil {
|
||||
return nil, fmt.Errorf("target service ID %s not registered",
|
||||
proxy.TargetServiceID)
|
||||
}
|
||||
|
||||
// Get bind info from config
|
||||
cfg, err := proxy.ParseConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Construct almost all of the NodeService that needs to be registered by the
|
||||
// caller outside of the lock.
|
||||
svc := &structs.NodeService{
|
||||
Kind: structs.ServiceKindConnectProxy,
|
||||
ID: target.ID + "-proxy",
|
||||
Service: target.ID + "-proxy",
|
||||
ProxyDestination: target.Service,
|
||||
Address: cfg.BindAddress,
|
||||
Port: cfg.BindPort,
|
||||
}
|
||||
|
||||
// Lock now. We can't lock earlier as l.Service would deadlock and shouldn't
|
||||
// anyway to minimise the critical section.
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
pToken := restoredProxyToken
|
||||
|
||||
// Does this proxy instance allready exist?
|
||||
if existing, ok := l.managedProxies[svc.ID]; ok {
|
||||
// Keep the existing proxy token so we don't have to restart proxy to
|
||||
// re-inject token.
|
||||
pToken = existing.ProxyToken
|
||||
// If the user didn't explicitly change the port, use the old one instead of
|
||||
// assigning new.
|
||||
if svc.Port < 1 {
|
||||
svc.Port = existing.Proxy.ProxyService.Port
|
||||
}
|
||||
} else if proxyService, ok := l.services[svc.ID]; ok {
|
||||
// The proxy-service already exists so keep the port that got assigned. This
|
||||
// happens on reload from disk since service definitions are reloaded first.
|
||||
svc.Port = proxyService.Service.Port
|
||||
}
|
||||
|
||||
// If this is a new instance, generate a token
|
||||
if pToken == "" {
|
||||
pToken, err = uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate port if needed (min and max inclusive).
|
||||
rangeLen := l.config.ProxyBindMaxPort - l.config.ProxyBindMinPort + 1
|
||||
if svc.Port < 1 && l.config.ProxyBindMinPort > 0 && rangeLen > 0 {
|
||||
// This should be a really short list so don't bother optimising lookup yet.
|
||||
OUTER:
|
||||
for _, offset := range rand.Perm(rangeLen) {
|
||||
p := l.config.ProxyBindMinPort + offset
|
||||
// See if this port was already allocated to another proxy
|
||||
for _, other := range l.managedProxies {
|
||||
if other.Proxy.ProxyService.Port == p {
|
||||
// allready taken, skip to next random pick in the range
|
||||
continue OUTER
|
||||
}
|
||||
}
|
||||
// We made it through all existing proxies without a match so claim this one
|
||||
svc.Port = p
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no ports left (or auto ports disabled) fail
|
||||
if svc.Port < 1 {
|
||||
return nil, fmt.Errorf("no port provided for proxy bind_port and none "+
|
||||
" left in the allocated range [%d, %d]", l.config.ProxyBindMinPort,
|
||||
l.config.ProxyBindMaxPort)
|
||||
}
|
||||
|
||||
proxy.ProxyService = svc
|
||||
|
||||
// All set, add the proxy and return the service
|
||||
if old, ok := l.managedProxies[svc.ID]; ok {
|
||||
// Notify watchers of the existing proxy config that it's changing. Note
|
||||
// this is safe here even before the map is updated since we still hold the
|
||||
// state lock and the watcher can't re-read the new config until we return
|
||||
// anyway.
|
||||
close(old.WatchCh)
|
||||
}
|
||||
l.managedProxies[svc.ID] = &ManagedProxy{
|
||||
Proxy: proxy,
|
||||
ProxyToken: pToken,
|
||||
WatchCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Notify
|
||||
for ch := range l.managedProxyHandlers {
|
||||
// Do not block
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// No need to trigger sync as proxy state is local only.
|
||||
return l.managedProxies[svc.ID], nil
|
||||
}
|
||||
|
||||
// RemoveProxy is used to remove a proxy entry from the local state.
|
||||
// This returns the proxy that was removed.
|
||||
func (l *State) RemoveProxy(id string) (*ManagedProxy, error) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
p := l.managedProxies[id]
|
||||
if p == nil {
|
||||
return nil, fmt.Errorf("Proxy %s does not exist", id)
|
||||
}
|
||||
delete(l.managedProxies, id)
|
||||
|
||||
// Notify watchers of the existing proxy config that it's changed.
|
||||
close(p.WatchCh)
|
||||
|
||||
// Notify
|
||||
for ch := range l.managedProxyHandlers {
|
||||
// Do not block
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// No need to trigger sync as proxy state is local only.
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Proxy returns the local proxy state.
|
||||
func (l *State) Proxy(id string) *ManagedProxy {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
return l.managedProxies[id]
|
||||
}
|
||||
|
||||
// Proxies returns the locally registered proxies.
|
||||
func (l *State) Proxies() map[string]*ManagedProxy {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
|
||||
m := make(map[string]*ManagedProxy)
|
||||
for id, p := range l.managedProxies {
|
||||
m[id] = p
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// NotifyProxy will register a channel to receive messages when the
|
||||
// configuration or set of proxies changes. This will not block on
|
||||
// channel send so ensure the channel has a buffer. Note that any buffer
|
||||
// size is generally fine since actual data is not sent over the channel,
|
||||
// so a dropped send due to a full buffer does not result in any loss of
|
||||
// data. The fact that a buffer already contains a notification means that
|
||||
// the receiver will still be notified that changes occurred.
|
||||
//
|
||||
// NOTE(mitchellh): This could be more generalized but for my use case I
|
||||
// only needed proxy events. In the future if it were to be generalized I
|
||||
// would add a new Notify method and remove the proxy-specific ones.
|
||||
func (l *State) NotifyProxy(ch chan<- struct{}) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.managedProxyHandlers[ch] = struct{}{}
|
||||
}
|
||||
|
||||
// StopNotifyProxy will deregister a channel receiving proxy notifications.
|
||||
// Pair this with all calls to NotifyProxy to clean up state.
|
||||
func (l *State) StopNotifyProxy(ch chan<- struct{}) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
delete(l.managedProxyHandlers, ch)
|
||||
}
|
||||
|
||||
// Metadata returns the local node metadata fields that the
|
||||
// agent is aware of and are being kept in sync with the server
|
||||
func (l *State) Metadata() map[string]string {
|
||||
|
|
|
@ -3,10 +3,16 @@ package local_test
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-memdb"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/agent/config"
|
||||
"github.com/hashicorp/consul/agent/local"
|
||||
|
@ -16,6 +22,7 @@ import (
|
|||
"github.com/hashicorp/consul/testutil/retry"
|
||||
"github.com/hashicorp/consul/types"
|
||||
"github.com/pascaldekloe/goe/verify"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAgentAntiEntropy_Services(t *testing.T) {
|
||||
|
@ -224,6 +231,145 @@ func TestAgentAntiEntropy_Services(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAgentAntiEntropy_Services_ConnectProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert := assert.New(t)
|
||||
a := &agent.TestAgent{Name: t.Name()}
|
||||
a.Start()
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register node info
|
||||
var out struct{}
|
||||
args := &structs.RegisterRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: a.Config.NodeName,
|
||||
Address: "127.0.0.1",
|
||||
}
|
||||
|
||||
// Exists both same (noop)
|
||||
srv1 := &structs.NodeService{
|
||||
Kind: structs.ServiceKindConnectProxy,
|
||||
ID: "mysql-proxy",
|
||||
Service: "mysql-proxy",
|
||||
Port: 5000,
|
||||
ProxyDestination: "db",
|
||||
}
|
||||
a.State.AddService(srv1, "")
|
||||
args.Service = srv1
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
// Exists both, different (update)
|
||||
srv2 := &structs.NodeService{
|
||||
ID: "redis-proxy",
|
||||
Service: "redis-proxy",
|
||||
Port: 8000,
|
||||
Kind: structs.ServiceKindConnectProxy,
|
||||
ProxyDestination: "redis",
|
||||
}
|
||||
a.State.AddService(srv2, "")
|
||||
|
||||
srv2_mod := new(structs.NodeService)
|
||||
*srv2_mod = *srv2
|
||||
srv2_mod.Port = 9000
|
||||
args.Service = srv2_mod
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
// Exists local (create)
|
||||
srv3 := &structs.NodeService{
|
||||
ID: "web-proxy",
|
||||
Service: "web-proxy",
|
||||
Port: 80,
|
||||
Kind: structs.ServiceKindConnectProxy,
|
||||
ProxyDestination: "web",
|
||||
}
|
||||
a.State.AddService(srv3, "")
|
||||
|
||||
// Exists remote (delete)
|
||||
srv4 := &structs.NodeService{
|
||||
ID: "lb-proxy",
|
||||
Service: "lb-proxy",
|
||||
Port: 443,
|
||||
Kind: structs.ServiceKindConnectProxy,
|
||||
ProxyDestination: "lb",
|
||||
}
|
||||
args.Service = srv4
|
||||
assert.Nil(a.RPC("Catalog.Register", args, &out))
|
||||
|
||||
// Exists local, in sync, remote missing (create)
|
||||
srv5 := &structs.NodeService{
|
||||
ID: "cache-proxy",
|
||||
Service: "cache-proxy",
|
||||
Port: 11211,
|
||||
Kind: structs.ServiceKindConnectProxy,
|
||||
ProxyDestination: "cache-proxy",
|
||||
}
|
||||
a.State.SetServiceState(&local.ServiceState{
|
||||
Service: srv5,
|
||||
InSync: true,
|
||||
})
|
||||
|
||||
assert.Nil(a.State.SyncFull())
|
||||
|
||||
var services structs.IndexedNodeServices
|
||||
req := structs.NodeSpecificRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: a.Config.NodeName,
|
||||
}
|
||||
assert.Nil(a.RPC("Catalog.NodeServices", &req, &services))
|
||||
|
||||
// We should have 5 services (consul included)
|
||||
assert.Len(services.NodeServices.Services, 5)
|
||||
|
||||
// All the services should match
|
||||
for id, serv := range services.NodeServices.Services {
|
||||
serv.CreateIndex, serv.ModifyIndex = 0, 0
|
||||
switch id {
|
||||
case "mysql-proxy":
|
||||
assert.Equal(srv1, serv)
|
||||
case "redis-proxy":
|
||||
assert.Equal(srv2, serv)
|
||||
case "web-proxy":
|
||||
assert.Equal(srv3, serv)
|
||||
case "cache-proxy":
|
||||
assert.Equal(srv5, serv)
|
||||
case structs.ConsulServiceID:
|
||||
// ignore
|
||||
default:
|
||||
t.Fatalf("unexpected service: %v", id)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Nil(servicesInSync(a.State, 4))
|
||||
|
||||
// Remove one of the services
|
||||
a.State.RemoveService("cache-proxy")
|
||||
assert.Nil(a.State.SyncFull())
|
||||
assert.Nil(a.RPC("Catalog.NodeServices", &req, &services))
|
||||
|
||||
// We should have 4 services (consul included)
|
||||
assert.Len(services.NodeServices.Services, 4)
|
||||
|
||||
// All the services should match
|
||||
for id, serv := range services.NodeServices.Services {
|
||||
serv.CreateIndex, serv.ModifyIndex = 0, 0
|
||||
switch id {
|
||||
case "mysql-proxy":
|
||||
assert.Equal(srv1, serv)
|
||||
case "redis-proxy":
|
||||
assert.Equal(srv2, serv)
|
||||
case "web-proxy":
|
||||
assert.Equal(srv3, serv)
|
||||
case structs.ConsulServiceID:
|
||||
// ignore
|
||||
default:
|
||||
t.Fatalf("unexpected service: %v", id)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Nil(servicesInSync(a.State, 3))
|
||||
}
|
||||
|
||||
func TestAgentAntiEntropy_EnableTagOverride(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := &agent.TestAgent{Name: t.Name()}
|
||||
|
@ -1524,3 +1670,263 @@ func checksInSync(state *local.State, wantChecks int) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestStateProxyManagement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
state := local.NewState(local.Config{
|
||||
ProxyBindMinPort: 20000,
|
||||
ProxyBindMaxPort: 20001,
|
||||
}, log.New(os.Stderr, "", log.LstdFlags), &token.Store{})
|
||||
|
||||
// Stub state syncing
|
||||
state.TriggerSyncChanges = func() {}
|
||||
|
||||
p1 := structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
TargetServiceID: "web",
|
||||
}
|
||||
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
|
||||
_, err := state.AddProxy(&p1, "fake-token", "")
|
||||
require.Error(err, "should fail as the target service isn't registered")
|
||||
|
||||
// Sanity check done, lets add a couple of target services to the state
|
||||
err = state.AddService(&structs.NodeService{
|
||||
Service: "web",
|
||||
}, "fake-token-web")
|
||||
require.NoError(err)
|
||||
err = state.AddService(&structs.NodeService{
|
||||
Service: "cache",
|
||||
}, "fake-token-cache")
|
||||
require.NoError(err)
|
||||
require.NoError(err)
|
||||
err = state.AddService(&structs.NodeService{
|
||||
Service: "db",
|
||||
}, "fake-token-db")
|
||||
require.NoError(err)
|
||||
|
||||
// Should work now
|
||||
pstate, err := state.AddProxy(&p1, "fake-token", "")
|
||||
require.NoError(err)
|
||||
|
||||
svc := pstate.Proxy.ProxyService
|
||||
assert.Equal("web-proxy", svc.ID)
|
||||
assert.Equal("web-proxy", svc.Service)
|
||||
assert.Equal(structs.ServiceKindConnectProxy, svc.Kind)
|
||||
assert.Equal("web", svc.ProxyDestination)
|
||||
assert.Equal("", svc.Address, "should have empty address by default")
|
||||
// Port is non-deterministic but could be either of 20000 or 20001
|
||||
assert.Contains([]int{20000, 20001}, svc.Port)
|
||||
|
||||
{
|
||||
// Re-registering same proxy again should not pick a random port but re-use
|
||||
// the assigned one. It should also keep the same proxy token since we don't
|
||||
// want to force restart for config change.
|
||||
pstateDup, err := state.AddProxy(&p1, "fake-token", "")
|
||||
require.NoError(err)
|
||||
svcDup := pstateDup.Proxy.ProxyService
|
||||
|
||||
assert.Equal("web-proxy", svcDup.ID)
|
||||
assert.Equal("web-proxy", svcDup.Service)
|
||||
assert.Equal(structs.ServiceKindConnectProxy, svcDup.Kind)
|
||||
assert.Equal("web", svcDup.ProxyDestination)
|
||||
assert.Equal("", svcDup.Address, "should have empty address by default")
|
||||
// Port must be same as before
|
||||
assert.Equal(svc.Port, svcDup.Port)
|
||||
// Same ProxyToken
|
||||
assert.Equal(pstate.ProxyToken, pstateDup.ProxyToken)
|
||||
}
|
||||
|
||||
// Let's register a notifier now
|
||||
notifyCh := make(chan struct{}, 1)
|
||||
state.NotifyProxy(notifyCh)
|
||||
defer state.StopNotifyProxy(notifyCh)
|
||||
assert.Empty(notifyCh)
|
||||
drainCh(notifyCh)
|
||||
|
||||
// Second proxy should claim other port
|
||||
p2 := p1
|
||||
p2.TargetServiceID = "cache"
|
||||
pstate2, err := state.AddProxy(&p2, "fake-token", "")
|
||||
require.NoError(err)
|
||||
svc2 := pstate2.Proxy.ProxyService
|
||||
assert.Contains([]int{20000, 20001}, svc2.Port)
|
||||
assert.NotEqual(svc.Port, svc2.Port)
|
||||
|
||||
// Should have a notification
|
||||
assert.NotEmpty(notifyCh)
|
||||
drainCh(notifyCh)
|
||||
|
||||
// Store this for later
|
||||
p2token := state.Proxy(svc2.ID).ProxyToken
|
||||
|
||||
// Third proxy should fail as all ports are used
|
||||
p3 := p1
|
||||
p3.TargetServiceID = "db"
|
||||
_, err = state.AddProxy(&p3, "fake-token", "")
|
||||
require.Error(err)
|
||||
|
||||
// Should have a notification but we'll do nothing so that the next
|
||||
// receive should block (we set cap == 1 above)
|
||||
|
||||
// But if we set a port explicitly it should be OK
|
||||
p3.Config = map[string]interface{}{
|
||||
"bind_port": 1234,
|
||||
"bind_address": "0.0.0.0",
|
||||
}
|
||||
pstate3, err := state.AddProxy(&p3, "fake-token", "")
|
||||
require.NoError(err)
|
||||
svc3 := pstate3.Proxy.ProxyService
|
||||
require.Equal("0.0.0.0", svc3.Address)
|
||||
require.Equal(1234, svc3.Port)
|
||||
|
||||
// Should have a notification
|
||||
assert.NotEmpty(notifyCh)
|
||||
drainCh(notifyCh)
|
||||
|
||||
// Update config of an already registered proxy should work
|
||||
p3updated := p3
|
||||
p3updated.Config["foo"] = "bar"
|
||||
// Setup multiple watchers who should all witness the change
|
||||
gotP3 := state.Proxy(svc3.ID)
|
||||
require.NotNil(gotP3)
|
||||
var ws memdb.WatchSet
|
||||
ws.Add(gotP3.WatchCh)
|
||||
pstate3, err = state.AddProxy(&p3updated, "fake-token", "")
|
||||
require.NoError(err)
|
||||
svc3 = pstate3.Proxy.ProxyService
|
||||
require.Equal("0.0.0.0", svc3.Address)
|
||||
require.Equal(1234, svc3.Port)
|
||||
gotProxy3 := state.Proxy(svc3.ID)
|
||||
require.NotNil(gotProxy3)
|
||||
require.Equal(p3updated.Config, gotProxy3.Proxy.Config)
|
||||
assert.False(ws.Watch(time.After(500*time.Millisecond)),
|
||||
"watch should have fired so ws.Watch should not timeout")
|
||||
|
||||
drainCh(notifyCh)
|
||||
|
||||
// Remove one of the auto-assigned proxies
|
||||
_, err = state.RemoveProxy(svc2.ID)
|
||||
require.NoError(err)
|
||||
|
||||
// Should have a notification
|
||||
assert.NotEmpty(notifyCh)
|
||||
drainCh(notifyCh)
|
||||
|
||||
// Should be able to create a new proxy for that service with the port (it
|
||||
// should have been "freed").
|
||||
p4 := p2
|
||||
pstate4, err := state.AddProxy(&p4, "fake-token", "")
|
||||
require.NoError(err)
|
||||
svc4 := pstate4.Proxy.ProxyService
|
||||
assert.Contains([]int{20000, 20001}, svc2.Port)
|
||||
assert.Equal(svc4.Port, svc2.Port, "should get the same port back that we freed")
|
||||
|
||||
// Remove a proxy that doesn't exist should error
|
||||
_, err = state.RemoveProxy("nope")
|
||||
require.Error(err)
|
||||
|
||||
assert.Equal(&p4, state.Proxy(p4.ProxyService.ID).Proxy,
|
||||
"should fetch the right proxy details")
|
||||
assert.Nil(state.Proxy("nope"))
|
||||
|
||||
proxies := state.Proxies()
|
||||
assert.Len(proxies, 3)
|
||||
assert.Equal(&p1, proxies[svc.ID].Proxy)
|
||||
assert.Equal(&p4, proxies[svc4.ID].Proxy)
|
||||
assert.Equal(&p3, proxies[svc3.ID].Proxy)
|
||||
|
||||
tokens := make([]string, 4)
|
||||
tokens[0] = state.Proxy(svc.ID).ProxyToken
|
||||
// p2 not registered anymore but lets make sure p4 got a new token when it
|
||||
// re-registered with same ID.
|
||||
tokens[1] = p2token
|
||||
tokens[2] = state.Proxy(svc2.ID).ProxyToken
|
||||
tokens[3] = state.Proxy(svc3.ID).ProxyToken
|
||||
|
||||
// Quick check all are distinct
|
||||
for i := 0; i < len(tokens)-1; i++ {
|
||||
assert.Len(tokens[i], 36) // Sanity check for UUIDish thing.
|
||||
for j := i + 1; j < len(tokens); j++ {
|
||||
assert.NotEqual(tokens[i], tokens[j], "tokens for proxy %d and %d match",
|
||||
i+1, j+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tests the logic for retaining tokens and ports through restore (i.e.
|
||||
// proxy-service already restored and token passed in externally)
|
||||
func TestStateProxyRestore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
state := local.NewState(local.Config{
|
||||
// Wide random range to make it very unlikely to pass by chance
|
||||
ProxyBindMinPort: 10000,
|
||||
ProxyBindMaxPort: 20000,
|
||||
}, log.New(os.Stderr, "", log.LstdFlags), &token.Store{})
|
||||
|
||||
// Stub state syncing
|
||||
state.TriggerSyncChanges = func() {}
|
||||
|
||||
webSvc := structs.NodeService{
|
||||
Service: "web",
|
||||
}
|
||||
|
||||
p1 := structs.ConnectManagedProxy{
|
||||
ExecMode: structs.ProxyExecModeDaemon,
|
||||
Command: []string{"consul", "connect", "proxy"},
|
||||
TargetServiceID: "web",
|
||||
}
|
||||
|
||||
p2 := p1
|
||||
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
|
||||
// Add a target service
|
||||
require.NoError(state.AddService(&webSvc, "fake-token-web"))
|
||||
|
||||
// Add the proxy for first time to get the proper service definition to
|
||||
// register
|
||||
pstate, err := state.AddProxy(&p1, "fake-token", "")
|
||||
require.NoError(err)
|
||||
|
||||
// Now start again with a brand new state
|
||||
state2 := local.NewState(local.Config{
|
||||
// Wide random range to make it very unlikely to pass by chance
|
||||
ProxyBindMinPort: 10000,
|
||||
ProxyBindMaxPort: 20000,
|
||||
}, log.New(os.Stderr, "", log.LstdFlags), &token.Store{})
|
||||
|
||||
// Stub state syncing
|
||||
state2.TriggerSyncChanges = func() {}
|
||||
|
||||
// Register the target service
|
||||
require.NoError(state2.AddService(&webSvc, "fake-token-web"))
|
||||
|
||||
// "Restore" the proxy service
|
||||
require.NoError(state.AddService(p1.ProxyService, "fake-token-web"))
|
||||
|
||||
// Now we can AddProxy with the "restored" token
|
||||
pstate2, err := state.AddProxy(&p2, "fake-token", pstate.ProxyToken)
|
||||
require.NoError(err)
|
||||
|
||||
// Check it still has the same port and token as before
|
||||
assert.Equal(pstate.ProxyToken, pstate2.ProxyToken)
|
||||
assert.Equal(p1.ProxyService.Port, p2.ProxyService.Port)
|
||||
}
|
||||
|
||||
// drainCh drains a channel by reading messages until it would block.
|
||||
func drainCh(ch chan struct{}) {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
package local
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/consul/agent/token"
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
)
|
||||
|
||||
// TestState returns a configured *State for testing.
|
||||
func TestState(t testing.T) *State {
|
||||
result := NewState(Config{
|
||||
ProxyBindMinPort: 20000,
|
||||
ProxyBindMaxPort: 20500,
|
||||
}, log.New(os.Stderr, "", log.LstdFlags), &token.Store{})
|
||||
result.TriggerSyncChanges = func() {}
|
||||
return result
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue