Forwarding

This commit is contained in:
Alex Dadgar 2018-01-29 22:01:42 -08:00
parent c6827dc63d
commit 6c1fa878ea
6 changed files with 200 additions and 39 deletions

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
"net"
"strings"
"time"
@ -40,6 +41,7 @@ func (f *FileSystem) handleStreamResultError(err error, code *int64, encoder *co
// Stats is used to retrieve the Clients stats.
func (f *FileSystem) Logs(conn io.ReadWriteCloser) {
defer conn.Close()
defer metrics.MeasureSince([]string{"nomad", "file_system", "logs"}, time.Now())
// Decode the arguments
var args cstructs.FsLogsRequest
@ -51,17 +53,50 @@ func (f *FileSystem) Logs(conn io.ReadWriteCloser) {
return
}
// TODO
// We only allow stale reads since the only potentially stale information is
// the Node registration and the cost is fairly high for adding another hope
// in the forwarding chain.
//args.QueryOptions.AllowStale = true
// Check if we need to forward to a different region
if r := args.RequestRegion(); r != f.srv.Region() {
// Request the allocation from the target region
allocReq := &structs.AllocSpecificRequest{
AllocID: args.AllocID,
QueryOptions: args.QueryOptions,
}
var allocResp structs.SingleAllocResponse
if err := f.srv.forwardRegion(r, "Alloc.GetAlloc", allocReq, &allocResp); err != nil {
f.handleStreamResultError(err, nil, encoder)
return
}
// Potentially forward to a different region.
//if done, err := f.srv.forward("FileSystem.Logs", args, args, reply); done {
//return err
//}
defer metrics.MeasureSince([]string{"nomad", "file_system", "logs"}, time.Now())
if allocResp.Alloc == nil {
f.handleStreamResultError(fmt.Errorf("unknown allocation %q", args.AllocID), nil, encoder)
return
}
// Determine the Server that has a connection to the node.
srv, err := f.srv.serverWithNodeConn(allocResp.Alloc.NodeID, r)
if err != nil {
f.handleStreamResultError(err, nil, encoder)
return
}
// Get a connection to the server
srvConn, err := f.srv.streamingRpc(srv, "FileSystem.Logs")
if err != nil {
f.handleStreamResultError(err, nil, encoder)
return
}
defer srvConn.Close()
// Send the request.
outEncoder := codec.NewEncoder(srvConn, structs.MsgpackHandle)
if err := outEncoder.Encode(args); err != nil {
f.handleStreamResultError(err, nil, encoder)
return
}
Bridge(conn, srvConn)
return
}
// Check node read permissions
if aclObj, err := f.srv.ResolveToken(args.AuthToken); err != nil {
@ -100,35 +135,43 @@ func (f *FileSystem) Logs(conn io.ReadWriteCloser) {
}
nodeID := alloc.NodeID
// Get the connection to the client
// Get the connection to the client either by forwarding to another server
// or creating a direct stream
var clientConn net.Conn
state, ok := f.srv.getNodeConn(nodeID)
if !ok {
// Determine the Server that has a connection to the node.
//srv, err := f.srv.serverWithNodeConn(nodeID)
//if err != nil {
//f.handleStreamResultError(err, nil, encoder)
//return
//}
srv, err := f.srv.serverWithNodeConn(nodeID, f.srv.Region())
if err != nil {
f.handleStreamResultError(err, nil, encoder)
return
}
// TODO Forward streaming
//return s.srv.forwardServer(srv, "ClientStats.Stats", args, reply)
return
}
// Get a connection to the server
conn, err := f.srv.streamingRpc(srv, "FileSystem.Logs")
if err != nil {
f.handleStreamResultError(err, nil, encoder)
return
}
stream, err := NodeStreamingRpc(state.Session, "FileSystem.Logs")
if err != nil {
f.handleStreamResultError(err, nil, encoder)
return
clientConn = conn
} else {
stream, err := NodeStreamingRpc(state.Session, "FileSystem.Logs")
if err != nil {
f.handleStreamResultError(err, nil, encoder)
return
}
clientConn = stream
}
defer stream.Close()
defer clientConn.Close()
// Send the request.
outEncoder := codec.NewEncoder(stream, structs.MsgpackHandle)
outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle)
if err := outEncoder.Encode(args); err != nil {
f.handleStreamResultError(err, nil, encoder)
return
}
Bridge(conn, stream)
Bridge(conn, clientConn)
return
}

View File

@ -75,10 +75,7 @@ func (s *Server) removeNodeConn(ctx *RPCContext) {
// ErrNoNodeConn is returned if all local peers could be queried but did not
// have a connection to the node. Otherwise if a connection could not be found
// and there were RPC errors, an error is returned.
func (s *Server) serverWithNodeConn(nodeID string) (*serverParts, error) {
s.peerLock.RLock()
defer s.peerLock.RUnlock()
func (s *Server) serverWithNodeConn(nodeID, region string) (*serverParts, error) {
// We skip ourselves.
selfAddr := s.LocalMember().Addr.String()
@ -90,14 +87,38 @@ func (s *Server) serverWithNodeConn(nodeID string) (*serverParts, error) {
},
}
// Select the list of servers to check based on what region we are querying
s.peerLock.RLock()
var rawTargets []*serverParts
if region == s.Region() {
rawTargets = make([]*serverParts, 0, len(s.localPeers))
for _, srv := range s.localPeers {
rawTargets = append(rawTargets, srv)
}
} else {
peers, ok := s.peers[region]
if !ok {
s.peerLock.RUnlock()
return nil, structs.ErrNoRegionPath
}
rawTargets = peers
}
targets := make([]*serverParts, 0, len(rawTargets))
for _, target := range rawTargets {
targets = append(targets, target.Copy())
}
s.peerLock.RUnlock()
// connections is used to store the servers that have connections to the
// requested node.
var mostRecentServer *serverParts
var mostRecent time.Time
var rpcErr multierror.Error
for addr, server := range s.localPeers {
if string(addr) == selfAddr {
for _, server := range targets {
if server.Addr.String() == selfAddr {
continue
}

View File

@ -23,11 +23,24 @@ func TestServerWithNodeConn_NoPath(t *testing.T) {
testutil.WaitForLeader(t, s2.RPC)
nodeID := uuid.Generate()
srv, err := s1.serverWithNodeConn(nodeID)
srv, err := s1.serverWithNodeConn(nodeID, s1.Region())
require.Nil(srv)
require.EqualError(err, structs.ErrNoNodeConn.Error())
}
func TestServerWithNodeConn_NoPath_Region(t *testing.T) {
t.Parallel()
require := require.New(t)
s1 := TestServer(t, nil)
defer s1.Shutdown()
testutil.WaitForLeader(t, s1.RPC)
nodeID := uuid.Generate()
srv, err := s1.serverWithNodeConn(nodeID, "fake-region")
require.Nil(srv)
require.EqualError(err, structs.ErrNoRegionPath.Error())
}
func TestServerWithNodeConn_Path(t *testing.T) {
t.Parallel()
require := require.New(t)
@ -47,7 +60,32 @@ func TestServerWithNodeConn_Path(t *testing.T) {
NodeID: nodeID,
})
srv, err := s1.serverWithNodeConn(nodeID)
srv, err := s1.serverWithNodeConn(nodeID, s1.Region())
require.NotNil(srv)
require.Equal(srv.Addr.String(), s2.config.RPCAddr.String())
require.Nil(err)
}
func TestServerWithNodeConn_Path_Region(t *testing.T) {
t.Parallel()
require := require.New(t)
s1 := TestServer(t, nil)
defer s1.Shutdown()
s2 := TestServer(t, func(c *Config) {
c.Region = "two"
})
defer s2.Shutdown()
TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)
testutil.WaitForLeader(t, s2.RPC)
// Create a fake connection for the node on server 2
nodeID := uuid.Generate()
s2.addNodeConn(&RPCContext{
NodeID: nodeID,
})
srv, err := s1.serverWithNodeConn(nodeID, s2.Region())
require.NotNil(srv)
require.Equal(srv.Addr.String(), s2.config.RPCAddr.String())
require.Nil(err)
@ -80,7 +118,7 @@ func TestServerWithNodeConn_Path_Newest(t *testing.T) {
NodeID: nodeID,
})
srv, err := s1.serverWithNodeConn(nodeID)
srv, err := s1.serverWithNodeConn(nodeID, s1.Region())
require.NotNil(srv)
require.Equal(srv.Addr.String(), s3.config.RPCAddr.String())
require.Nil(err)
@ -113,7 +151,7 @@ func TestServerWithNodeConn_PathAndErr(t *testing.T) {
// Shutdown the RPC layer for server 3
s3.rpcListener.Close()
srv, err := s1.serverWithNodeConn(nodeID)
srv, err := s1.serverWithNodeConn(nodeID, s1.Region())
require.NotNil(srv)
require.Equal(srv.Addr.String(), s2.config.RPCAddr.String())
require.Nil(err)
@ -140,7 +178,7 @@ func TestServerWithNodeConn_NoPathAndErr(t *testing.T) {
// Shutdown the RPC layer for server 3
s3.rpcListener.Close()
srv, err := s1.serverWithNodeConn(uuid.Generate())
srv, err := s1.serverWithNodeConn(uuid.Generate(), s1.Region())
require.Nil(srv)
require.NotNil(err)
require.Contains(err.Error(), "failed querying")

View File

@ -59,7 +59,7 @@ func (s *ClientStats) Stats(args *structs.ClientStatsRequest, reply *structs.Cli
}
// Determine the Server that has a connection to the node.
srv, err := s.srv.serverWithNodeConn(args.NodeID)
srv, err := s.srv.serverWithNodeConn(args.NodeID, s.srv.Region())
if err != nil {
return err
}

View File

@ -391,6 +391,59 @@ func (s *Server) forwardRegion(region, method string, args interface{}, reply in
return s.connPool.RPC(region, server.Addr, server.MajorVersion, method, args, reply)
}
// streamingRpc creates a connection to the given server and conducts the
// initial handshake, returning the connection or an error. It is the callers
// responsibility to close the connection if there is no returned error.
func (s *Server) streamingRpc(server *serverParts, method string) (net.Conn, error) {
// Try to dial the server
conn, err := net.DialTimeout("tcp", server.Addr.String(), 10*time.Second)
if err != nil {
return nil, err
}
// Cast to TCPConn
if tcp, ok := conn.(*net.TCPConn); ok {
tcp.SetKeepAlive(true)
tcp.SetNoDelay(true)
}
// TODO TLS
// Check if TLS is enabled
//if p.tlsWrap != nil {
//// Switch the connection into TLS mode
//if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil {
//conn.Close()
//return nil, err
//}
//// Wrap the connection in a TLS client
//tlsConn, err := p.tlsWrap(region, conn)
//if err != nil {
//conn.Close()
//return nil, err
//}
//conn = tlsConn
//}
// Write the multiplex byte to set the mode
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {
conn.Close()
return nil, err
}
// Send the header
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
header := structs.StreamingRpcHeader{
Method: method,
}
if err := encoder.Encode(header); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
// raftApplyFuture is used to encode a message, run it through raft, and return the Raft future.
func (s *Server) raftApplyFuture(t structs.MessageType, msg interface{}) (raft.ApplyFuture, error) {
buf, err := structs.Encode(t, msg)

View File

@ -43,6 +43,12 @@ func (s *serverParts) String() string {
s.Name, s.Addr, s.Datacenter)
}
func (s *serverParts) Copy() *serverParts {
ns := new(serverParts)
*ns = *s
return ns
}
// Returns if a member is a Nomad server. Returns a boolean,
// and a struct with the various important components
func isNomadServer(m serf.Member) (bool, *serverParts) {