From 04b6bd637a1797ea19996c6e4933267129c8eb73 Mon Sep 17 00:00:00 2001 From: Kyle Havlovitz Date: Thu, 23 Apr 2020 16:16:04 -0700 Subject: [PATCH] Filter wildcard gateway services to match listener protocol This now requires some type of protocol setting in ingress gateway tests to ensure the services are not filtered out. - small refactor to add a max(x, y) function - Use internal configEntryTxn function and add MaxUint64 to lib --- agent/consul/state/catalog.go | 62 ++-- agent/consul/state/catalog_test.go | 502 +++++++++++++++++++++++------ agent/consul/state/config_entry.go | 44 +++ lib/math.go | 7 + lib/math_test.go | 11 + 5 files changed, 509 insertions(+), 117 deletions(-) diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index f30705c9f..3c35f2fa2 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -7,6 +7,7 @@ import ( "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/types" "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-uuid" @@ -2006,9 +2007,7 @@ func (s *Store) CheckIngressServiceNodes(ws memdb.WatchSet, serviceName string, // TODO(ingress) : Deal with incorporating index from mapping table // Watch for index changes to the gateway nodes idx, chans := s.maxIndexAndWatchChsForServiceNodes(tx, nodes, false, entMeta) - if idx > maxIdx { - maxIdx = idx - } + maxIdx = lib.MaxUint64(maxIdx, idx) for _, ch := range chans { ws.Add(ch) } @@ -2026,10 +2025,7 @@ func (s *Store) CheckIngressServiceNodes(ws memdb.WatchSet, serviceName string, if err != nil { return 0, nil, err } - if idx > maxIdx { - maxIdx = idx - } - + maxIdx = lib.MaxUint64(maxIdx, idx) results = append(results, n...) } return maxIdx, results, nil @@ -2087,9 +2083,7 @@ func (s *Store) checkServiceNodesTxn(tx *memdb.Txn, ws memdb.WatchSet, serviceNa if err != nil { return 0, nil, fmt.Errorf("failed gateway nodes lookup: %v", err) } - if idx < gwIdx { - idx = gwIdx - } + idx = lib.MaxUint64(idx, gwIdx) for i := 0; i < len(nodes); i++ { results = append(results, nodes[i]) serviceNames[nodes[i].ServiceName] = struct{}{} @@ -2117,9 +2111,7 @@ func (s *Store) checkServiceNodesTxn(tx *memdb.Txn, ws memdb.WatchSet, serviceNa // below is always true. svcIdx, svcCh := s.maxIndexAndWatchChForService(tx, svcName, true, true, entMeta) // Take the max index represented - if idx < svcIdx { - idx = svcIdx - } + idx = lib.MaxUint64(idx, svcIdx) if svcCh != nil { // Watch the service-specific index for changes in liu of all iradix nodes // for checks etc. @@ -2139,9 +2131,7 @@ func (s *Store) checkServiceNodesTxn(tx *memdb.Txn, ws memdb.WatchSet, serviceNa // be returned as we can't use the optimization in this case (and don't need // to as there is only one chan to watch anyway). svcIdx, _ := s.maxIndexAndWatchChForService(tx, serviceName, false, true, entMeta) - if idx < svcIdx { - idx = svcIdx - } + idx = lib.MaxUint64(idx, svcIdx) } // Create a nil watchset to pass below, we'll only pass the real one if we @@ -2202,6 +2192,7 @@ func (s *Store) CheckServiceTagNodes(ws memdb.WatchSet, serviceName string, tags func (s *Store) GatewayServices(ws memdb.WatchSet, gateway string, entMeta *structs.EnterpriseMeta) (uint64, structs.GatewayServices, error) { tx := s.db.Txn(false) defer tx.Abort() + var maxIdx uint64 iter, err := s.gatewayServices(tx, gateway, entMeta) if err != nil { @@ -2214,12 +2205,19 @@ func (s *Store) GatewayServices(ws memdb.WatchSet, gateway string, entMeta *stru svc := service.(*structs.GatewayService) if svc.Service.ID != structs.WildcardSpecifier { - results = append(results, svc) + idx, matches, err := s.checkProtocolMatch(tx, ws, svc) + if err != nil { + return 0, nil, fmt.Errorf("failed checking protocol: %s", err) + } + maxIdx = lib.MaxUint64(maxIdx, idx) + if matches { + results = append(results, svc) + } } } idx := maxIndexTxn(tx, gatewayServicesTableName) - return idx, results, nil + return lib.MaxUint64(maxIdx, idx), results, nil } // parseCheckServiceNodes is used to parse through a given set of services, @@ -2727,10 +2725,7 @@ func (s *Store) serviceGatewayNodes(tx *memdb.Txn, ws memdb.WatchSet, service st if mapping.GatewayKind != kind { continue } - - if mapping.ModifyIndex > maxIdx { - maxIdx = mapping.ModifyIndex - } + maxIdx = lib.MaxUint64(maxIdx, mapping.ModifyIndex) // Look up nodes for gateway gwServices, err := s.catalogServiceNodeList(tx, mapping.Gateway.ID, "service", &mapping.Gateway.EnterpriseMeta) @@ -2749,9 +2744,7 @@ func (s *Store) serviceGatewayNodes(tx *memdb.Txn, ws memdb.WatchSet, service st // This prevents the index from sliding back in case all instances of the service are deregistered svcIdx := s.maxIndexForService(tx, mapping.Gateway.ID, exists, false, &mapping.Service.EnterpriseMeta) - if maxIdx < svcIdx { - maxIdx = svcIdx - } + maxIdx = lib.MaxUint64(maxIdx, svcIdx) // Ensure that blocking queries wake up if the gateway-service mapping exists, but the gateway does not exist yet if !exists { @@ -2760,3 +2753,22 @@ func (s *Store) serviceGatewayNodes(tx *memdb.Txn, ws memdb.WatchSet, service st } return maxIdx, ret, nil } + +// checkProtocolMatch filters out any GatewayService entries added from a wildcard with a protocol +// that doesn't match the one configured in their discovery chain. +func (s *Store) checkProtocolMatch( + tx *memdb.Txn, + ws memdb.WatchSet, + svc *structs.GatewayService, +) (uint64, bool, error) { + if svc.GatewayKind != structs.ServiceKindIngressGateway || !svc.FromWildcard { + return 0, true, nil + } + + idx, protocol, err := s.protocolForService(tx, ws, svc.Service) + if err != nil { + return 0, false, err + } + + return idx, svc.Protocol == protocol, nil +} diff --git a/agent/consul/state/catalog_test.go b/agent/consul/state/catalog_test.go index 5150300c3..5d96d03aa 100644 --- a/agent/consul/state/catalog_test.go +++ b/agent/consul/state/catalog_test.go @@ -4920,7 +4920,7 @@ func TestStateStore_CheckIngressServiceNodes(t *testing.T) { t.Run("check service1 ingress gateway", func(t *testing.T) { idx, results, err := s.CheckIngressServiceNodes(ws, "service1", nil) require.NoError(err) - require.Equal(uint64(14), idx) + require.Equal(uint64(15), idx) // Multiple instances of the ingress2 service require.Len(results, 4) @@ -4939,7 +4939,7 @@ func TestStateStore_CheckIngressServiceNodes(t *testing.T) { t.Run("check service2 ingress gateway", func(t *testing.T) { idx, results, err := s.CheckIngressServiceNodes(ws, "service2", nil) require.NoError(err) - require.Equal(uint64(14), idx) + require.Equal(uint64(15), idx) require.Len(results, 2) ids := make(map[string]struct{}) @@ -4957,7 +4957,7 @@ func TestStateStore_CheckIngressServiceNodes(t *testing.T) { ws := memdb.NewWatchSet() idx, results, err := s.CheckIngressServiceNodes(ws, "service3", nil) require.NoError(err) - require.Equal(uint64(14), idx) + require.Equal(uint64(15), idx) require.Len(results, 1) require.Equal("wildcardIngress", results[0].Service.ID) }) @@ -4968,17 +4968,17 @@ func TestStateStore_CheckIngressServiceNodes(t *testing.T) { idx, results, err := s.CheckIngressServiceNodes(ws, "service1", nil) require.NoError(err) - require.Equal(uint64(14), idx) + require.Equal(uint64(15), idx) require.Len(results, 3) idx, results, err = s.CheckIngressServiceNodes(ws, "service2", nil) require.NoError(err) - require.Equal(uint64(14), idx) + require.Equal(uint64(15), idx) require.Len(results, 1) idx, results, err = s.CheckIngressServiceNodes(ws, "service3", nil) require.NoError(err) - require.Equal(uint64(14), idx) + require.Equal(uint64(15), idx) // TODO(ingress): index goes backward when deleting last config entry // require.Equal(uint64(11), idx) require.Len(results, 0) @@ -4988,88 +4988,176 @@ func TestStateStore_CheckIngressServiceNodes(t *testing.T) { func TestStateStore_GatewayServices_Ingress(t *testing.T) { s := testStateStore(t) ws := setupIngressState(t, s) - require := require.New(t) t.Run("ingress1 gateway services", func(t *testing.T) { + expected := structs.GatewayServices{ + { + Gateway: structs.NewServiceID("ingress1", nil), + Service: structs.NewServiceID("service1", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 1111, + Protocol: "http", + Hosts: []string{"test.example.com"}, + RaftIndex: structs.RaftIndex{ + CreateIndex: 13, + ModifyIndex: 13, + }, + }, + { + Gateway: structs.NewServiceID("ingress1", nil), + Service: structs.NewServiceID("service2", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 2222, + Protocol: "http", + RaftIndex: structs.RaftIndex{ + CreateIndex: 13, + ModifyIndex: 13, + }, + }, + } idx, results, err := s.GatewayServices(ws, "ingress1", nil) - require.NoError(err) - require.Equal(uint64(15), idx) - require.Len(results, 2) - require.Equal("ingress1", results[0].Gateway.ID) - require.Equal("service1", results[0].Service.ID) - require.Len(results[0].Hosts, 1) - require.Equal(1111, results[0].Port) - require.Equal("ingress1", results[1].Gateway.ID) - require.Equal("service2", results[1].Service.ID) - require.Equal(2222, results[1].Port) + require.NoError(t, err) + require.Equal(t, uint64(16), idx) + require.ElementsMatch(t, results, expected) }) t.Run("ingress2 gateway services", func(t *testing.T) { + expected := structs.GatewayServices{ + { + Gateway: structs.NewServiceID("ingress2", nil), + Service: structs.NewServiceID("service1", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 3333, + Protocol: "http", + RaftIndex: structs.RaftIndex{ + CreateIndex: 14, + ModifyIndex: 14, + }, + }, + } idx, results, err := s.GatewayServices(ws, "ingress2", nil) - require.NoError(err) - require.Equal(uint64(15), idx) - require.Len(results, 1) - require.Equal("ingress2", results[0].Gateway.ID) - require.Equal("service1", results[0].Service.ID) - require.Equal(3333, results[0].Port) + require.NoError(t, err) + require.Equal(t, uint64(16), idx) + require.ElementsMatch(t, results, expected) }) t.Run("No gatway services associated", func(t *testing.T) { idx, results, err := s.GatewayServices(ws, "nothingIngress", nil) - require.NoError(err) - require.Equal(uint64(15), idx) - require.Len(results, 0) + require.NoError(t, err) + require.Equal(t, uint64(16), idx) + require.Len(t, results, 0) }) t.Run("wildcard gateway services", func(t *testing.T) { - ws = memdb.NewWatchSet() + expected := structs.GatewayServices{ + { + Gateway: structs.NewServiceID("wildcardIngress", nil), + Service: structs.NewServiceID("service1", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 4444, + Protocol: "http", + FromWildcard: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 12, + ModifyIndex: 12, + }, + }, + { + Gateway: structs.NewServiceID("wildcardIngress", nil), + Service: structs.NewServiceID("service2", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 4444, + Protocol: "http", + FromWildcard: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 12, + ModifyIndex: 12, + }, + }, + { + Gateway: structs.NewServiceID("wildcardIngress", nil), + Service: structs.NewServiceID("service3", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 4444, + Protocol: "http", + FromWildcard: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 12, + ModifyIndex: 12, + }, + }, + } idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) - require.NoError(err) - require.Equal(uint64(15), idx) - require.Len(results, 3) - require.Equal("wildcardIngress", results[0].Gateway.ID) - require.Equal("service1", results[0].Service.ID) - require.Equal(4444, results[0].Port) - require.Equal("wildcardIngress", results[1].Gateway.ID) - require.Equal("service2", results[1].Service.ID) - require.Equal(4444, results[1].Port) - require.Equal("wildcardIngress", results[2].Gateway.ID) - require.Equal("service3", results[2].Service.ID) - require.Equal(4444, results[2].Port) + require.NoError(t, err) + require.Equal(t, uint64(16), idx) + require.ElementsMatch(t, results, expected) }) t.Run("gateway with duplicate service", func(t *testing.T) { + expected := structs.GatewayServices{ + { + Gateway: structs.NewServiceID("ingress3", nil), + Service: structs.NewServiceID("service1", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 5555, + Protocol: "http", + FromWildcard: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 15, + ModifyIndex: 15, + }, + }, + { + Gateway: structs.NewServiceID("ingress3", nil), + Service: structs.NewServiceID("service2", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 5555, + Protocol: "http", + FromWildcard: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 15, + ModifyIndex: 15, + }, + }, + { + Gateway: structs.NewServiceID("ingress3", nil), + Service: structs.NewServiceID("service3", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 5555, + Protocol: "http", + FromWildcard: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 15, + ModifyIndex: 15, + }, + }, + { + Gateway: structs.NewServiceID("ingress3", nil), + Service: structs.NewServiceID("service1", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 6666, + Protocol: "http", + RaftIndex: structs.RaftIndex{ + CreateIndex: 15, + ModifyIndex: 15, + }, + }, + } idx, results, err := s.GatewayServices(ws, "ingress3", nil) - require.NoError(err) - require.Equal(uint64(15), idx) - require.Len(results, 4) - require.Equal("ingress3", results[0].Gateway.ID) - require.Equal("service1", results[0].Service.ID) - require.Equal(6666, results[0].Port) - require.Equal("tcp", results[0].Protocol) - require.Equal("ingress3", results[1].Gateway.ID) - require.Equal("service1", results[1].Service.ID) - require.Equal(5555, results[1].Port) - require.Equal("http", results[1].Protocol) - require.Equal("ingress3", results[2].Gateway.ID) - require.Equal("service2", results[2].Service.ID) - require.Equal(5555, results[2].Port) - require.Equal("http", results[2].Protocol) - require.Equal("ingress3", results[3].Gateway.ID) - require.Equal("service3", results[3].Service.ID) - require.Equal(5555, results[3].Port) - require.Equal("http", results[3].Protocol) + require.NoError(t, err) + require.Equal(t, uint64(16), idx) + require.ElementsMatch(t, results, expected) }) t.Run("deregistering a service", func(t *testing.T) { - require.Nil(s.DeleteService(18, "node1", "service1", nil)) - require.True(watchFired(ws)) + require.Nil(t, s.DeleteService(18, "node1", "service1", nil)) + require.True(t, watchFired(ws)) ws = memdb.NewWatchSet() idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) - require.NoError(err) - require.Equal(uint64(18), idx) - require.Len(results, 2) + require.NoError(t, err) + require.Equal(t, uint64(18), idx) + require.Len(t, results, 2) }) // TODO(ingress): This test case fails right now because of a @@ -5087,14 +5175,14 @@ func TestStateStore_GatewayServices_Ingress(t *testing.T) { // }) t.Run("deleting a wildcard config entry", func(t *testing.T) { - require.Nil(s.DeleteConfigEntry(19, "ingress-gateway", "wildcardIngress", nil)) - require.True(watchFired(ws)) + require.Nil(t, s.DeleteConfigEntry(19, "ingress-gateway", "wildcardIngress", nil)) + require.True(t, watchFired(ws)) ws = memdb.NewWatchSet() idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) - require.NoError(err) - require.Equal(uint64(19), idx) - require.Len(results, 0) + require.NoError(t, err) + require.Equal(t, uint64(19), idx) + require.Len(t, results, 0) }) t.Run("updating a config entry with zero listeners", func(t *testing.T) { @@ -5103,13 +5191,13 @@ func TestStateStore_GatewayServices_Ingress(t *testing.T) { Name: "ingress1", Listeners: []structs.IngressListener{}, } - require.Nil(s.EnsureConfigEntry(20, ingress1, nil)) - require.True(watchFired(ws)) + require.Nil(t, s.EnsureConfigEntry(20, ingress1, nil)) + require.True(t, watchFired(ws)) idx, results, err := s.GatewayServices(ws, "ingress1", nil) - require.NoError(err) - require.Equal(uint64(20), idx) - require.Len(results, 0) + require.NoError(t, err) + require.Equal(t, uint64(20), idx) + require.Len(t, results, 0) }) } @@ -5122,21 +5210,21 @@ func TestStateStore_GatewayServices_WildcardAssociation(t *testing.T) { t.Run("base case for wildcard", func(t *testing.T) { idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) require.NoError(err) - require.Equal(uint64(15), idx) + require.Equal(uint64(16), idx) require.Len(results, 3) }) t.Run("do not associate ingress services with gateway", func(t *testing.T) { - testRegisterIngressService(t, s, 15, "node1", "testIngress") + testRegisterIngressService(t, s, 17, "node1", "testIngress") require.False(watchFired(ws)) idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) require.NoError(err) - require.Equal(uint64(15), idx) + require.Equal(uint64(16), idx) require.Len(results, 3) }) t.Run("do not associate terminating-gateway services with gateway", func(t *testing.T) { - require.Nil(s.EnsureService(16, "node1", + require.Nil(s.EnsureService(18, "node1", &structs.NodeService{ Kind: structs.ServiceKindTerminatingGateway, ID: "gateway", Service: "gateway", Port: 443, }, @@ -5144,31 +5232,249 @@ func TestStateStore_GatewayServices_WildcardAssociation(t *testing.T) { require.False(watchFired(ws)) idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) require.NoError(err) - require.Equal(uint64(15), idx) + require.Equal(uint64(16), idx) require.Len(results, 3) }) t.Run("do not associate connect-proxy services with gateway", func(t *testing.T) { - testRegisterSidecarProxy(t, s, 17, "node1", "web") + testRegisterSidecarProxy(t, s, 19, "node1", "web") require.False(watchFired(ws)) idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) require.NoError(err) - require.Equal(uint64(15), idx) + require.Equal(uint64(16), idx) require.Len(results, 3) }) t.Run("do not associate consul services with gateway", func(t *testing.T) { - require.Nil(s.EnsureService(18, "node1", + require.Nil(s.EnsureService(20, "node1", &structs.NodeService{ID: "consul", Service: "consul", Tags: nil}, )) require.False(watchFired(ws)) idx, results, err := s.GatewayServices(ws, "wildcardIngress", nil) require.NoError(err) - require.Equal(uint64(15), idx) + require.Equal(uint64(16), idx) require.Len(results, 3) }) } +func TestStateStore_GatewayServices_IngressProtocolFiltering(t *testing.T) { + s := testStateStore(t) + + t.Run("setup", func(t *testing.T) { + ingress1 := &structs.IngressGatewayConfigEntry{ + Kind: "ingress-gateway", + Name: "ingress1", + Listeners: []structs.IngressListener{ + { + Port: 4444, + Protocol: "http", + Services: []structs.IngressService{ + { + Name: "*", + }, + }, + }, + }, + } + + testRegisterNode(t, s, 0, "node1") + testRegisterService(t, s, 1, "node1", "service1") + testRegisterService(t, s, 2, "node1", "service2") + assert.NoError(t, s.EnsureConfigEntry(4, ingress1, nil)) + }) + + t.Run("no services from default tcp protocol", func(t *testing.T) { + require := require.New(t) + idx, results, err := s.GatewayServices(nil, "ingress1", nil) + require.NoError(err) + require.Equal(uint64(4), idx) + require.Len(results, 0) + }) + + t.Run("service-defaults", func(t *testing.T) { + require := require.New(t) + expected := structs.GatewayServices{ + { + Gateway: structs.NewServiceID("ingress1", nil), + Service: structs.NewServiceID("service1", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 4444, + Protocol: "http", + FromWildcard: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 4, + ModifyIndex: 4, + }, + }, + } + + svcDefaults := &structs.ServiceConfigEntry{ + Name: "service1", + Kind: structs.ServiceDefaults, + Protocol: "http", + } + assert.NoError(t, s.EnsureConfigEntry(5, svcDefaults, nil)) + idx, results, err := s.GatewayServices(nil, "ingress1", nil) + require.NoError(err) + require.Equal(uint64(5), idx) + require.ElementsMatch(results, expected) + }) + + t.Run("proxy-defaults", func(t *testing.T) { + require := require.New(t) + expected := structs.GatewayServices{ + { + Gateway: structs.NewServiceID("ingress1", nil), + Service: structs.NewServiceID("service1", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 4444, + Protocol: "http", + FromWildcard: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 4, + ModifyIndex: 4, + }, + }, + { + Gateway: structs.NewServiceID("ingress1", nil), + Service: structs.NewServiceID("service2", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 4444, + Protocol: "http", + FromWildcard: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 4, + ModifyIndex: 4, + }, + }, + } + + proxyDefaults := &structs.ProxyConfigEntry{ + Name: structs.ProxyConfigGlobal, + Kind: structs.ProxyDefaults, + Config: map[string]interface{}{ + "protocol": "http", + }, + } + assert.NoError(t, s.EnsureConfigEntry(6, proxyDefaults, nil)) + + idx, results, err := s.GatewayServices(nil, "ingress1", nil) + require.NoError(err) + require.Equal(uint64(6), idx) + require.ElementsMatch(results, expected) + }) + + t.Run("service-defaults overrides proxy-defaults", func(t *testing.T) { + require := require.New(t) + expected := structs.GatewayServices{ + { + Gateway: structs.NewServiceID("ingress1", nil), + Service: structs.NewServiceID("service2", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 4444, + Protocol: "http", + FromWildcard: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 4, + ModifyIndex: 4, + }, + }, + } + + svcDefaults := &structs.ServiceConfigEntry{ + Name: "service1", + Kind: structs.ServiceDefaults, + Protocol: "grpc", + } + assert.NoError(t, s.EnsureConfigEntry(7, svcDefaults, nil)) + + idx, results, err := s.GatewayServices(nil, "ingress1", nil) + require.NoError(err) + require.Equal(uint64(7), idx) + require.ElementsMatch(results, expected) + }) + + t.Run("change listener protocol and expect different filter", func(t *testing.T) { + require := require.New(t) + expected := structs.GatewayServices{ + { + Gateway: structs.NewServiceID("ingress1", nil), + Service: structs.NewServiceID("service1", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 4444, + Protocol: "grpc", + FromWildcard: true, + RaftIndex: structs.RaftIndex{ + CreateIndex: 8, + ModifyIndex: 8, + }, + }, + } + + ingress1 := &structs.IngressGatewayConfigEntry{ + Kind: "ingress-gateway", + Name: "ingress1", + Listeners: []structs.IngressListener{ + { + Port: 4444, + Protocol: "grpc", + Services: []structs.IngressService{ + { + Name: "*", + }, + }, + }, + }, + } + assert.NoError(t, s.EnsureConfigEntry(8, ingress1, nil)) + + idx, results, err := s.GatewayServices(nil, "ingress1", nil) + require.NoError(err) + require.Equal(uint64(8), idx) + require.ElementsMatch(results, expected) + }) + + // Relies on service defaults for service1 being set to grpc above + t.Run("only filters on gateway services from wildcards", func(t *testing.T) { + require := require.New(t) + expected := structs.GatewayServices{ + { + Gateway: structs.NewServiceID("ingress1", nil), + Service: structs.NewServiceID("service1", nil), + GatewayKind: structs.ServiceKindIngressGateway, + Port: 4444, + Protocol: "http", + RaftIndex: structs.RaftIndex{ + CreateIndex: 8, + ModifyIndex: 8, + }, + }, + } + + ingress1 := &structs.IngressGatewayConfigEntry{ + Kind: "ingress-gateway", + Name: "ingress1", + Listeners: []structs.IngressListener{ + { + Port: 4444, + Protocol: "http", + Services: []structs.IngressService{ + { + Name: "service1", + }, + }, + }, + }, + } + assert.NoError(t, s.EnsureConfigEntry(8, ingress1, nil)) + + idx, results, err := s.GatewayServices(nil, "ingress1", nil) + require.NoError(err) + require.Equal(uint64(8), idx) + require.ElementsMatch(results, expected) + }) +} + func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { // Querying with no matches gives an empty response ws := memdb.NewWatchSet() @@ -5191,6 +5497,16 @@ func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { testRegisterService(t, s, 9, "node2", "service2") testRegisterService(t, s, 10, "node2", "service3") + // Default protocol to http + proxyDefaults := &structs.ProxyConfigEntry{ + Name: structs.ProxyConfigGlobal, + Kind: structs.ProxyDefaults, + Config: map[string]interface{}{ + "protocol": "http", + }, + } + assert.NoError(t, s.EnsureConfigEntry(11, proxyDefaults, nil)) + // Register some ingress config entries. wildcardIngress := &structs.IngressGatewayConfigEntry{ Kind: "ingress-gateway", @@ -5198,7 +5514,7 @@ func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { Listeners: []structs.IngressListener{ { Port: 4444, - Protocol: "tcp", + Protocol: "http", Services: []structs.IngressService{ { Name: "*", @@ -5207,7 +5523,7 @@ func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { }, }, } - assert.NoError(t, s.EnsureConfigEntry(11, wildcardIngress, nil)) + assert.NoError(t, s.EnsureConfigEntry(12, wildcardIngress, nil)) ingress1 := &structs.IngressGatewayConfigEntry{ Kind: "ingress-gateway", @@ -5215,7 +5531,7 @@ func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { Listeners: []structs.IngressListener{ { Port: 1111, - Protocol: "tcp", + Protocol: "http", Services: []structs.IngressService{ { Name: "service1", @@ -5225,7 +5541,7 @@ func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { }, { Port: 2222, - Protocol: "tcp", + Protocol: "http", Services: []structs.IngressService{ { Name: "service2", @@ -5234,7 +5550,8 @@ func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { }, }, } - assert.NoError(t, s.EnsureConfigEntry(12, ingress1, nil)) + assert.NoError(t, s.EnsureConfigEntry(13, ingress1, nil)) + assert.True(t, watchFired(ws)) ingress2 := &structs.IngressGatewayConfigEntry{ Kind: "ingress-gateway", @@ -5242,7 +5559,7 @@ func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { Listeners: []structs.IngressListener{ { Port: 3333, - Protocol: "tcp", + Protocol: "http", Services: []structs.IngressService{ { Name: "service1", @@ -5251,7 +5568,8 @@ func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { }, }, } - assert.NoError(t, s.EnsureConfigEntry(13, ingress2, nil)) + assert.NoError(t, s.EnsureConfigEntry(14, ingress2, nil)) + assert.True(t, watchFired(ws)) ingress3 := &structs.IngressGatewayConfigEntry{ Kind: "ingress-gateway", @@ -5268,7 +5586,7 @@ func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { }, { Port: 6666, - Protocol: "tcp", + Protocol: "http", Services: []structs.IngressService{ { Name: "service1", @@ -5277,7 +5595,7 @@ func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { }, }, } - assert.NoError(t, s.EnsureConfigEntry(14, ingress3, nil)) + assert.NoError(t, s.EnsureConfigEntry(15, ingress3, nil)) assert.True(t, watchFired(ws)) nothingIngress := &structs.IngressGatewayConfigEntry{ @@ -5285,7 +5603,7 @@ func setupIngressState(t *testing.T, s *Store) memdb.WatchSet { Name: "nothingIngress", Listeners: []structs.IngressListener{}, } - assert.NoError(t, s.EnsureConfigEntry(15, nothingIngress, nil)) + assert.NoError(t, s.EnsureConfigEntry(16, nothingIngress, nil)) assert.True(t, watchFired(ws)) return ws diff --git a/agent/consul/state/config_entry.go b/agent/consul/state/config_entry.go index 2ad84ae45..1169a38a2 100644 --- a/agent/consul/state/config_entry.go +++ b/agent/consul/state/config_entry.go @@ -5,6 +5,7 @@ import ( "github.com/hashicorp/consul/agent/consul/discoverychain" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/lib" memdb "github.com/hashicorp/go-memdb" ) @@ -832,3 +833,46 @@ func (s *Store) configEntryWithOverridesTxn( return s.configEntryTxn(tx, ws, kind, name, entMeta) } + +// protocolForService returns the service graph protocol associated to the +// provided service, checking all relevant config entries. +func (s *Store) protocolForService( + tx *memdb.Txn, + ws memdb.WatchSet, + svc structs.ServiceID, +) (uint64, string, error) { + // Get the global proxy defaults (for default protocol) + maxIdx, proxyConfig, err := s.configEntryTxn(tx, ws, structs.ProxyDefaults, structs.ProxyConfigGlobal, &svc.EnterpriseMeta) + if err != nil { + return 0, "", err + } + + // Check the wildcard entry's protocol against the discovery chain protocol for the service. + idx, serviceDefaults, err := s.configEntryTxn(tx, ws, structs.ServiceDefaults, svc.ID, &svc.EnterpriseMeta) + if err != nil { + return 0, "", err + } + maxIdx = lib.MaxUint64(maxIdx, idx) + + entries := structs.NewDiscoveryChainConfigEntries() + if proxyConfig != nil { + entries.AddEntries(proxyConfig) + } + if serviceDefaults != nil { + entries.AddEntries(serviceDefaults) + } + req := discoverychain.CompileRequest{ + ServiceName: svc.ID, + EvaluateInNamespace: svc.NamespaceOrDefault(), + EvaluateInDatacenter: "dc1", + // Use a dummy trust domain since that won't affect the protocol here. + EvaluateInTrustDomain: "b6fc9da3-03d4-4b5a-9134-c045e9b20152.consul", + UseInDatacenter: "dc1", + Entries: entries, + } + chain, err := discoverychain.Compile(req) + if err != nil { + return 0, "", err + } + return maxIdx, chain.Protocol, nil +} diff --git a/lib/math.go b/lib/math.go index 1d0b6dc0f..0cfc2ad28 100644 --- a/lib/math.go +++ b/lib/math.go @@ -20,3 +20,10 @@ func MinInt(a, b int) int { } return a } + +func MaxUint64(a, b uint64) uint64 { + if a > b { + return a + } + return b +} diff --git a/lib/math_test.go b/lib/math_test.go index 4640bf4f2..e0e52f123 100644 --- a/lib/math_test.go +++ b/lib/math_test.go @@ -36,3 +36,14 @@ func TestMathMinInt(t *testing.T) { } } } + +func TestMathMaxUint64(t *testing.T) { + tests := [][3]uint64{{1, 2, 2}, {0, 1, 1}, {2, 0, 2}} + for _, test := range tests { + expected := test[2] + actual := lib.MaxUint64(test[0], test[1]) + if expected != actual { + t.Fatalf("expected %d, got %d", expected, actual) + } + } +}