Refactor keyring ops:

* changes some functions to return data instead of modifying pointer
  arguments
* renames globalRPC() to keyringRPCs() to make its purpose more clear
* restructures KeyringOperation() to make it more understandable
This commit is contained in:
Hans Hasselberg 2020-08-11 13:35:48 +02:00
parent 08b1fea379
commit e0297b6e99
4 changed files with 89 additions and 67 deletions

View File

@ -250,6 +250,11 @@ func (m *Internal) KeyringOperation(
args *structs.KeyringRequest, args *structs.KeyringRequest,
reply *structs.KeyringResponses) error { reply *structs.KeyringResponses) error {
// Error aggressively to be clear about LocalOnly behavior
if args.LocalOnly && args.Operation != structs.KeyringList {
return fmt.Errorf("argument error: LocalOnly can only be used for List operations")
}
// Check ACLs // Check ACLs
identity, rule, err := m.srv.ResolveTokenToIdentityAndAuthorizer(args.Token) identity, rule, err := m.srv.ResolveTokenToIdentityAndAuthorizer(args.Token)
if err != nil { if err != nil {
@ -277,44 +282,63 @@ func (m *Internal) KeyringOperation(
} }
} }
// Validate use of local-only if args.LocalOnly || args.Forwarded || m.srv.serfWAN == nil {
if args.LocalOnly && args.Operation != structs.KeyringList { // Handle operations that are localOnly, already forwarded or
// Error aggressively to be clear about LocalOnly behavior // there is no serfWAN. If any of this is the case this
return fmt.Errorf("argument error: LocalOnly can only be used for List operations") // operation shouldn't go out to other dcs or WAN pool.
} reply.Responses = append(reply.Responses, m.executeKeyringOpLAN(args)...)
} else {
// Handle not already forwarded, non-local operations.
// args.LocalOnly should always be false for non-GET requests // Marking this as forwarded because this is what we are about
if !args.LocalOnly { // to do. Prevents the same message from being fowarded by
// Only perform WAN keyring querying and RPC forwarding once // other servers.
if !args.Forwarded && m.srv.serfWAN != nil { args.Forwarded = true
args.Forwarded = true reply.Responses = append(reply.Responses, m.executeKeyringOpWAN(args))
m.executeKeyringOp(args, reply, true) reply.Responses = append(reply.Responses, m.executeKeyringOpLAN(args)...)
return m.srv.globalRPC("Internal.KeyringOperation", args, reply)
dcs := m.srv.router.GetRemoteDatacenters(m.srv.config.Datacenter)
responses, err := m.srv.keyringRPCs("Internal.KeyringOperation", args, dcs)
if err != nil {
return err
} }
reply.Add(responses)
} }
// Query the LAN keyring of this node's DC
m.executeKeyringOp(args, reply, false)
return nil return nil
} }
// executeKeyringOp executes the keyring-related operation in the request func (m *Internal) executeKeyringOpLAN(args *structs.KeyringRequest) []*structs.KeyringResponse {
// on either the WAN or LAN pools. responses := []*structs.KeyringResponse{}
func (m *Internal) executeKeyringOp( segments := m.srv.LANSegments()
args *structs.KeyringRequest, for name, segment := range segments {
reply *structs.KeyringResponses, mgr := segment.KeyManager()
wan bool) { serfResp, err := m.executeKeyringOpMgr(mgr, args)
resp := translateKeyResponseToKeyringResponse(serfResp, m.srv.config.Datacenter, err)
if wan { resp.Segment = name
mgr := m.srv.KeyManagerWAN() responses = append(responses, &resp)
m.executeKeyringOpMgr(mgr, args, reply, wan, "")
} else {
segments := m.srv.LANSegments()
for name, segment := range segments {
mgr := segment.KeyManager()
m.executeKeyringOpMgr(mgr, args, reply, wan, name)
}
} }
return responses
}
func (m *Internal) executeKeyringOpWAN(args *structs.KeyringRequest) *structs.KeyringResponse {
mgr := m.srv.KeyManagerWAN()
serfResp, err := m.executeKeyringOpMgr(mgr, args)
resp := translateKeyResponseToKeyringResponse(serfResp, m.srv.config.Datacenter, err)
resp.WAN = true
return &resp
}
func translateKeyResponseToKeyringResponse(keyresponse *serf.KeyResponse, datacenter string, err error) structs.KeyringResponse {
resp := structs.KeyringResponse{
Datacenter: datacenter,
Messages: keyresponse.Messages,
Keys: keyresponse.Keys,
NumNodes: keyresponse.NumNodes,
}
if err != nil {
resp.Error = err.Error()
}
return resp
} }
// executeKeyringOpMgr executes the appropriate keyring-related function based on // executeKeyringOpMgr executes the appropriate keyring-related function based on
@ -323,9 +347,7 @@ func (m *Internal) executeKeyringOp(
func (m *Internal) executeKeyringOpMgr( func (m *Internal) executeKeyringOpMgr(
mgr *serf.KeyManager, mgr *serf.KeyManager,
args *structs.KeyringRequest, args *structs.KeyringRequest,
reply *structs.KeyringResponses, ) (*serf.KeyResponse, error) {
wan bool,
segment string) {
var serfResp *serf.KeyResponse var serfResp *serf.KeyResponse
var err error var err error
@ -341,20 +363,7 @@ func (m *Internal) executeKeyringOpMgr(
serfResp, err = mgr.RemoveKeyWithOptions(args.Key, opts) serfResp, err = mgr.RemoveKeyWithOptions(args.Key, opts)
} }
errStr := "" return serfResp, err
if err != nil {
errStr = err.Error()
}
reply.Responses = append(reply.Responses, &structs.KeyringResponse{
WAN: wan,
Datacenter: m.srv.config.Datacenter,
Segment: segment,
Messages: serfResp.Messages,
Keys: serfResp.Keys,
NumNodes: serfResp.NumNodes,
Error: errStr,
})
} }
// aclAccessorID is used to convert an ACLToken's secretID to its accessorID for non- // aclAccessorID is used to convert an ACLToken's secretID to its accessorID for non-

