diff --git a/agent/consul/prepared_query_endpoint.go b/agent/consul/prepared_query_endpoint.go index 1c006ca1d..8873d4aad 100644 --- a/agent/consul/prepared_query_endpoint.go +++ b/agent/consul/prepared_query_endpoint.go @@ -354,7 +354,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest, } // Execute the query for the local DC. - if err := p.execute(query, reply); err != nil { + if err := p.execute(query, reply, args.Connect); err != nil { return err } @@ -450,7 +450,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest, // by the query setup. if len(reply.Nodes) == 0 { wrapper := &queryServerWrapper{p.srv} - if err := queryFailover(wrapper, query, args.Limit, args.QueryOptions, reply); err != nil { + if err := queryFailover(wrapper, query, args, reply); err != nil { return err } } @@ -479,7 +479,7 @@ func (p *PreparedQuery) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRe } // Run the query locally to see what we can find. - if err := p.execute(&args.Query, reply); err != nil { + if err := p.execute(&args.Query, reply, args.Connect); err != nil { return err } @@ -509,13 +509,14 @@ func (p *PreparedQuery) ExecuteRemote(args *structs.PreparedQueryExecuteRemoteRe // execute runs a prepared query in the local DC without any failover. We don't // apply any sorting options or ACL checks at this level - it should be done up above. func (p *PreparedQuery) execute(query *structs.PreparedQuery, - reply *structs.PreparedQueryExecuteResponse) error { + reply *structs.PreparedQueryExecuteResponse, + forceConnect bool) error { state := p.srv.fsm.State() // If we're requesting Connect-capable services, then switch the // lookup to be the Connect function. f := state.CheckServiceNodes - if query.Service.Connect { + if query.Service.Connect || forceConnect { f = state.CheckConnectServiceNodes } @@ -659,7 +660,7 @@ func (q *queryServerWrapper) ForwardDC(method, dc string, args interface{}, repl // queryFailover runs an algorithm to determine which DCs to try and then calls // them to try to locate alternative services. func queryFailover(q queryServer, query *structs.PreparedQuery, - limit int, options structs.QueryOptions, + args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error { // Pull the list of other DCs. This is sorted by RTT in case the user @@ -727,8 +728,9 @@ func queryFailover(q queryServer, query *structs.PreparedQuery, remote := &structs.PreparedQueryExecuteRemoteRequest{ Datacenter: dc, Query: *query, - Limit: limit, - QueryOptions: options, + Limit: args.Limit, + QueryOptions: args.QueryOptions, + Connect: args.Connect, } if err := q.ForwardDC("PreparedQuery.ExecuteRemote", dc, remote, reply); err != nil { q.GetLogger().Printf("[WARN] consul.prepared_query: Failed querying for service '%s' in datacenter '%s': %s", query.Service.Service, dc, err) diff --git a/agent/consul/prepared_query_endpoint_test.go b/agent/consul/prepared_query_endpoint_test.go index 5e16eff0f..7c97962c8 100644 --- a/agent/consul/prepared_query_endpoint_test.go +++ b/agent/consul/prepared_query_endpoint_test.go @@ -2699,6 +2699,37 @@ func TestPreparedQuery_Execute_ConnectExact(t *testing.T) { require.True(reply.QueryMeta.KnownLeader, "queried leader") } + // Run with the Connect setting specified on the request + { + req := structs.PreparedQueryExecuteRequest{ + Datacenter: "dc1", + QueryIDOrName: query.Query.ID, + Connect: true, + } + + var reply structs.PreparedQueryExecuteResponse + require.NoError(msgpackrpc.CallWithCodec( + codec, "PreparedQuery.Execute", &req, &reply)) + + // Result should have two because we should get the native AND + // the proxy (since the destination matches our service name). + require.Len(reply.Nodes, 2) + require.Equal(query.Query.Service.Service, reply.Service) + require.Equal(query.Query.DNS, reply.DNS) + require.True(reply.QueryMeta.KnownLeader, "queried leader") + + // Make sure the native is the first one + if !reply.Nodes[0].Service.Connect.Native { + reply.Nodes[0], reply.Nodes[1] = reply.Nodes[1], reply.Nodes[0] + } + + require.True(reply.Nodes[0].Service.Connect.Native, "native") + require.Equal(reply.Service, reply.Nodes[0].Service.Service) + + require.Equal(structs.ServiceKindConnectProxy, reply.Nodes[1].Service.Kind) + require.Equal(reply.Service, reply.Nodes[1].Service.ProxyDestination) + } + // Update the query query.Query.Service.Connect = true require.NoError(msgpackrpc.CallWithCodec( @@ -2943,7 +2974,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil { + if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil { t.Fatalf("err: %v", err) } if len(reply.Nodes) != 0 || reply.Datacenter != "" || reply.Failovers != 0 { @@ -2959,7 +2990,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply) + err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply) if err == nil || !strings.Contains(err.Error(), "XXX") { t.Fatalf("bad: %v", err) } @@ -2976,7 +3007,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil { + if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil { t.Fatalf("err: %v", err) } if len(reply.Nodes) != 0 || reply.Datacenter != "" || reply.Failovers != 0 { @@ -2999,7 +3030,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil { + if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil { t.Fatalf("err: %v", err) } if len(reply.Nodes) != 3 || @@ -3027,7 +3058,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil { + if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil { t.Fatalf("err: %v", err) } if len(reply.Nodes) != 3 || @@ -3048,7 +3079,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil { + if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil { t.Fatalf("err: %v", err) } if len(reply.Nodes) != 0 || @@ -3077,7 +3108,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil { + if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil { t.Fatalf("err: %v", err) } if len(reply.Nodes) != 3 || @@ -3106,7 +3137,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil { + if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil { t.Fatalf("err: %v", err) } if len(reply.Nodes) != 3 || @@ -3135,7 +3166,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil { + if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil { t.Fatalf("err: %v", err) } if len(reply.Nodes) != 3 || @@ -3170,7 +3201,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil { + if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil { t.Fatalf("err: %v", err) } if len(reply.Nodes) != 3 || @@ -3202,7 +3233,7 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - if err := queryFailover(mock, query, 0, structs.QueryOptions{}, &reply); err != nil { + if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{}, &reply); err != nil { t.Fatalf("err: %v", err) } if len(reply.Nodes) != 3 || @@ -3238,7 +3269,10 @@ func TestPreparedQuery_queryFailover(t *testing.T) { } var reply structs.PreparedQueryExecuteResponse - if err := queryFailover(mock, query, 5, structs.QueryOptions{RequireConsistent: true}, &reply); err != nil { + if err := queryFailover(mock, query, &structs.PreparedQueryExecuteRequest{ + Limit: 5, + QueryOptions: structs.QueryOptions{RequireConsistent: true}, + }, &reply); err != nil { t.Fatalf("err: %v", err) } if len(reply.Nodes) != 3 || diff --git a/agent/structs/prepared_query.go b/agent/structs/prepared_query.go index 842a9b716..bad1b1927 100644 --- a/agent/structs/prepared_query.go +++ b/agent/structs/prepared_query.go @@ -203,6 +203,12 @@ type PreparedQueryExecuteRequest struct { // Limit will trim the resulting list down to the given limit. Limit int + // Connect will force results to be Connect-enabled nodes for the + // matching services. This is equivalent in semantics exactly to + // setting "Connect" in the query template itself, but allows callers + // to use any prepared query in a Connect setting. + Connect bool + // Source is used to sort the results relative to a given node using // network coordinates. Source QuerySource @@ -234,6 +240,9 @@ type PreparedQueryExecuteRemoteRequest struct { // Limit will trim the resulting list down to the given limit. Limit int + // Connect is the same as ExecuteRequest. + Connect bool + // QueryOptions (unfortunately named here) controls the consistency // settings for the the service lookups. QueryOptions