open-nomad/nomad/client_agent_endpoint.go
2023-04-10 15:36:59 +00:00

523 lines
13 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package nomad
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"time"
log "github.com/hashicorp/go-hclog"
sframer "github.com/hashicorp/nomad/client/lib/streamframer"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/command/agent/host"
"github.com/hashicorp/nomad/command/agent/monitor"
"github.com/hashicorp/nomad/command/agent/pprof"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/go-msgpack/codec"
)
type Agent struct {
srv *Server
}
func NewAgentEndpoint(srv *Server) *Agent {
return &Agent{srv: srv}
}
func (a *Agent) register() {
a.srv.streamingRpcs.Register("Agent.Monitor", a.monitor)
}
func (a *Agent) Profile(args *structs.AgentPprofRequest, reply *structs.AgentPprofResponse) error {
authErr := a.srv.Authenticate(nil, args)
a.srv.MeasureRPCRate("agent", structs.RateMetricRead, args)
if authErr != nil {
return structs.ErrPermissionDenied
}
// Check ACL for agent write
aclObj, err := a.srv.ResolveACL(args)
if err != nil {
return err
} else if aclObj != nil && !aclObj.AllowAgentWrite() {
return structs.ErrPermissionDenied
}
// Forward to different region if necessary
// this would typically be done in a.srv.forward() but since
// we are targeting a specific server, not just the leader
// we must manually handle region forwarding here.
region := args.RequestRegion()
if region == "" {
return fmt.Errorf("missing target RPC")
}
if region != a.srv.config.Region {
// Mark that we are forwarding
args.SetForwarded()
return a.srv.forwardRegion(region, "Agent.Profile", args, reply)
}
// Targeting a node, forward request to node
if args.NodeID != "" {
return a.forwardProfileClient(args, reply)
}
// Handle serverID not equal to ours
if args.ServerID != "" {
serverToFwd, err := a.forwardFor(args.ServerID, region)
if err != nil {
return err
}
if serverToFwd != nil {
return a.srv.forwardServer(serverToFwd, "Agent.Profile", args, reply)
}
}
// If ACLs are disabled, EnableDebug must be enabled
if aclObj == nil && !a.srv.config.EnableDebug {
return structs.ErrPermissionDenied
}
// Process the request on this server
var resp []byte
var headers map[string]string
// Determine which profile to run and generate profile.
// Blocks for args.Seconds
// Our RPC endpoints currently don't support context
// or request cancellation so using server shutdownCtx as a
// best effort.
switch args.ReqType {
case pprof.CPUReq:
resp, headers, err = pprof.CPUProfile(a.srv.shutdownCtx, args.Seconds)
case pprof.CmdReq:
resp, headers, err = pprof.Cmdline()
case pprof.LookupReq:
resp, headers, err = pprof.Profile(args.Profile, args.Debug, args.GC)
case pprof.TraceReq:
resp, headers, err = pprof.Trace(a.srv.shutdownCtx, args.Seconds)
default:
err = structs.NewErrRPCCoded(404, "Unknown profile request type")
}
if err != nil {
if pprof.IsErrProfileNotFound(err) {
return structs.NewErrRPCCoded(404, err.Error())
}
return structs.NewErrRPCCoded(500, err.Error())
}
// Copy profile response to reply
reply.Payload = resp
reply.HTTPHeaders = headers
reply.AgentID = a.srv.serf.LocalMember().Name
return nil
}
func (a *Agent) monitor(conn io.ReadWriteCloser) {
defer conn.Close()
// Decode args
var args cstructs.MonitorRequest
decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
if err := decoder.Decode(&args); err != nil {
handleStreamResultError(err, pointer.Of(int64(500)), encoder)
return
}
authErr := a.srv.Authenticate(nil, &args)
a.srv.MeasureRPCRate("agent", structs.RateMetricRead, &args)
if authErr != nil {
handleStreamResultError(structs.ErrPermissionDenied, nil, encoder)
return
}
// Check agent read permissions
if aclObj, err := a.srv.ResolveACL(&args); err != nil {
handleStreamResultError(err, nil, encoder)
return
} else if aclObj != nil && !aclObj.AllowAgentRead() {
handleStreamResultError(structs.ErrPermissionDenied, pointer.Of(int64(403)), encoder)
return
}
logLevel := log.LevelFromString(args.LogLevel)
if args.LogLevel == "" {
logLevel = log.LevelFromString("INFO")
}
if logLevel == log.NoLevel {
handleStreamResultError(errors.New("Unknown log level"), pointer.Of(int64(400)), encoder)
return
}
// Targeting a node, forward request to node
if args.NodeID != "" {
a.forwardMonitorClient(conn, args, encoder, decoder)
// forwarded request has ended, return
return
}
region := args.RequestRegion()
if region == "" {
handleStreamResultError(fmt.Errorf("missing target RPC"), pointer.Of(int64(400)), encoder)
return
}
if region != a.srv.config.Region {
// Mark that we are forwarding
args.SetForwarded()
}
// Try to forward request to remote region/server
if args.ServerID != "" {
serverToFwd, err := a.forwardFor(args.ServerID, region)
if err != nil {
handleStreamResultError(err, pointer.Of(int64(400)), encoder)
return
}
if serverToFwd != nil {
a.forwardMonitorServer(conn, serverToFwd, args, encoder, decoder)
return
}
}
// NodeID was empty, ServerID was equal to this server, monitor this server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
monitor := monitor.New(512, a.srv.logger, &log.LoggerOptions{
Level: logLevel,
JSONFormat: args.LogJSON,
})
frames := make(chan *sframer.StreamFrame, 32)
errCh := make(chan error)
var buf bytes.Buffer
frameCodec := codec.NewEncoder(&buf, structs.JsonHandle)
framer := sframer.NewStreamFramer(frames, 1*time.Second, 200*time.Millisecond, 1024)
framer.Run()
defer framer.Destroy()
// goroutine to detect remote side closing
go func() {
if _, err := conn.Read(nil); err != nil {
// One end of the pipe explicitly closed, exit
cancel()
return
}
<-ctx.Done()
}()
logCh := monitor.Start()
defer monitor.Stop()
initialOffset := int64(0)
// receive logs and build frames
go func() {
defer framer.Destroy()
LOOP:
for {
select {
case log := <-logCh:
if err := framer.Send("", "log", log, initialOffset); err != nil {
select {
case errCh <- err:
case <-ctx.Done():
}
break LOOP
}
case <-ctx.Done():
break LOOP
}
}
}()
var streamErr error
OUTER:
for {
select {
case frame, ok := <-frames:
if !ok {
// frame may have been closed when an error
// occurred. Check once more for an error.
select {
case streamErr = <-errCh:
// There was a pending error!
default:
// No error, continue on
}
break OUTER
}
var resp cstructs.StreamErrWrapper
if args.PlainText {
resp.Payload = frame.Data
} else {
if err := frameCodec.Encode(frame); err != nil {
streamErr = err
break OUTER
}
resp.Payload = buf.Bytes()
buf.Reset()
}
if err := encoder.Encode(resp); err != nil {
streamErr = err
break OUTER
}
encoder.Reset(conn)
case <-ctx.Done():
break OUTER
}
}
if streamErr != nil {
handleStreamResultError(streamErr, pointer.Of(int64(500)), encoder)
return
}
}
// forwardFor returns a serverParts for a request to be forwarded to.
// A response of nil, nil indicates that the current server is equal to the
// serverID and region so the request should not be forwarded.
func (a *Agent) forwardFor(serverID, region string) (*serverParts, error) {
var target *serverParts
var err error
if serverID == "leader" {
isLeader, remoteLeader := a.srv.getLeader()
if !isLeader && remoteLeader != nil {
target = remoteLeader
} else if !isLeader && remoteLeader == nil {
return nil, structs.ErrNoLeader
} else if isLeader {
// This server is current leader do not forward
return nil, nil
}
} else {
target, err = a.srv.getServer(region, serverID)
if err != nil {
return nil, err
}
}
// Unable to find a server
if target == nil {
return nil, fmt.Errorf("unknown nomad server %s", serverID)
}
// ServerID is this current server,
// No need to forward request
if target.Name == a.srv.LocalMember().Name {
return nil, nil
}
return target, nil
}
func (a *Agent) forwardMonitorClient(conn io.ReadWriteCloser, args cstructs.MonitorRequest, encoder *codec.Encoder, decoder *codec.Decoder) {
// Get the Connection to the client either by fowarding to another server
// or creating direct stream
state, srv, err := a.findClientConn(args.NodeID)
if err != nil {
handleStreamResultError(err, pointer.Of(int64(500)), encoder)
return
}
var clientConn net.Conn
if state == nil {
conn, err := a.srv.streamingRpc(srv, "Agent.Monitor")
if err != nil {
handleStreamResultError(err, nil, encoder)
return
}
clientConn = conn
} else {
stream, err := NodeStreamingRpc(state.Session, "Agent.Monitor")
if err != nil {
handleStreamResultError(err, nil, encoder)
return
}
clientConn = stream
}
defer clientConn.Close()
// Send the Request
outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle)
if err := outEncoder.Encode(args); err != nil {
handleStreamResultError(err, nil, encoder)
return
}
structs.Bridge(conn, clientConn)
}
func (a *Agent) forwardMonitorServer(conn io.ReadWriteCloser, server *serverParts, args cstructs.MonitorRequest, encoder *codec.Encoder, decoder *codec.Decoder) {
// empty ServerID to prevent forwarding loop
args.ServerID = ""
serverConn, err := a.srv.streamingRpc(server, "Agent.Monitor")
if err != nil {
handleStreamResultError(err, pointer.Of(int64(500)), encoder)
return
}
defer serverConn.Close()
// Send the Request
outEncoder := codec.NewEncoder(serverConn, structs.MsgpackHandle)
if err := outEncoder.Encode(args); err != nil {
handleStreamResultError(err, pointer.Of(int64(500)), encoder)
return
}
structs.Bridge(conn, serverConn)
}
func (a *Agent) forwardProfileClient(args *structs.AgentPprofRequest, reply *structs.AgentPprofResponse) error {
state, srv, err := a.findClientConn(args.NodeID)
if err != nil {
return err
}
if srv != nil {
return a.srv.forwardServer(srv, "Agent.Profile", args, reply)
}
// NodeRpc
rpcErr := NodeRpc(state.Session, "Agent.Profile", args, reply)
if rpcErr != nil {
return rpcErr
}
return nil
}
// Host returns data about the agent's host system for the `debug` command.
func (a *Agent) Host(args *structs.HostDataRequest, reply *structs.HostDataResponse) error {
authErr := a.srv.Authenticate(nil, args)
a.srv.MeasureRPCRate("agent", structs.RateMetricRead, args)
if authErr != nil {
return structs.ErrPermissionDenied
}
aclObj, err := a.srv.ResolveACL(args)
if err != nil {
return err
}
if (aclObj != nil && !aclObj.AllowAgentRead()) ||
(aclObj == nil && !a.srv.config.EnableDebug) {
return structs.ErrPermissionDenied
}
// Forward to different region if necessary
// this would typically be done in a.srv.forward() but since
// we are targeting a specific server, not just the leader
// we must manually handle region forwarding here.
region := args.RequestRegion()
if region == "" {
return fmt.Errorf("missing target RPC")
}
if region != a.srv.config.Region {
// Mark that we are forwarding
args.SetForwarded()
return a.srv.forwardRegion(region, "Agent.Host", args, reply)
}
// Targeting a client node, forward request to node
if args.NodeID != "" {
client, srv, err := a.findClientConn(args.NodeID)
if err != nil {
return err
}
if srv != nil {
return a.srv.forwardServer(srv, "Agent.Host", args, reply)
}
return NodeRpc(client.Session, "Agent.Host", args, reply)
}
// Handle serverID not equal to ours
if args.ServerID != "" {
srv, err := a.forwardFor(args.ServerID, region)
if err != nil {
return err
}
if srv != nil {
return a.srv.forwardServer(srv, "Agent.Host", args, reply)
}
}
data, err := host.MakeHostData()
if err != nil {
return err
}
reply.AgentID = a.srv.serf.LocalMember().Name
reply.HostData = data
return nil
}
// findClientConn is a helper that returns a connection to the client node or, if the client
// is connected to a different server, a serverParts describing the server to which the
// client bound RPC should be forwarded.
func (a *Agent) findClientConn(nodeID string) (*nodeConnState, *serverParts, error) {
snap, err := a.srv.State().Snapshot()
if err != nil {
return nil, nil, structs.NewErrRPCCoded(500, err.Error())
}
node, err := snap.NodeByID(nil, nodeID)
if err != nil {
return nil, nil, structs.NewErrRPCCoded(500, err.Error())
}
if node == nil {
err := fmt.Errorf("Unknown node %q", nodeID)
return nil, nil, structs.NewErrRPCCoded(404, err.Error())
}
if err := nodeSupportsRpc(node); err != nil {
return nil, nil, structs.NewErrRPCCoded(400, err.Error())
}
// Get the Connection to the client either by fowarding to another server
// or creating direct stream
state, ok := a.srv.getNodeConn(nodeID)
if ok {
return state, nil, nil
}
// Determine the server that has a connection to the node
srv, err := a.srv.serverWithNodeConn(nodeID, a.srv.Region())
if err != nil {
code := 500
if structs.IsErrNoNodeConn(err) {
code = 404
}
return nil, nil, structs.NewErrRPCCoded(code, err.Error())
}
return nil, srv, nil
}