View File

@ -635,22 +635,17 @@ func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{
return nil return nil
} }
// globalRPC is used to forward an RPC request to one server in each datacenter. // keyringRPCs is used to forward an RPC request to a server in each dc. This
// This will only error for RPC-related errors. Otherwise, application-level // will only error for RPC-related errors. Otherwise, application-level errors
// errors can be sent in the response objects. // can be sent in the response objects.
func (s *Server) globalRPC(method string, args interface{}, func (s *Server) keyringRPCs(method string, args interface{}, dcs []string) (*structs.KeyringResponses, error) {
reply structs.CompoundResponse) error {
// Make a new request into each datacenter errorCh := make(chan error, len(dcs))
dcs := s.router.GetDatacenters() respCh := make(chan *structs.KeyringResponses, len(dcs))
replies, total := 0, len(dcs)
errorCh := make(chan error, total)
respCh := make(chan interface{}, total)
for _, dc := range dcs { for _, dc := range dcs {
go func(dc string) { go func(dc string) {
rr := reply.New() rr := &structs.KeyringResponses{}
if err := s.forwardDC(method, dc, args, &rr); err != nil { if err := s.forwardDC(method, dc, args, &rr); err != nil {
errorCh <- err errorCh <- err
return return
@ -659,16 +654,16 @@ func (s *Server) globalRPC(method string, args interface{},
}(dc) }(dc)
} }
for replies < total { responses := &structs.KeyringResponses{}
for i := 0; i < len(dcs); i++ {
select { select {
case err := <-errorCh: case err := <-errorCh:
return err return nil, err
case rr := <-respCh: case rr := <-respCh:
reply.Add(rr) responses.Add(rr)
replies++
} }
} }
return nil return responses, nil
} }
type raftEncoder func(structs.MessageType, interface{}) ([]byte, error) type raftEncoder func(structs.MessageType, interface{}) ([]byte, error)

View File

@ -1286,7 +1286,7 @@ func (r *fakeGlobalResp) New() interface{} {
return struct{}{} return struct{}{}
} }
func TestServer_globalRPCErrors(t *testing.T) { func TestServer_keyringRPCs(t *testing.T) {
t.Parallel() t.Parallel()
dir1, s1 := testServerDC(t, "dc1") dir1, s1 := testServerDC(t, "dc1")
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
@ -1298,7 +1298,7 @@ func TestServer_globalRPCErrors(t *testing.T) {
}) })
// Check that an error from a remote DC is returned // Check that an error from a remote DC is returned
err := s1.globalRPC("Bad.Method", nil, &fakeGlobalResp{}) _, err := s1.keyringRPCs("Bad.Method", nil, []string{s1.config.Datacenter})
if err == nil { if err == nil {
t.Fatalf("should have errored") t.Fatalf("should have errored")
} }

View File

@ -406,6 +406,24 @@ func (r *Router) GetDatacenters() []string {
return dcs return dcs
} }
// GetRemoteDatacenters returns a list of remote datacenters known to the router, sorted by
// name.
func (r *Router) GetRemoteDatacenters(local string) []string {
r.RLock()
defer r.RUnlock()
dcs := make([]string, 0, len(r.managers))
for dc := range r.managers {
if dc == local {
continue
}
dcs = append(dcs, dc)
}
sort.Strings(dcs)
return dcs
}
// HasDatacenter checks whether dc is defined in WAN // HasDatacenter checks whether dc is defined in WAN
func (r *Router) HasDatacenter(dc string) bool { func (r *Router) HasDatacenter(dc string) bool {
r.RLock() r.RLock()