328 lines
8.6 KiB
Go
328 lines
8.6 KiB
Go
package balancer
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
"net/url"
|
|
"sort"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/resolver"
|
|
"google.golang.org/grpc/resolver/manual"
|
|
"google.golang.org/grpc/stats"
|
|
"google.golang.org/grpc/status"
|
|
|
|
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
|
|
"github.com/hashicorp/consul/sdk/testutil"
|
|
"github.com/hashicorp/consul/sdk/testutil/retry"
|
|
)
|
|
|
|
func TestBalancer(t *testing.T) {
|
|
t.Run("remains pinned to the same server", func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
t.Cleanup(cancel)
|
|
|
|
server1 := runServer(t, "server1")
|
|
server2 := runServer(t, "server2")
|
|
|
|
target, _ := stubResolver(t, server1, server2)
|
|
|
|
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
|
|
balancerBuilder.Register()
|
|
|
|
conn := dial(t, target, balancerBuilder)
|
|
client := testservice.NewSimpleClient(conn)
|
|
|
|
var serverName string
|
|
for i := 0; i < 5; i++ {
|
|
rsp, err := client.Something(ctx, &testservice.Req{})
|
|
require.NoError(t, err)
|
|
|
|
if i == 0 {
|
|
serverName = rsp.ServerName
|
|
} else {
|
|
require.Equal(t, serverName, rsp.ServerName)
|
|
}
|
|
}
|
|
|
|
var pinnedServer, otherServer *server
|
|
switch serverName {
|
|
case server1.name:
|
|
pinnedServer, otherServer = server1, server2
|
|
case server2.name:
|
|
pinnedServer, otherServer = server2, server1
|
|
}
|
|
require.Equal(t, 1,
|
|
pinnedServer.openConnections(),
|
|
"pinned server should have 1 connection",
|
|
)
|
|
require.Zero(t,
|
|
otherServer.openConnections(),
|
|
"other server should have no connections",
|
|
)
|
|
})
|
|
|
|
t.Run("switches server on-error", func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
t.Cleanup(cancel)
|
|
|
|
server1 := runServer(t, "server1")
|
|
server2 := runServer(t, "server2")
|
|
|
|
target, _ := stubResolver(t, server1, server2)
|
|
|
|
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
|
|
balancerBuilder.Register()
|
|
|
|
conn := dial(t, target, balancerBuilder)
|
|
client := testservice.NewSimpleClient(conn)
|
|
|
|
// Figure out which server we're talking to now, and which we should switch to.
|
|
rsp, err := client.Something(ctx, &testservice.Req{})
|
|
require.NoError(t, err)
|
|
|
|
var initialServer, otherServer *server
|
|
switch rsp.ServerName {
|
|
case server1.name:
|
|
initialServer, otherServer = server1, server2
|
|
case server2.name:
|
|
initialServer, otherServer = server2, server1
|
|
}
|
|
|
|
// Next request should fail (we don't have retries configured).
|
|
initialServer.err = status.Error(codes.ResourceExhausted, "rate limit exceeded")
|
|
_, err = client.Something(ctx, &testservice.Req{})
|
|
require.Error(t, err)
|
|
|
|
// Following request should succeed (against the other server).
|
|
rsp, err = client.Something(ctx, &testservice.Req{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, otherServer.name, rsp.ServerName)
|
|
|
|
retry.Run(t, func(r *retry.R) {
|
|
require.Zero(r,
|
|
initialServer.openConnections(),
|
|
"connection to previous server should have been torn down",
|
|
)
|
|
})
|
|
})
|
|
|
|
t.Run("rebalance changes the server", func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
t.Cleanup(cancel)
|
|
|
|
server1 := runServer(t, "server1")
|
|
server2 := runServer(t, "server2")
|
|
|
|
target, _ := stubResolver(t, server1, server2)
|
|
|
|
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
|
|
balancerBuilder.Register()
|
|
|
|
// Provide a custom prioritizer that causes Rebalance to choose whichever
|
|
// server didn't get our first request.
|
|
var otherServer *server
|
|
balancerBuilder.shuffler = func(addrs []resolver.Address) {
|
|
sort.Slice(addrs, func(a, b int) bool {
|
|
return addrs[a].Addr == otherServer.addr
|
|
})
|
|
}
|
|
|
|
conn := dial(t, target, balancerBuilder)
|
|
client := testservice.NewSimpleClient(conn)
|
|
|
|
// Figure out which server we're talking to now.
|
|
rsp, err := client.Something(ctx, &testservice.Req{})
|
|
require.NoError(t, err)
|
|
|
|
var initialServer *server
|
|
switch rsp.ServerName {
|
|
case server1.name:
|
|
initialServer, otherServer = server1, server2
|
|
case server2.name:
|
|
initialServer, otherServer = server2, server1
|
|
}
|
|
|
|
// Trigger a rebalance.
|
|
targetURL, err := url.Parse(target)
|
|
require.NoError(t, err)
|
|
balancerBuilder.Rebalance(resolver.Target{URL: *targetURL})
|
|
|
|
// Following request should hit the other server.
|
|
rsp, err = client.Something(ctx, &testservice.Req{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, otherServer.name, rsp.ServerName)
|
|
|
|
retry.Run(t, func(r *retry.R) {
|
|
require.Zero(r,
|
|
initialServer.openConnections(),
|
|
"connection to previous server should have been torn down",
|
|
)
|
|
})
|
|
})
|
|
|
|
t.Run("resolver removes the server", func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
t.Cleanup(cancel)
|
|
|
|
server1 := runServer(t, "server1")
|
|
server2 := runServer(t, "server2")
|
|
|
|
target, res := stubResolver(t, server1, server2)
|
|
|
|
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
|
|
balancerBuilder.Register()
|
|
|
|
conn := dial(t, target, balancerBuilder)
|
|
client := testservice.NewSimpleClient(conn)
|
|
|
|
// Figure out which server we're talking to now.
|
|
rsp, err := client.Something(ctx, &testservice.Req{})
|
|
require.NoError(t, err)
|
|
var initialServer, otherServer *server
|
|
switch rsp.ServerName {
|
|
case server1.name:
|
|
initialServer, otherServer = server1, server2
|
|
case server2.name:
|
|
initialServer, otherServer = server2, server1
|
|
}
|
|
|
|
// Remove the server's address.
|
|
res.UpdateState(resolver.State{
|
|
Addresses: []resolver.Address{
|
|
{Addr: otherServer.addr},
|
|
},
|
|
})
|
|
|
|
// Following request should hit the other server.
|
|
rsp, err = client.Something(ctx, &testservice.Req{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, otherServer.name, rsp.ServerName)
|
|
|
|
retry.Run(t, func(r *retry.R) {
|
|
require.Zero(r,
|
|
initialServer.openConnections(),
|
|
"connection to previous server should have been torn down",
|
|
)
|
|
})
|
|
|
|
// Remove the other server too.
|
|
res.UpdateState(resolver.State{
|
|
Addresses: []resolver.Address{},
|
|
})
|
|
|
|
_, err = client.Something(ctx, &testservice.Req{})
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "resolver produced no addresses")
|
|
|
|
retry.Run(t, func(r *retry.R) {
|
|
require.Zero(r,
|
|
otherServer.openConnections(),
|
|
"connection to other server should have been torn down",
|
|
)
|
|
})
|
|
})
|
|
}
|
|
|
|
func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
|
|
t.Helper()
|
|
|
|
addresses := make([]resolver.Address, len(servers))
|
|
for idx, s := range servers {
|
|
addresses[idx] = resolver.Address{Addr: s.addr}
|
|
}
|
|
|
|
scheme := fmt.Sprintf("consul-%d-%d", time.Now().UnixNano(), rand.Int())
|
|
|
|
r := manual.NewBuilderWithScheme(scheme)
|
|
r.InitialState(resolver.State{Addresses: addresses})
|
|
|
|
resolver.Register(r)
|
|
t.Cleanup(func() { resolver.UnregisterForTesting(scheme) })
|
|
|
|
return fmt.Sprintf("%s://", scheme), r
|
|
}
|
|
|
|
func runServer(t *testing.T, name string) *server {
|
|
t.Helper()
|
|
|
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
s := &server{
|
|
name: name,
|
|
addr: lis.Addr().String(),
|
|
}
|
|
|
|
gs := grpc.NewServer(
|
|
grpc.StatsHandler(s),
|
|
)
|
|
testservice.RegisterSimpleServer(gs, s)
|
|
go gs.Serve(lis)
|
|
|
|
var once sync.Once
|
|
s.shutdown = func() { once.Do(gs.Stop) }
|
|
t.Cleanup(s.shutdown)
|
|
|
|
return s
|
|
}
|
|
|
|
type server struct {
|
|
name string
|
|
addr string
|
|
err error
|
|
|
|
c int32
|
|
shutdown func()
|
|
}
|
|
|
|
func (s *server) openConnections() int { return int(atomic.LoadInt32(&s.c)) }
|
|
|
|
func (*server) HandleRPC(context.Context, stats.RPCStats) {}
|
|
func (*server) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx }
|
|
func (*server) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx }
|
|
|
|
func (s *server) HandleConn(_ context.Context, cs stats.ConnStats) {
|
|
switch cs.(type) {
|
|
case *stats.ConnBegin:
|
|
atomic.AddInt32(&s.c, 1)
|
|
case *stats.ConnEnd:
|
|
atomic.AddInt32(&s.c, -1)
|
|
}
|
|
}
|
|
|
|
func (*server) Flow(*testservice.Req, testservice.Simple_FlowServer) error { return nil }
|
|
|
|
func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp, error) {
|
|
if s.err != nil {
|
|
return nil, s.err
|
|
}
|
|
return &testservice.Resp{ServerName: s.name}, nil
|
|
}
|
|
|
|
func dial(t *testing.T, target string, builder *Builder) *grpc.ClientConn {
|
|
conn, err := grpc.Dial(
|
|
target,
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithDefaultServiceConfig(
|
|
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, builder.Name()),
|
|
),
|
|
)
|
|
t.Cleanup(func() {
|
|
if err := conn.Close(); err != nil {
|
|
t.Logf("error closing connection: %v", err)
|
|
}
|
|
})
|
|
require.NoError(t, err)
|
|
return conn
|
|
}
|