Preserve PeeringState on upsert (#13666)
Fixes a bug where if the generate token is called twice, the second call upserts the zero-value (undefined) of PeeringState.
This commit is contained in:
parent
706808dd84
commit
1f8ae56951
|
@ -213,6 +213,13 @@ func (s *Store) PeeringWrite(idx uint64, p *pbpeering.Peering) error {
|
|||
return fmt.Errorf("cannot write to peering that is marked for deletion")
|
||||
}
|
||||
|
||||
if p.State == pbpeering.PeeringState_UNDEFINED {
|
||||
p.State = existing.State
|
||||
}
|
||||
// TODO(peering): Confirm behavior when /peering/token is called more than once.
|
||||
// We may need to avoid clobbering existing values.
|
||||
p.ImportedServiceCount = existing.ImportedServiceCount
|
||||
p.ExportedServiceCount = existing.ExportedServiceCount
|
||||
p.CreateIndex = existing.CreateIndex
|
||||
p.ModifyIndex = idx
|
||||
} else {
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
|
@ -148,6 +149,71 @@ func TestHTTP_Peering_GenerateToken(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// Test for GenerateToken calls at various points in a peer's lifecycle
|
||||
func TestHTTP_Peering_GenerateToken_EdgeCases(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("too slow for testing.Short")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
a := NewTestAgent(t, "")
|
||||
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
|
||||
|
||||
body := &pbpeering.GenerateTokenRequest{
|
||||
PeerName: "peering-a",
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
getPeering := func(t *testing.T) *api.Peering {
|
||||
t.Helper()
|
||||
// Check state of peering
|
||||
req, err := http.NewRequest("GET", "/v1/peering/peering-a", bytes.NewReader(bodyBytes))
|
||||
require.NoError(t, err)
|
||||
resp := httptest.NewRecorder()
|
||||
a.srv.h.ServeHTTP(resp, req)
|
||||
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
|
||||
|
||||
var p *api.Peering
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&p))
|
||||
return p
|
||||
}
|
||||
|
||||
{
|
||||
// Call once
|
||||
req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes))
|
||||
require.NoError(t, err)
|
||||
resp := httptest.NewRecorder()
|
||||
a.srv.h.ServeHTTP(resp, req)
|
||||
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
|
||||
// Assertions tested in TestHTTP_Peering_GenerateToken
|
||||
}
|
||||
|
||||
if !t.Run("generate token called again", func(t *testing.T) {
|
||||
before := getPeering(t)
|
||||
require.Equal(t, api.PeeringStatePending, before.State)
|
||||
|
||||
// Call again
|
||||
req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes))
|
||||
require.NoError(t, err)
|
||||
resp := httptest.NewRecorder()
|
||||
a.srv.h.ServeHTTP(resp, req)
|
||||
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
|
||||
|
||||
after := getPeering(t)
|
||||
assert.NotEqual(t, before.ModifyIndex, after.ModifyIndex)
|
||||
// blank out modify index so we can compare rest of struct
|
||||
before.ModifyIndex, after.ModifyIndex = 0, 0
|
||||
assert.Equal(t, before, after)
|
||||
|
||||
}) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestHTTP_Peering_Establish(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("too slow for testing.Short")
|
||||
|
|
|
@ -211,6 +211,53 @@ func (s *Server) GenerateToken(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var peering *pbpeering.Peering
|
||||
|
||||
// This loop ensures at most one retry in the case of a race condition.
|
||||
for canRetry := true; canRetry; canRetry = false {
|
||||
peering, err = s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if peering == nil {
|
||||
id, err := lib.GenerateUUID(s.Backend.CheckPeeringUUID)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
peering = &pbpeering.Peering{
|
||||
ID: id,
|
||||
Name: req.PeerName,
|
||||
Meta: req.Meta,
|
||||
|
||||
// PartitionOrEmpty is used to avoid writing "default" in OSS.
|
||||
Partition: entMeta.PartitionOrEmpty(),
|
||||
}
|
||||
} else {
|
||||
// validate that this peer name is not being used as a dialer already
|
||||
if err := validatePeer(peering, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
writeReq := pbpeering.PeeringWriteRequest{
|
||||
Peering: peering,
|
||||
}
|
||||
if err := s.Backend.PeeringWrite(&writeReq); err != nil {
|
||||
// There's a possible race where two servers call Generate Token at the
|
||||
// same time with the same peer name for the first time. They both
|
||||
// generate an ID and try to insert and only one wins. This detects the
|
||||
// collision and forces the loser to discard its generated ID and use
|
||||
// the one from the other server.
|
||||
if strings.Contains(err.Error(), "A peering already exists with the name") {
|
||||
// retry to fetch existing peering
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("failed to write peering: %w", err)
|
||||
}
|
||||
// write succeeded, break loop early
|
||||
break
|
||||
}
|
||||
|
||||
ca, err := s.Backend.GetAgentCACertificates()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -227,57 +274,6 @@ func (s *Server) GenerateToken(
|
|||
}
|
||||
}
|
||||
|
||||
peeringOrNil, err := s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// validate that this peer name is not being used as a dialer already
|
||||
if err = validatePeer(peeringOrNil, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
canRetry := true
|
||||
RETRY_ONCE:
|
||||
id, err := s.getExistingOrCreateNewPeerID(req.PeerName, entMeta.PartitionOrDefault())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writeReq := pbpeering.PeeringWriteRequest{
|
||||
Peering: &pbpeering.Peering{
|
||||
ID: id,
|
||||
Name: req.PeerName,
|
||||
Meta: req.Meta,
|
||||
|
||||
// PartitionOrEmpty is used to avoid writing "default" in OSS.
|
||||
Partition: entMeta.PartitionOrEmpty(),
|
||||
},
|
||||
}
|
||||
if err := s.Backend.PeeringWrite(&writeReq); err != nil {
|
||||
// There's a possible race where two servers call Generate Token at the
|
||||
// same time with the same peer name for the first time. They both
|
||||
// generate an ID and try to insert and only one wins. This detects the
|
||||
// collision and forces the loser to discard its generated ID and use
|
||||
// the one from the other server.
|
||||
if canRetry && strings.Contains(err.Error(), "A peering already exists with the name") {
|
||||
canRetry = false
|
||||
goto RETRY_ONCE
|
||||
}
|
||||
return nil, fmt.Errorf("failed to write peering: %w", err)
|
||||
}
|
||||
|
||||
q := state.Query{
|
||||
Value: strings.ToLower(req.PeerName),
|
||||
EnterpriseMeta: *entMeta,
|
||||
}
|
||||
_, peering, err := s.Backend.Store().PeeringRead(nil, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if peering == nil {
|
||||
return nil, fmt.Errorf("peering was deleted while token generation request was in flight")
|
||||
}
|
||||
|
||||
tok := structs.PeeringToken{
|
||||
// Store the UUID so that we can do a global search when handling inbound streams.
|
||||
PeerID: peering.ID,
|
||||
|
@ -345,24 +341,24 @@ func (s *Server) Establish(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
peeringOrNil, err := s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault())
|
||||
peering, err := s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// validate that this peer name is not being used as an acceptor already
|
||||
if err = validatePeer(peeringOrNil, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var id string
|
||||
if peeringOrNil != nil {
|
||||
id = peeringOrNil.ID
|
||||
} else {
|
||||
if peering == nil {
|
||||
id, err = lib.GenerateUUID(s.Backend.CheckPeeringUUID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
id = peering.ID
|
||||
}
|
||||
|
||||
// validate that this peer name is not being used as an acceptor already
|
||||
if err := validatePeer(peering, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// convert ServiceAddress values to strings
|
||||
|
@ -392,10 +388,10 @@ func (s *Server) Establish(
|
|||
Partition: entMeta.PartitionOrEmpty(),
|
||||
},
|
||||
}
|
||||
if err = s.Backend.PeeringWrite(writeReq); err != nil {
|
||||
if err := s.Backend.PeeringWrite(writeReq); err != nil {
|
||||
return nil, fmt.Errorf("failed to write peering: %w", err)
|
||||
}
|
||||
// resp.Status == 0
|
||||
// TODO(peering): low prio: consider adding response details
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
@ -564,10 +560,19 @@ func (s *Server) PeeringWrite(ctx context.Context, req *pbpeering.PeeringWriteRe
|
|||
return nil, fmt.Errorf("missing required peering body")
|
||||
}
|
||||
|
||||
id, err := s.getExistingOrCreateNewPeerID(req.Peering.Name, entMeta.PartitionOrDefault())
|
||||
var id string
|
||||
peering, err := s.getExistingPeering(req.Peering.Name, entMeta.PartitionOrDefault())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if peering == nil {
|
||||
id, err = lib.GenerateUUID(s.Backend.CheckPeeringUUID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
id = peering.ID
|
||||
}
|
||||
req.Peering.ID = id
|
||||
|
||||
err = s.Backend.PeeringWrite(req)
|
||||
|
@ -759,22 +764,6 @@ func (s *Server) TrustBundleListByService(ctx context.Context, req *pbpeering.Tr
|
|||
return &pbpeering.TrustBundleListByServiceResponse{Index: idx, Bundles: bundles}, nil
|
||||
}
|
||||
|
||||
func (s *Server) getExistingOrCreateNewPeerID(peerName, partition string) (string, error) {
|
||||
peeringOrNil, err := s.getExistingPeering(peerName, partition)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if peeringOrNil != nil {
|
||||
return peeringOrNil.ID, nil
|
||||
}
|
||||
|
||||
id, err := lib.GenerateUUID(s.Backend.CheckPeeringUUID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (s *Server) getExistingPeering(peerName, partition string) (*pbpeering.Peering, error) {
|
||||
q := state.Query{
|
||||
Value: strings.ToLower(peerName),
|
||||
|
@ -793,9 +782,9 @@ func (s *Server) getExistingPeering(peerName, partition string) (*pbpeering.Peer
|
|||
//
|
||||
// We define a DIALER as a peering that has server addresses (or a peering that is created via the Establish endpoint)
|
||||
// Conversely, we define an ACCEPTOR as a peering that is created via the GenerateToken endpoint
|
||||
func validatePeer(peering *pbpeering.Peering, allowedToDial bool) error {
|
||||
if peering != nil && peering.ShouldDial() != allowedToDial {
|
||||
if allowedToDial {
|
||||
func validatePeer(peering *pbpeering.Peering, shouldDial bool) error {
|
||||
if peering != nil && peering.ShouldDial() != shouldDial {
|
||||
if shouldDial {
|
||||
return fmt.Errorf("cannot create peering with name: %q; there is an existing peering expecting to be dialed", peering.Name)
|
||||
} else {
|
||||
return fmt.Errorf("cannot create peering with name: %q; there is already an established peering", peering.Name)
|
||||
|
|
Loading…
Reference in a new issue