diff --git a/agent/consul/fsm/snapshot_oss.go b/agent/consul/fsm/snapshot_oss.go index 1dde3ab0b..b042c7831 100644 --- a/agent/consul/fsm/snapshot_oss.go +++ b/agent/consul/fsm/snapshot_oss.go @@ -21,6 +21,7 @@ func init() { registerRestorer(structs.PreparedQueryRequestType, restorePreparedQuery) registerRestorer(structs.AutopilotRequestType, restoreAutopilot) registerRestorer(structs.IntentionRequestType, restoreIntention) + registerRestorer(structs.ConnectCARequestType, restoreConnectCA) } func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error { @@ -48,6 +49,9 @@ func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) err if err := s.persistIntentions(sink, encoder); err != nil { return err } + if err := s.persistConnectCA(sink, encoder); err != nil { + return err + } return nil } @@ -262,6 +266,24 @@ func (s *snapshot) persistAutopilot(sink raft.SnapshotSink, return nil } +func (s *snapshot) persistConnectCA(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + roots, err := s.state.CARoots() + if err != nil { + return err + } + + for _, r := range roots { + if _, err := sink.Write([]byte{byte(structs.ConnectCARequestType)}); err != nil { + return err + } + if err := encoder.Encode(r); err != nil { + return err + } + } + return nil +} + func (s *snapshot) persistIntentions(sink raft.SnapshotSink, encoder *codec.Encoder) error { ixns, err := s.state.Intentions() @@ -397,3 +419,14 @@ func restoreIntention(header *snapshotHeader, restore *state.Restore, decoder *c } return nil } + +func restoreConnectCA(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.CARoot + if err := decoder.Decode(&req); err != nil { + return err + } + if err := restore.CARoot(&req); err != nil { + return err + } + return nil +} diff --git a/agent/consul/fsm/snapshot_oss_test.go b/agent/consul/fsm/snapshot_oss_test.go index 63f1ab1d3..971e6bbf5 100644 --- a/agent/consul/fsm/snapshot_oss_test.go +++ b/agent/consul/fsm/snapshot_oss_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/consul/autopilot" "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/structs" @@ -110,6 +111,18 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { } assert.Nil(fsm.state.IntentionSet(14, ixn)) + // CA Roots + roots := []*structs.CARoot{ + connect.TestCA(t, nil), + connect.TestCA(t, nil), + } + for _, r := range roots[1:] { + r.Active = false + } + ok, err := fsm.state.CARootSetCAS(15, 0, roots) + assert.Nil(err) + assert.True(ok) + // Snapshot snap, err := fsm.Snapshot() if err != nil { @@ -278,6 +291,11 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { assert.Len(ixns, 1) assert.Equal(ixn, ixns[0]) + // Verify CA roots are restored. + _, roots, err = fsm2.state.CARoots(nil) + assert.Nil(err) + assert.Len(roots, 2) + // Snapshot snap, err = fsm2.Snapshot() if err != nil { diff --git a/agent/consul/state/connect_ca.go b/agent/consul/state/connect_ca.go index 3b66a07c6..05313ce2e 100644 --- a/agent/consul/state/connect_ca.go +++ b/agent/consul/state/connect_ca.go @@ -33,6 +33,34 @@ func init() { registerSchema(caRootTableSchema) } +// CARoots is used to pull all the CA roots for the snapshot. +func (s *Snapshot) CARoots() (structs.CARoots, error) { + ixns, err := s.tx.Get(caRootTableName, "id") + if err != nil { + return nil, err + } + + var ret structs.CARoots + for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() { + ret = append(ret, wrapped.(*structs.CARoot)) + } + + return ret, nil +} + +// CARoots is used when restoring from a snapshot. +func (s *Restore) CARoot(r *structs.CARoot) error { + // Insert + if err := s.tx.Insert(caRootTableName, r); err != nil { + return fmt.Errorf("failed restoring CA root: %s", err) + } + if err := indexUpdateMaxTxn(s.tx, r.ModifyIndex, caRootTableName); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + return nil +} + // CARoots returns the list of all CA roots. func (s *Store) CARoots(ws memdb.WatchSet) (uint64, structs.CARoots, error) { tx := s.db.Txn(false) diff --git a/agent/consul/state/connect_ca_test.go b/agent/consul/state/connect_ca_test.go index 14b5caf54..bbbac0f0f 100644 --- a/agent/consul/state/connect_ca_test.go +++ b/agent/consul/state/connect_ca_test.go @@ -113,92 +113,60 @@ func TestStore_CARootActive_none(t *testing.T) { assert.Nil(err) } -/* -func TestStore_Intention_Snapshot_Restore(t *testing.T) { +func TestStore_CARoot_Snapshot_Restore(t *testing.T) { assert := assert.New(t) s := testStateStore(t) // Create some intentions. - ixns := structs.Intentions{ - &structs.Intention{ - DestinationName: "foo", - }, - &structs.Intention{ - DestinationName: "bar", - }, - &structs.Intention{ - DestinationName: "baz", - }, + roots := structs.CARoots{ + connect.TestCA(t, nil), + connect.TestCA(t, nil), + connect.TestCA(t, nil), + } + for _, r := range roots[1:] { + r.Active = false } // Force the sort order of the UUIDs before we create them so the // order is deterministic. id := testUUID() - ixns[0].ID = "a" + id[1:] - ixns[1].ID = "b" + id[1:] - ixns[2].ID = "c" + id[1:] + roots[0].ID = "a" + id[1:] + roots[1].ID = "b" + id[1:] + roots[2].ID = "c" + id[1:] // Now create - for i, ixn := range ixns { - assert.Nil(s.IntentionSet(uint64(4+i), ixn)) - } + ok, err := s.CARootSetCAS(1, 0, roots) + assert.Nil(err) + assert.True(ok) // Snapshot the queries. snap := s.Snapshot() defer snap.Close() // Alter the real state store. - assert.Nil(s.IntentionDelete(7, ixns[0].ID)) + ok, err = s.CARootSetCAS(2, 1, roots[:1]) + assert.Nil(err) + assert.True(ok) // Verify the snapshot. - assert.Equal(snap.LastIndex(), uint64(6)) - expected := structs.Intentions{ - &structs.Intention{ - ID: ixns[0].ID, - DestinationName: "foo", - Meta: map[string]string{}, - RaftIndex: structs.RaftIndex{ - CreateIndex: 4, - ModifyIndex: 4, - }, - }, - &structs.Intention{ - ID: ixns[1].ID, - DestinationName: "bar", - Meta: map[string]string{}, - RaftIndex: structs.RaftIndex{ - CreateIndex: 5, - ModifyIndex: 5, - }, - }, - &structs.Intention{ - ID: ixns[2].ID, - DestinationName: "baz", - Meta: map[string]string{}, - RaftIndex: structs.RaftIndex{ - CreateIndex: 6, - ModifyIndex: 6, - }, - }, - } - dump, err := snap.Intentions() + assert.Equal(snap.LastIndex(), uint64(1)) + dump, err := snap.CARoots() assert.Nil(err) - assert.Equal(expected, dump) + assert.Equal(roots, dump) // Restore the values into a new state store. func() { s := testStateStore(t) restore := s.Restore() - for _, ixn := range dump { - assert.Nil(restore.Intention(ixn)) + for _, r := range dump { + assert.Nil(restore.CARoot(r)) } restore.Commit() // Read the restored values back out and verify that they match. - idx, actual, err := s.Intentions(nil) + idx, actual, err := s.CARoots(nil) assert.Nil(err) - assert.Equal(idx, uint64(6)) - assert.Equal(expected, actual) + assert.Equal(idx, uint64(2)) + assert.Equal(roots, actual) }() } -*/