agent: make the RPC endpoint overwrite mechanism more transparent

This patch hides the RPC handler overwrite mechanism from the
rest of the code so that it works in all cases and that there
is no cooperation required from the tested code, i.e. we can
drop a.getEndpoint().
This commit is contained in:
Frank Schroeder 2017-06-16 09:54:09 +02:00 committed by Frank Schröder
parent 27adc31672
commit d3ab99244b
8 changed files with 68 additions and 83 deletions

View File

@ -174,7 +174,7 @@ func (m *aclManager) lookupACL(a *Agent, id string) (acl.ACL, error) {
args.ETag = cached.ETag args.ETag = cached.ETag
} }
var reply structs.ACLPolicy var reply structs.ACLPolicy
err := a.RPC(a.getEndpoint("ACL")+".GetPolicy", &args, &reply) err := a.RPC("ACL.GetPolicy", &args, &reply)
if err != nil { if err != nil {
if strings.Contains(err.Error(), aclDisabled) { if strings.Contains(err.Error(), aclDisabled) {
a.logger.Printf("[DEBUG] agent: ACLs disabled on servers, will check again after %s", a.config.ACLDisabledTTL) a.logger.Printf("[DEBUG] agent: ACLs disabled on servers, will check again after %s", a.config.ACLDisabledTTL)

View File

@ -47,7 +47,7 @@ func TestACL_Version8(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{} m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -70,7 +70,7 @@ func TestACL_Disabled(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{} m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -123,7 +123,7 @@ func TestACL_Special_IDs(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{} m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -176,7 +176,7 @@ func TestACL_Down_Deny(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{} m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -206,7 +206,7 @@ func TestACL_Down_Allow(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{} m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -236,7 +236,7 @@ func TestACL_Down_Extend(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{} m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -313,7 +313,7 @@ func TestACL_Cache(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{} m := MockServer{}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -495,7 +495,7 @@ func TestACL_vetServiceRegister(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{catalogPolicy} m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -541,7 +541,7 @@ func TestACL_vetServiceUpdate(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{catalogPolicy} m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -577,7 +577,7 @@ func TestACL_vetCheckRegister(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{catalogPolicy} m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -660,7 +660,7 @@ func TestACL_vetCheckUpdate(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{catalogPolicy} m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -716,7 +716,7 @@ func TestACL_filterMembers(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{catalogPolicy} m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -752,7 +752,7 @@ func TestACL_filterServices(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{catalogPolicy} m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -783,7 +783,7 @@ func TestACL_filterChecks(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockServer{catalogPolicy} m := MockServer{catalogPolicy}
if err := a.InjectEndpoint("ACL", &m); err != nil { if err := a.RegisterEndpoint("ACL", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@ -14,7 +14,6 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@ -146,9 +145,9 @@ type Agent struct {
// attempts. // attempts.
retryJoinCh chan error retryJoinCh chan error
// endpoints lets you override RPC endpoints for testing. Not all // endpoints maps unique RPC endpoint names to common ones
// agent methods use this, so use with care and never override // to allow overriding of RPC handlers since the golang
// outside of a unit test. // net/rpc server does not allow this.
endpoints map[string]string endpoints map[string]string
endpointsLock sync.RWMutex endpointsLock sync.RWMutex
@ -1068,9 +1067,34 @@ LOAD:
return nil return nil
} }
// RegisterEndpoint registers a handler for the consul RPC server
// under a unique name while making it accessible under the provided
// name. This allows overwriting handlers for the golang net/rpc
// service which does not allow this.
func (a *Agent) RegisterEndpoint(name string, handler interface{}) error {
srv, ok := a.delegate.(*consul.Server)
if !ok {
panic("agent must be a server")
}
realname := fmt.Sprintf("%s-%d", name, time.Now().UnixNano())
a.endpointsLock.Lock()
a.endpoints[name] = realname
a.endpointsLock.Unlock()
return srv.RegisterEndpoint(realname, handler)
}
// RPC is used to make an RPC call to the Consul servers // RPC is used to make an RPC call to the Consul servers
// This allows the agent to implement the Consul.Interface // This allows the agent to implement the Consul.Interface
func (a *Agent) RPC(method string, args interface{}, reply interface{}) error { func (a *Agent) RPC(method string, args interface{}, reply interface{}) error {
a.endpointsLock.Lock()
// fast path: only translate if there are overrides
if len(a.endpoints) > 0 {
p := strings.SplitN(method, ".", 2)
if e := a.endpoints[p[0]]; e != "" {
method = e + "." + p[1]
}
}
a.endpointsLock.Unlock()
return a.delegate.RPC(method, args, reply) return a.delegate.RPC(method, args, reply)
} }
@ -2255,37 +2279,6 @@ func (a *Agent) DisableNodeMaintenance() {
a.logger.Printf("[INFO] agent: Node left maintenance mode") a.logger.Printf("[INFO] agent: Node left maintenance mode")
} }
// InjectEndpoint overrides the given endpoint with a substitute one. Note
// that not all agent methods use this mechanism, and that is should only
// be used for testing.
func (a *Agent) InjectEndpoint(endpoint string, handler interface{}) error {
srv, ok := a.delegate.(*consul.Server)
if !ok {
return fmt.Errorf("agent must be a server")
}
if err := srv.InjectEndpoint(handler); err != nil {
return err
}
name := reflect.Indirect(reflect.ValueOf(handler)).Type().Name()
a.endpointsLock.Lock()
a.endpoints[endpoint] = name
a.endpointsLock.Unlock()
a.logger.Printf("[WARN] agent: endpoint injected; this should only be used for testing")
return nil
}
// getEndpoint returns the endpoint name to use for the given endpoint,
// which may be overridden.
func (a *Agent) getEndpoint(endpoint string) string {
a.endpointsLock.RLock()
defer a.endpointsLock.RUnlock()
if override, ok := a.endpoints[endpoint]; ok {
return override
}
return endpoint
}
func (a *Agent) ReloadConfig(newCfg *Config) (bool, error) { func (a *Agent) ReloadConfig(newCfg *Config) (bool, error) {
var errs error var errs error

View File

@ -977,10 +977,10 @@ func (s *Server) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io
return nil return nil
} }
// InjectEndpoint is used to substitute an endpoint for testing. // RegisterEndpoint is used to substitute an endpoint for testing.
func (s *Server) InjectEndpoint(endpoint interface{}) error { func (s *Server) RegisterEndpoint(name string, handler interface{}) error {
s.logger.Printf("[WARN] consul: endpoint injected; this should only be used for testing") s.logger.Printf("[WARN] consul: endpoint injected; this should only be used for testing")
return s.rpcServer.Register(endpoint) return s.rpcServer.RegisterName(name, handler)
} }
// Stats is used to return statistics for debugging and insight // Stats is used to return statistics for debugging and insight

View File

@ -695,10 +695,9 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, req,
// likely work in practice, like 10*maxUDPAnswerLimit which should help // likely work in practice, like 10*maxUDPAnswerLimit which should help
// reduce bandwidth if there are thousands of nodes available. // reduce bandwidth if there are thousands of nodes available.
endpoint := d.agent.getEndpoint(preparedQueryEndpoint)
var out structs.PreparedQueryExecuteResponse var out structs.PreparedQueryExecuteResponse
RPC: RPC:
if err := d.agent.RPC(endpoint+".Execute", &args, &out); err != nil { if err := d.agent.RPC("PreparedQuery.Execute", &args, &out); err != nil {
// If they give a bogus query name, treat that as a name error, // If they give a bogus query name, treat that as a name error,
// not a full on server error. We have to use a string compare // not a full on server error. We have to use a string compare
// here since the RPC layer loses the type information. // here since the RPC layer loses the type information.

View File

@ -3932,7 +3932,7 @@ func TestDNS_PreparedQuery_AllowStale(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -4013,7 +4013,7 @@ func TestDNS_PreparedQuery_AgentSource(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@ -37,8 +37,7 @@ func (s *HTTPServer) preparedQueryCreate(resp http.ResponseWriter, req *http.Req
} }
var reply string var reply string
endpoint := s.agent.getEndpoint(preparedQueryEndpoint) if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil {
if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil {
return nil, err return nil, err
} }
return preparedQueryCreateResponse{reply}, nil return preparedQueryCreateResponse{reply}, nil
@ -52,8 +51,7 @@ func (s *HTTPServer) preparedQueryList(resp http.ResponseWriter, req *http.Reque
} }
var reply structs.IndexedPreparedQueries var reply structs.IndexedPreparedQueries
endpoint := s.agent.getEndpoint(preparedQueryEndpoint) if err := s.agent.RPC("PreparedQuery.List", &args, &reply); err != nil {
if err := s.agent.RPC(endpoint+".List", &args, &reply); err != nil {
return nil, err return nil, err
} }
@ -110,8 +108,7 @@ func (s *HTTPServer) preparedQueryExecute(id string, resp http.ResponseWriter, r
} }
var reply structs.PreparedQueryExecuteResponse var reply structs.PreparedQueryExecuteResponse
endpoint := s.agent.getEndpoint(preparedQueryEndpoint) if err := s.agent.RPC("PreparedQuery.Execute", &args, &reply); err != nil {
if err := s.agent.RPC(endpoint+".Execute", &args, &reply); err != nil {
// We have to check the string since the RPC sheds // We have to check the string since the RPC sheds
// the specific error type. // the specific error type.
if err.Error() == consul.ErrQueryNotFound.Error() { if err.Error() == consul.ErrQueryNotFound.Error() {
@ -155,8 +152,7 @@ func (s *HTTPServer) preparedQueryExplain(id string, resp http.ResponseWriter, r
} }
var reply structs.PreparedQueryExplainResponse var reply structs.PreparedQueryExplainResponse
endpoint := s.agent.getEndpoint(preparedQueryEndpoint) if err := s.agent.RPC("PreparedQuery.Explain", &args, &reply); err != nil {
if err := s.agent.RPC(endpoint+".Explain", &args, &reply); err != nil {
// We have to check the string since the RPC sheds // We have to check the string since the RPC sheds
// the specific error type. // the specific error type.
if err.Error() == consul.ErrQueryNotFound.Error() { if err.Error() == consul.ErrQueryNotFound.Error() {
@ -179,8 +175,7 @@ func (s *HTTPServer) preparedQueryGet(id string, resp http.ResponseWriter, req *
} }
var reply structs.IndexedPreparedQueries var reply structs.IndexedPreparedQueries
endpoint := s.agent.getEndpoint(preparedQueryEndpoint) if err := s.agent.RPC("PreparedQuery.Get", &args, &reply); err != nil {
if err := s.agent.RPC(endpoint+".Get", &args, &reply); err != nil {
// We have to check the string since the RPC sheds // We have to check the string since the RPC sheds
// the specific error type. // the specific error type.
if err.Error() == consul.ErrQueryNotFound.Error() { if err.Error() == consul.ErrQueryNotFound.Error() {
@ -212,8 +207,7 @@ func (s *HTTPServer) preparedQueryUpdate(id string, resp http.ResponseWriter, re
args.Query.ID = id args.Query.ID = id
var reply string var reply string
endpoint := s.agent.getEndpoint(preparedQueryEndpoint) if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil {
if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil {
return nil, err return nil, err
} }
return nil, nil return nil, nil
@ -231,8 +225,7 @@ func (s *HTTPServer) preparedQueryDelete(id string, resp http.ResponseWriter, re
s.parseToken(req, &args.Token) s.parseToken(req, &args.Token)
var reply string var reply string
endpoint := s.agent.getEndpoint(preparedQueryEndpoint) if err := s.agent.RPC("PreparedQuery.Apply", &args, &reply); err != nil {
if err := s.agent.RPC(endpoint+".Apply", &args, &reply); err != nil {
return nil, err return nil, err
} }
return nil, nil return nil, nil

View File

@ -74,7 +74,7 @@ func TestPreparedQuery_Create(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -159,7 +159,7 @@ func TestPreparedQuery_List(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -192,7 +192,7 @@ func TestPreparedQuery_List(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -242,7 +242,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -275,7 +275,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -331,7 +331,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -365,7 +365,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -415,7 +415,7 @@ func TestPreparedQuery_Execute(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -479,7 +479,7 @@ func TestPreparedQuery_Explain(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -552,7 +552,7 @@ func TestPreparedQuery_Get(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -617,7 +617,7 @@ func TestPreparedQuery_Update(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -695,7 +695,7 @@ func TestPreparedQuery_Delete(t *testing.T) {
defer a.Shutdown() defer a.Shutdown()
m := MockPreparedQuery{} m := MockPreparedQuery{}
if err := a.InjectEndpoint("PreparedQuery", &m); err != nil { if err := a.RegisterEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }