diff --git a/agent/consul/fsm/snapshot_oss.go b/agent/consul/fsm/snapshot_oss.go index b042c7831..f9eb18cc8 100644 --- a/agent/consul/fsm/snapshot_oss.go +++ b/agent/consul/fsm/snapshot_oss.go @@ -22,6 +22,7 @@ func init() { registerRestorer(structs.AutopilotRequestType, restoreAutopilot) registerRestorer(structs.IntentionRequestType, restoreIntention) registerRestorer(structs.ConnectCARequestType, restoreConnectCA) + registerRestorer(structs.ConnectCAProviderStateType, restoreConnectCAProviderState) } func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) error { @@ -52,6 +53,9 @@ func persistOSS(s *snapshot, sink raft.SnapshotSink, encoder *codec.Encoder) err if err := s.persistConnectCA(sink, encoder); err != nil { return err } + if err := s.persistConnectCAProviderState(sink, encoder); err != nil { + return err + } return nil } @@ -284,6 +288,24 @@ func (s *snapshot) persistConnectCA(sink raft.SnapshotSink, return nil } +func (s *snapshot) persistConnectCAProviderState(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + state, err := s.state.CAProviderState() + if err != nil { + return err + } + + for _, r := range state { + if _, err := sink.Write([]byte{byte(structs.ConnectCAProviderStateType)}); err != nil { + return err + } + if err := encoder.Encode(r); err != nil { + return err + } + } + return nil +} + func (s *snapshot) persistIntentions(sink raft.SnapshotSink, encoder *codec.Encoder) error { ixns, err := s.state.Intentions() @@ -430,3 +452,14 @@ func restoreConnectCA(header *snapshotHeader, restore *state.Restore, decoder *c } return nil } + +func restoreConnectCAProviderState(header *snapshotHeader, restore *state.Restore, decoder *codec.Decoder) error { + var req structs.CAConsulProviderState + if err := decoder.Decode(&req); err != nil { + return err + } + if err := restore.CAProviderState(&req); err != nil { + return err + } + return nil +} diff --git a/agent/consul/fsm/snapshot_oss_test.go b/agent/consul/fsm/snapshot_oss_test.go index 971e6bbf5..9a6f3a355 100644 --- a/agent/consul/fsm/snapshot_oss_test.go +++ b/agent/consul/fsm/snapshot_oss_test.go @@ -123,6 +123,14 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { assert.Nil(err) assert.True(ok) + ok, err = fsm.state.CASetProviderState(16, &structs.CAConsulProviderState{ + ID: "asdf", + PrivateKey: "foo", + RootCert: "bar", + }) + assert.Nil(err) + assert.True(ok) + // Snapshot snap, err := fsm.Snapshot() if err != nil { @@ -296,6 +304,12 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { assert.Nil(err) assert.Len(roots, 2) + // Verify provider state is restored. + _, state, err := fsm2.state.CAProviderState("asdf") + assert.Nil(err) + assert.Equal("foo", state.PrivateKey) + assert.Equal("bar", state.RootCert) + // Snapshot snap, err = fsm2.Snapshot() if err != nil { diff --git a/agent/structs/structs.go b/agent/structs/structs.go index 7b719405a..f5308b351 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -31,20 +31,21 @@ type RaftIndex struct { // These are serialized between Consul servers and stored in Consul snapshots, // so entries must only ever be added. const ( - RegisterRequestType MessageType = 0 - DeregisterRequestType = 1 - KVSRequestType = 2 - SessionRequestType = 3 - ACLRequestType = 4 - TombstoneRequestType = 5 - CoordinateBatchUpdateType = 6 - PreparedQueryRequestType = 7 - TxnRequestType = 8 - AutopilotRequestType = 9 - AreaRequestType = 10 - ACLBootstrapRequestType = 11 // FSM snapshots only. - IntentionRequestType = 12 - ConnectCARequestType = 13 + RegisterRequestType MessageType = 0 + DeregisterRequestType = 1 + KVSRequestType = 2 + SessionRequestType = 3 + ACLRequestType = 4 + TombstoneRequestType = 5 + CoordinateBatchUpdateType = 6 + PreparedQueryRequestType = 7 + TxnRequestType = 8 + AutopilotRequestType = 9 + AreaRequestType = 10 + ACLBootstrapRequestType = 11 // FSM snapshots only. + IntentionRequestType = 12 + ConnectCARequestType = 13 + ConnectCAProviderStateType = 14 ) const (