diff --git a/agent/consul/state/peering.go b/agent/consul/state/peering.go index e48684923..a3529cbda 100644 --- a/agent/consul/state/peering.go +++ b/agent/consul/state/peering.go @@ -213,6 +213,13 @@ func (s *Store) PeeringWrite(idx uint64, p *pbpeering.Peering) error { return fmt.Errorf("cannot write to peering that is marked for deletion") } + if p.State == pbpeering.PeeringState_UNDEFINED { + p.State = existing.State + } + // TODO(peering): Confirm behavior when /peering/token is called more than once. + // We may need to avoid clobbering existing values. + p.ImportedServiceCount = existing.ImportedServiceCount + p.ExportedServiceCount = existing.ExportedServiceCount p.CreateIndex = existing.CreateIndex p.ModifyIndex = idx } else { diff --git a/agent/peering_endpoint_test.go b/agent/peering_endpoint_test.go index 0b82e6399..05b8646e9 100644 --- a/agent/peering_endpoint_test.go +++ b/agent/peering_endpoint_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/hashicorp/consul/agent/structs" @@ -148,6 +149,71 @@ func TestHTTP_Peering_GenerateToken(t *testing.T) { }) } +// Test for GenerateToken calls at various points in a peer's lifecycle +func TestHTTP_Peering_GenerateToken_EdgeCases(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + t.Parallel() + + a := NewTestAgent(t, "") + testrpc.WaitForTestAgent(t, a.RPC, "dc1") + + body := &pbpeering.GenerateTokenRequest{ + PeerName: "peering-a", + } + + bodyBytes, err := json.Marshal(body) + require.NoError(t, err) + + getPeering := func(t *testing.T) *api.Peering { + t.Helper() + // Check state of peering + req, err := http.NewRequest("GET", "/v1/peering/peering-a", bytes.NewReader(bodyBytes)) + require.NoError(t, err) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String()) + + var p *api.Peering + require.NoError(t, json.NewDecoder(resp.Body).Decode(&p)) + return p + } + + { + // Call once + req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes)) + require.NoError(t, err) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String()) + // Assertions tested in TestHTTP_Peering_GenerateToken + } + + if !t.Run("generate token called again", func(t *testing.T) { + before := getPeering(t) + require.Equal(t, api.PeeringStatePending, before.State) + + // Call again + req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes)) + require.NoError(t, err) + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String()) + + after := getPeering(t) + assert.NotEqual(t, before.ModifyIndex, after.ModifyIndex) + // blank out modify index so we can compare rest of struct + before.ModifyIndex, after.ModifyIndex = 0, 0 + assert.Equal(t, before, after) + + }) { + t.FailNow() + } + +} + func TestHTTP_Peering_Establish(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") diff --git a/agent/rpc/peering/service.go b/agent/rpc/peering/service.go index c7e0b861c..72add67b0 100644 --- a/agent/rpc/peering/service.go +++ b/agent/rpc/peering/service.go @@ -211,6 +211,53 @@ func (s *Server) GenerateToken( return nil, err } + var peering *pbpeering.Peering + + // This loop ensures at most one retry in the case of a race condition. + for canRetry := true; canRetry; canRetry = false { + peering, err = s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault()) + if err != nil { + return nil, err + } + + if peering == nil { + id, err := lib.GenerateUUID(s.Backend.CheckPeeringUUID) + if err != nil { + return resp, err + } + peering = &pbpeering.Peering{ + ID: id, + Name: req.PeerName, + Meta: req.Meta, + + // PartitionOrEmpty is used to avoid writing "default" in OSS. + Partition: entMeta.PartitionOrEmpty(), + } + } else { + // validate that this peer name is not being used as a dialer already + if err := validatePeer(peering, false); err != nil { + return nil, err + } + } + writeReq := pbpeering.PeeringWriteRequest{ + Peering: peering, + } + if err := s.Backend.PeeringWrite(&writeReq); err != nil { + // There's a possible race where two servers call Generate Token at the + // same time with the same peer name for the first time. They both + // generate an ID and try to insert and only one wins. This detects the + // collision and forces the loser to discard its generated ID and use + // the one from the other server. + if strings.Contains(err.Error(), "A peering already exists with the name") { + // retry to fetch existing peering + continue + } + return nil, fmt.Errorf("failed to write peering: %w", err) + } + // write succeeded, break loop early + break + } + ca, err := s.Backend.GetAgentCACertificates() if err != nil { return nil, err @@ -227,57 +274,6 @@ func (s *Server) GenerateToken( } } - peeringOrNil, err := s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault()) - if err != nil { - return nil, err - } - - // validate that this peer name is not being used as a dialer already - if err = validatePeer(peeringOrNil, false); err != nil { - return nil, err - } - - canRetry := true -RETRY_ONCE: - id, err := s.getExistingOrCreateNewPeerID(req.PeerName, entMeta.PartitionOrDefault()) - if err != nil { - return nil, err - } - writeReq := pbpeering.PeeringWriteRequest{ - Peering: &pbpeering.Peering{ - ID: id, - Name: req.PeerName, - Meta: req.Meta, - - // PartitionOrEmpty is used to avoid writing "default" in OSS. - Partition: entMeta.PartitionOrEmpty(), - }, - } - if err := s.Backend.PeeringWrite(&writeReq); err != nil { - // There's a possible race where two servers call Generate Token at the - // same time with the same peer name for the first time. They both - // generate an ID and try to insert and only one wins. This detects the - // collision and forces the loser to discard its generated ID and use - // the one from the other server. - if canRetry && strings.Contains(err.Error(), "A peering already exists with the name") { - canRetry = false - goto RETRY_ONCE - } - return nil, fmt.Errorf("failed to write peering: %w", err) - } - - q := state.Query{ - Value: strings.ToLower(req.PeerName), - EnterpriseMeta: *entMeta, - } - _, peering, err := s.Backend.Store().PeeringRead(nil, q) - if err != nil { - return nil, err - } - if peering == nil { - return nil, fmt.Errorf("peering was deleted while token generation request was in flight") - } - tok := structs.PeeringToken{ // Store the UUID so that we can do a global search when handling inbound streams. PeerID: peering.ID, @@ -345,24 +341,24 @@ func (s *Server) Establish( return nil, err } - peeringOrNil, err := s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault()) + peering, err := s.getExistingPeering(req.PeerName, entMeta.PartitionOrDefault()) if err != nil { return nil, err } - // validate that this peer name is not being used as an acceptor already - if err = validatePeer(peeringOrNil, true); err != nil { - return nil, err - } - var id string - if peeringOrNil != nil { - id = peeringOrNil.ID - } else { + if peering == nil { id, err = lib.GenerateUUID(s.Backend.CheckPeeringUUID) if err != nil { return nil, err } + } else { + id = peering.ID + } + + // validate that this peer name is not being used as an acceptor already + if err := validatePeer(peering, true); err != nil { + return nil, err } // convert ServiceAddress values to strings @@ -392,10 +388,10 @@ func (s *Server) Establish( Partition: entMeta.PartitionOrEmpty(), }, } - if err = s.Backend.PeeringWrite(writeReq); err != nil { + if err := s.Backend.PeeringWrite(writeReq); err != nil { return nil, fmt.Errorf("failed to write peering: %w", err) } - // resp.Status == 0 + // TODO(peering): low prio: consider adding response details return resp, nil } @@ -564,10 +560,19 @@ func (s *Server) PeeringWrite(ctx context.Context, req *pbpeering.PeeringWriteRe return nil, fmt.Errorf("missing required peering body") } - id, err := s.getExistingOrCreateNewPeerID(req.Peering.Name, entMeta.PartitionOrDefault()) + var id string + peering, err := s.getExistingPeering(req.Peering.Name, entMeta.PartitionOrDefault()) if err != nil { return nil, err } + if peering == nil { + id, err = lib.GenerateUUID(s.Backend.CheckPeeringUUID) + if err != nil { + return nil, err + } + } else { + id = peering.ID + } req.Peering.ID = id err = s.Backend.PeeringWrite(req) @@ -759,22 +764,6 @@ func (s *Server) TrustBundleListByService(ctx context.Context, req *pbpeering.Tr return &pbpeering.TrustBundleListByServiceResponse{Index: idx, Bundles: bundles}, nil } -func (s *Server) getExistingOrCreateNewPeerID(peerName, partition string) (string, error) { - peeringOrNil, err := s.getExistingPeering(peerName, partition) - if err != nil { - return "", err - } - if peeringOrNil != nil { - return peeringOrNil.ID, nil - } - - id, err := lib.GenerateUUID(s.Backend.CheckPeeringUUID) - if err != nil { - return "", err - } - return id, nil -} - func (s *Server) getExistingPeering(peerName, partition string) (*pbpeering.Peering, error) { q := state.Query{ Value: strings.ToLower(peerName), @@ -793,9 +782,9 @@ func (s *Server) getExistingPeering(peerName, partition string) (*pbpeering.Peer // // We define a DIALER as a peering that has server addresses (or a peering that is created via the Establish endpoint) // Conversely, we define an ACCEPTOR as a peering that is created via the GenerateToken endpoint -func validatePeer(peering *pbpeering.Peering, allowedToDial bool) error { - if peering != nil && peering.ShouldDial() != allowedToDial { - if allowedToDial { +func validatePeer(peering *pbpeering.Peering, shouldDial bool) error { + if peering != nil && peering.ShouldDial() != shouldDial { + if shouldDial { return fmt.Errorf("cannot create peering with name: %q; there is an existing peering expecting to be dialed", peering.Name) } else { return fmt.Errorf("cannot create peering with name: %q; there is already an established peering", peering.Name)