diff --git a/agent/consul/fsm/commands_oss_test.go b/agent/consul/fsm/commands_oss_test.go index a52e6d7b6..280bf5b38 100644 --- a/agent/consul/fsm/commands_oss_test.go +++ b/agent/consul/fsm/commands_oss_test.go @@ -1318,3 +1318,42 @@ func TestFSM_CARoots(t *testing.T) { assert.Len(roots, 2) } } + +func TestFSM_CABuiltinProvider(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + fsm, err := New(nil, os.Stderr) + assert.Nil(err) + + // Provider state. + expected := &structs.CAConsulProviderState{ + ID: "foo", + PrivateKey: "a", + RootCert: "b", + SerialIndex: 2, + RaftIndex: structs.RaftIndex{ + CreateIndex: 1, + ModifyIndex: 1, + }, + } + + // Create a new request. + req := structs.CARequest{ + Op: structs.CAOpSetProviderState, + ProviderState: expected, + } + + { + buf, err := structs.Encode(structs.ConnectCARequestType, req) + assert.Nil(err) + assert.True(fsm.Apply(makeLog(buf)).(bool)) + } + + // Verify it's in the state store. + { + _, state, err := fsm.state.CAProviderState("foo") + assert.Nil(err) + assert.Equal(expected, state) + } +} diff --git a/agent/consul/state/connect_ca.go b/agent/consul/state/connect_ca.go index 7c4cea294..a7f51a52a 100644 --- a/agent/consul/state/connect_ca.go +++ b/agent/consul/state/connect_ca.go @@ -319,19 +319,19 @@ func (s *Store) CARootSetCAS(idx, cidx uint64, rs []*structs.CARoot) (bool, erro return true, nil } -// CAProviderState is used to pull the built-in provider state from the snapshot. -func (s *Snapshot) CAProviderState() (*structs.CAConsulProviderState, error) { - c, err := s.tx.First(caBuiltinProviderTableName, "id") +// CAProviderState is used to pull the built-in provider states from the snapshot. +func (s *Snapshot) CAProviderState() ([]*structs.CAConsulProviderState, error) { + ixns, err := s.tx.Get(caBuiltinProviderTableName, "id") if err != nil { return nil, err } - state, ok := c.(*structs.CAConsulProviderState) - if !ok { - return nil, nil + var ret []*structs.CAConsulProviderState + for wrapped := ixns.Next(); wrapped != nil; wrapped = ixns.Next() { + ret = append(ret, wrapped.(*structs.CAConsulProviderState)) } - return state, nil + return ret, nil } // CAProviderState is used when restoring from a snapshot. @@ -339,6 +339,9 @@ func (s *Restore) CAProviderState(state *structs.CAConsulProviderState) error { if err := s.tx.Insert(caBuiltinProviderTableName, state); err != nil { return fmt.Errorf("failed restoring built-in CA state: %s", err) } + if err := indexUpdateMaxTxn(s.tx, state.ModifyIndex, caBuiltinProviderTableName); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } return nil } @@ -365,27 +368,6 @@ func (s *Store) CAProviderState(id string) (uint64, *structs.CAConsulProviderSta return idx, state, nil } -// CAProviderStates is used to get the Consul CA provider state for the given ID. -func (s *Store) CAProviderStates() (uint64, []*structs.CAConsulProviderState, error) { - tx := s.db.Txn(false) - defer tx.Abort() - - // Get the index - idx := maxIndexTxn(tx, caBuiltinProviderTableName) - - // Get all - iter, err := tx.Get(caBuiltinProviderTableName, "id") - if err != nil { - return 0, nil, fmt.Errorf("failed CA provider state lookup: %s", err) - } - - var results []*structs.CAConsulProviderState - for v := iter.Next(); v != nil; v = iter.Next() { - results = append(results, v.(*structs.CAConsulProviderState)) - } - return idx, results, nil -} - // CASetProviderState is used to set the current built-in CA provider state. func (s *Store) CASetProviderState(idx uint64, state *structs.CAConsulProviderState) (bool, error) { tx := s.db.Txn(true) @@ -419,7 +401,8 @@ func (s *Store) CASetProviderState(idx uint64, state *structs.CAConsulProviderSt return true, nil } -// CADeleteProviderState is used to remove the Consul CA provider state for the given ID. +// CADeleteProviderState is used to remove the built-in Consul CA provider +// state for the given ID. func (s *Store) CADeleteProviderState(id string) error { tx := s.db.Txn(true) defer tx.Abort() diff --git a/agent/consul/state/connect_ca_test.go b/agent/consul/state/connect_ca_test.go index cd37f526b..4639c7f5a 100644 --- a/agent/consul/state/connect_ca_test.go +++ b/agent/consul/state/connect_ca_test.go @@ -349,3 +349,106 @@ func TestStore_CARoot_Snapshot_Restore(t *testing.T) { assert.Equal(roots, actual) }() } + +func TestStore_CABuiltinProvider(t *testing.T) { + assert := assert.New(t) + s := testStateStore(t) + + { + expected := &structs.CAConsulProviderState{ + ID: "foo", + PrivateKey: "a", + RootCert: "b", + SerialIndex: 1, + } + + ok, err := s.CASetProviderState(0, expected) + assert.NoError(err) + assert.True(ok) + + idx, state, err := s.CAProviderState(expected.ID) + assert.NoError(err) + assert.Equal(idx, uint64(0)) + assert.Equal(expected, state) + } + + { + expected := &structs.CAConsulProviderState{ + ID: "bar", + PrivateKey: "c", + RootCert: "d", + SerialIndex: 2, + } + + ok, err := s.CASetProviderState(1, expected) + assert.NoError(err) + assert.True(ok) + + idx, state, err := s.CAProviderState(expected.ID) + assert.NoError(err) + assert.Equal(idx, uint64(1)) + assert.Equal(expected, state) + } +} + +func TestStore_CABuiltinProvider_Snapshot_Restore(t *testing.T) { + assert := assert.New(t) + s := testStateStore(t) + + // Create multiple state entries. + before := []*structs.CAConsulProviderState{ + { + ID: "bar", + PrivateKey: "y", + RootCert: "z", + SerialIndex: 2, + }, + { + ID: "foo", + PrivateKey: "a", + RootCert: "b", + SerialIndex: 1, + }, + } + + for i, state := range before { + ok, err := s.CASetProviderState(uint64(98+i), state) + assert.NoError(err) + assert.True(ok) + } + + // Take a snapshot. + snap := s.Snapshot() + defer snap.Close() + + // Modify the state store. + after := &structs.CAConsulProviderState{ + ID: "foo", + PrivateKey: "c", + RootCert: "d", + SerialIndex: 1, + } + ok, err := s.CASetProviderState(100, after) + assert.NoError(err) + assert.True(ok) + + snapped, err := snap.CAProviderState() + assert.NoError(err) + assert.Equal(before, snapped) + + // Restore onto a new state store. + s2 := testStateStore(t) + restore := s2.Restore() + for _, entry := range snapped { + assert.NoError(restore.CAProviderState(entry)) + } + restore.Commit() + + // Verify the restored values match those from before the snapshot. + for _, state := range before { + idx, res, err := s2.CAProviderState(state.ID) + assert.NoError(err) + assert.Equal(idx, uint64(99)) + assert.Equal(state, res) + } +}