open-consul/agent/consul/state/connect_ca_test.go

498 lines
11 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package state
import (
"reflect"
"testing"
"github.com/hashicorp/consul/proto/private/prototest"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/go-memdb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
)
func TestStore_CAConfig(t *testing.T) {
s := testStateStore(t)
expected := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": "asdf",
"RootCert": "qwer",
},
}
if err := s.CASetConfig(0, expected); err != nil {
t.Fatal(err)
}
idx, config, err := s.CAConfig(nil)
if err != nil {
t.Fatal(err)
}
if idx != 0 {
t.Fatalf("bad: %d", idx)
}
if !reflect.DeepEqual(expected, config) {
t.Fatalf("bad: %#v, %#v", expected, config)
}
}
func TestStore_CAConfigCAS(t *testing.T) {
s := testStateStore(t)
expected := &structs.CAConfiguration{
Provider: "consul",
}
if err := s.CASetConfig(0, expected); err != nil {
t.Fatal(err)
}
// Do an extra operation to move the index up by 1 for the
// check-and-set operation after this
if err := s.CASetConfig(1, expected); err != nil {
t.Fatal(err)
}
// Do a CAS with an index lower than the entry
ok, err := s.CACheckAndSetConfig(2, 0, &structs.CAConfiguration{
Provider: "static",
})
require.False(t, ok)
testutil.RequireErrorContains(t, err, "ModifyIndex did not match existing")
// Check that the index is untouched and the entry
// has not been updated.
idx, config, err := s.CAConfig(nil)
if err != nil {
t.Fatal(err)
}
if idx != 1 {
t.Fatalf("bad: %d", idx)
}
if config.Provider != "consul" {
t.Fatalf("bad: %#v", config)
}
// Do another CAS, this time with the correct index
ok, err = s.CACheckAndSetConfig(2, 1, &structs.CAConfiguration{
Provider: "static",
})
if !ok || err != nil {
t.Fatalf("expected (true, nil), got: (%v, %#v)", ok, err)
}
// Make sure the config was updated
idx, config, err = s.CAConfig(nil)
if err != nil {
t.Fatal(err)
}
if idx != 2 {
t.Fatalf("bad: %d", idx)
}
if config.Provider != "static" {
t.Fatalf("bad: %#v", config)
}
}
func TestStore_CAConfig_Snapshot_Restore(t *testing.T) {
s := testStateStore(t)
before := &structs.CAConfiguration{
Provider: "consul",
Config: map[string]interface{}{
"PrivateKey": "asdf",
"RootCert": "qwer",
},
}
if err := s.CASetConfig(99, before); err != nil {
t.Fatal(err)
}
snap := s.Snapshot()
defer snap.Close()
after := &structs.CAConfiguration{
Provider: "static",
Config: map[string]interface{}{},
}
if err := s.CASetConfig(100, after); err != nil {
t.Fatal(err)
}
snapped, err := snap.CAConfig()
if err != nil {
t.Fatalf("err: %s", err)
}
require.Equal(t, snapped, before)
s2 := testStateStore(t)
restore := s2.Restore()
if err := restore.CAConfig(snapped); err != nil {
t.Fatalf("err: %s", err)
}
restore.Commit()
idx, res, err := s2.CAConfig(nil)
if err != nil {
t.Fatalf("err: %s", err)
}
if idx != 99 {
t.Fatalf("bad index: %d", idx)
}
require.Equal(t, res, before)
}
// Make sure we handle the case of a leftover blank CA config that
// got stuck in a snapshot, as in https://github.com/hashicorp/consul/issues/4954
func TestStore_CAConfig_Snapshot_Restore_BlankConfig(t *testing.T) {
s := testStateStore(t)
before := &structs.CAConfiguration{}
if err := s.CASetConfig(99, before); err != nil {
t.Fatal(err)
}
snap := s.Snapshot()
defer snap.Close()
snapped, err := snap.CAConfig()
if err != nil {
t.Fatalf("err: %s", err)
}
require.Equal(t, snapped, before)
s2 := testStateStore(t)
restore := s2.Restore()
if err := restore.CAConfig(snapped); err != nil {
t.Fatalf("err: %s", err)
}
restore.Commit()
idx, result, err := s2.CAConfig(nil)
if err != nil {
t.Fatalf("err: %s", err)
}
if idx != 0 {
t.Fatalf("bad index: %d", idx)
}
if result != nil {
t.Fatalf("should be nil: %v", result)
}
}
func TestStore_CARootSetList(t *testing.T) {
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(t, err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
expected := *ca1
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
assert.Nil(t, err)
assert.True(t, ok)
// Make sure the index got updated.
assert.Equal(t, s.maxIndex(tableConnectCARoots), uint64(1))
assert.True(t, watchFired(ws), "watch fired")
// Read it back out and verify it.
expected.RaftIndex = structs.RaftIndex{
CreateIndex: 1,
ModifyIndex: 1,
}
ws = memdb.NewWatchSet()
_, roots, err := s.CARoots(ws)
assert.Nil(t, err)
assert.Len(t, roots, 1)
actual := roots[0]
prototest.AssertDeepEqual(t, expected, *actual)
}
func TestStore_CARootSet_emptyID(t *testing.T) {
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(t, err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca1.ID = ""
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1})
assert.NotNil(t, err)
assert.Contains(t, err.Error(), ErrMissingCARootID.Error())
assert.False(t, ok)
// Make sure the index got updated.
assert.Equal(t, s.maxIndex(tableConnectCARoots), uint64(0))
assert.False(t, watchFired(ws), "watch fired")
// Read it back out and verify it.
ws = memdb.NewWatchSet()
_, roots, err := s.CARoots(ws)
assert.Nil(t, err)
assert.Len(t, roots, 0)
}
func TestStore_CARootSet_noActive(t *testing.T) {
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(t, err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca1.Active = false
ca2 := connect.TestCA(t, nil)
ca2.Active = false
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "exactly one active")
assert.False(t, ok)
}
func TestStore_CARootSet_multipleActive(t *testing.T) {
s := testStateStore(t)
// Call list to populate the watch set
ws := memdb.NewWatchSet()
_, _, err := s.CARoots(ws)
assert.Nil(t, err)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca2 := connect.TestCA(t, nil)
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2})
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "exactly one active")
assert.False(t, ok)
}
func TestStore_CARootActive_valid(t *testing.T) {
s := testStateStore(t)
// Build a valid value
ca1 := connect.TestCA(t, nil)
ca1.Active = false
ca2 := connect.TestCA(t, nil)
ca3 := connect.TestCA(t, nil)
ca3.Active = false
// Set
ok, err := s.CARootSetCAS(1, 0, []*structs.CARoot{ca1, ca2, ca3})
assert.Nil(t, err)
assert.True(t, ok)
// Query
ws := memdb.NewWatchSet()
idx, res, err := s.CARootActive(ws)
assert.Equal(t, idx, uint64(1))
assert.Nil(t, err)
assert.NotNil(t, res)
assert.Equal(t, ca2.ID, res.ID)
}
// Test that querying the active CA returns the correct value.
func TestStore_CARootActive_none(t *testing.T) {
s := testStateStore(t)
// Querying with no results returns nil.
ws := memdb.NewWatchSet()
idx, res, err := s.CARootActive(ws)
assert.Equal(t, idx, uint64(0))
assert.Nil(t, res)
assert.Nil(t, err)
}
func TestStore_CARoot_Snapshot_Restore(t *testing.T) {
s := testStateStore(t)
// Create some intentions.
roots := structs.CARoots{
connect.TestCA(t, nil),
connect.TestCA(t, nil),
connect.TestCA(t, nil),
}
for _, r := range roots[1:] {
r.Active = false
}
// Force the sort order of the UUIDs before we create them so the
// order is deterministic.
id := testUUID()
roots[0].ID = "a" + id[1:]
roots[1].ID = "b" + id[1:]
roots[2].ID = "c" + id[1:]
// Now create
ok, err := s.CARootSetCAS(1, 0, roots)
assert.Nil(t, err)
assert.True(t, ok)
// Snapshot the queries.
snap := s.Snapshot()
defer snap.Close()
// Alter the real state store.
ok, err = s.CARootSetCAS(2, 1, roots[:1])
assert.Nil(t, err)
assert.True(t, ok)
// Verify the snapshot.
assert.Equal(t, snap.LastIndex(), uint64(1))
dump, err := snap.CARoots()
assert.Nil(t, err)
assert.Equal(t, roots, dump)
// Restore the values into a new state store.
func() {
s := testStateStore(t)
restore := s.Restore()
for _, r := range dump {
assert.Nil(t, restore.CARoot(r))
}
restore.Commit()
// Read the restored values back out and verify that they match.
idx, actual, err := s.CARoots(nil)
assert.Nil(t, err)
assert.Equal(t, idx, uint64(2))
assert.Equal(t, roots, actual)
}()
}
func TestStore_CABuiltinProvider(t *testing.T) {
s := testStateStore(t)
{
expected := &structs.CAConsulProviderState{
ID: "foo",
PrivateKey: "a",
RootCert: "b",
}
ok, err := s.CASetProviderState(0, expected)
assert.NoError(t, err)
assert.True(t, ok)
idx, state, err := s.CAProviderState(expected.ID)
assert.NoError(t, err)
assert.Equal(t, idx, uint64(0))
assert.Equal(t, expected, state)
}
{
expected := &structs.CAConsulProviderState{
ID: "bar",
PrivateKey: "c",
RootCert: "d",
}
ok, err := s.CASetProviderState(1, expected)
assert.NoError(t, err)
assert.True(t, ok)
idx, state, err := s.CAProviderState(expected.ID)
assert.NoError(t, err)
assert.Equal(t, idx, uint64(1))
assert.Equal(t, expected, state)
}
{
// Since we've already written to the builtin provider table the serial
// numbers will initialize from the max index of the provider table.
// That's why this first serial is 2 and not 1.
sn, err := s.CAIncrementProviderSerialNumber(10)
assert.NoError(t, err)
assert.Equal(t, uint64(2), sn)
sn, err = s.CAIncrementProviderSerialNumber(10)
assert.NoError(t, err)
assert.Equal(t, uint64(3), sn)
sn, err = s.CAIncrementProviderSerialNumber(10)
assert.NoError(t, err)
assert.Equal(t, uint64(4), sn)
}
}
func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) {
s := testStateStore(t)
// Create multiple state entries.
before := []*structs.CAConsulProviderState{
{
ID: "bar",
PrivateKey: "y",
RootCert: "z",
},
{
ID: "foo",
PrivateKey: "a",
RootCert: "b",
},
}
for i, state := range before {
ok, err := s.CASetProviderState(uint64(98+i), state)
assert.NoError(t, err)
assert.True(t, ok)
}
// Take a snapshot.
snap := s.Snapshot()
defer snap.Close()
// Modify the state store.
after := &structs.CAConsulProviderState{
ID: "foo",
PrivateKey: "c",
RootCert: "d",
}
ok, err := s.CASetProviderState(100, after)
assert.NoError(t, err)
assert.True(t, ok)
snapped, err := snap.CAProviderState()
assert.NoError(t, err)
assert.Equal(t, before, snapped)
// Restore onto a new state store.
s2 := testStateStore(t)
restore := s2.Restore()
for _, entry := range snapped {
assert.NoError(t, restore.CAProviderState(entry))
}
restore.Commit()
// Verify the restored values match those from before the snapshot.
for _, state := range before {
idx, res, err := s2.CAProviderState(state.ID)
assert.NoError(t, err)
assert.Equal(t, idx, uint64(99))
assert.Equal(t, state, res)
}
}