diff --git a/.changelog/15864.txt b/.changelog/15864.txt new file mode 100644 index 000000000..b91ffba97 --- /dev/null +++ b/.changelog/15864.txt @@ -0,0 +1,3 @@ +```release-note:improvement +client: added http api access for tasks via unix socket +``` diff --git a/client/allocrunner/interfaces/task_lifecycle.go b/client/allocrunner/interfaces/task_lifecycle.go index 1bf61bd5a..3ea51c4a9 100644 --- a/client/allocrunner/interfaces/task_lifecycle.go +++ b/client/allocrunner/interfaces/task_lifecycle.go @@ -33,8 +33,6 @@ import ( +-----------+ *Kill (forces terminal) - -Link: http://stable.ascii-flow.appspot.com/#Draw4489375405966393064/1824429135 */ // TaskHook is a lifecycle hook into the life cycle of a task runner. @@ -186,6 +184,9 @@ type TaskStopRequest struct { // ExistingState is previously set hook data and should only be // read. Stop hooks cannot alter state. ExistingState map[string]string + + // TaskDir contains the task's directory tree on the host + TaskDir *allocdir.TaskDir } type TaskStopResponse struct{} diff --git a/client/allocrunner/taskrunner/api_hook.go b/client/allocrunner/taskrunner/api_hook.go new file mode 100644 index 000000000..003d5fecd --- /dev/null +++ b/client/allocrunner/taskrunner/api_hook.go @@ -0,0 +1,119 @@ +package taskrunner + +import ( + "context" + "errors" + "net" + "net/http" + "os" + "path/filepath" + "sync" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/helper/users" +) + +// apiHook exposes the Task API. The Task API allows task's to access the Nomad +// HTTP API without having to discover and connect to an agent's address. +// Instead a unix socket is provided in a standard location. To prevent access +// by untrusted workloads the Task API always requires authentication even when +// ACLs are disabled. +// +// The Task API hook largely soft-fails as there are a number of ways creating +// the unix socket could fail (the most common one being path length +// restrictions), and it is assumed most tasks won't require access to the Task +// API anyway. Tasks that do require access are expected to crash and get +// rescheduled should they land on a client who Task API hook soft-fails. +type apiHook struct { + shutdownCtx context.Context + srv config.APIListenerRegistrar + logger hclog.Logger + + // Lock listener as it is updated from multiple hooks. + lock sync.Mutex + + // Listener is the unix domain socket of the task api for this taks. + ln net.Listener +} + +func newAPIHook(shutdownCtx context.Context, srv config.APIListenerRegistrar, logger hclog.Logger) *apiHook { + h := &apiHook{ + shutdownCtx: shutdownCtx, + srv: srv, + } + h.logger = logger.Named(h.Name()) + return h +} + +func (*apiHook) Name() string { + return "api" +} + +func (h *apiHook) Prestart(_ context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { + h.lock.Lock() + defer h.lock.Unlock() + + if h.ln != nil { + // Listener already set. Task is probably restarting. + return nil + } + + udsPath := apiSocketPath(req.TaskDir) + udsln, err := users.SocketFileFor(h.logger, udsPath, req.Task.User) + if err != nil { + // Soft-fail and let the task fail if it requires the task api. + h.logger.Warn("error creating task api socket", "path", udsPath, "error", err) + return nil + } + + go func() { + // Cannot use Prestart's context as it is closed after all prestart hooks + // have been closed, but we do want to try to cleanup on shutdown. + if err := h.srv.Serve(h.shutdownCtx, udsln); err != nil { + if errors.Is(err, http.ErrServerClosed) { + return + } + if errors.Is(err, net.ErrClosed) { + return + } + h.logger.Error("error serving task api", "error", err) + } + }() + + h.ln = udsln + return nil +} + +func (h *apiHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error { + h.lock.Lock() + defer h.lock.Unlock() + + if h.ln != nil { + if err := h.ln.Close(); err != nil { + if !errors.Is(err, net.ErrClosed) { + h.logger.Debug("error closing task listener: %v", err) + } + } + h.ln = nil + } + + // Best-effort at cleaining things up. Alloc dir cleanup will remove it if + // this fails for any reason. + _ = os.RemoveAll(apiSocketPath(req.TaskDir)) + + return nil +} + +// apiSocketPath returns the path to the Task API socket. +// +// The path needs to be as short as possible because of the low limits on the +// sun_path char array imposed by the syscall used to create unix sockets. +// +// See https://github.com/hashicorp/nomad/pull/13971 for an example of the +// sadness this causes. +func apiSocketPath(taskDir *allocdir.TaskDir) string { + return filepath.Join(taskDir.SecretsDir, "api.sock") +} diff --git a/client/allocrunner/taskrunner/api_hook_test.go b/client/allocrunner/taskrunner/api_hook_test.go new file mode 100644 index 000000000..164d4433d --- /dev/null +++ b/client/allocrunner/taskrunner/api_hook_test.go @@ -0,0 +1,169 @@ +package taskrunner + +import ( + "context" + "io/fs" + "net" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "syscall" + "testing" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/helper/users" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/shoenig/test/must" +) + +type testAPIListenerRegistrar struct { + cb func(net.Listener) error +} + +func (n testAPIListenerRegistrar) Serve(_ context.Context, ln net.Listener) error { + if n.cb != nil { + return n.cb(ln) + } + return nil +} + +// TestAPIHook_SoftFail asserts that the Task API Hook soft fails and does not +// return errors. +func TestAPIHook_SoftFail(t *testing.T) { + ci.Parallel(t) + + // Use a SecretsDir that will always exceed Unix socket path length + // limits (sun_path) + dst := filepath.Join(t.TempDir(), strings.Repeat("_NOMAD_TEST_", 100)) + + ctx := context.Background() + srv := testAPIListenerRegistrar{} + logger := testlog.HCLogger(t) + h := newAPIHook(ctx, srv, logger) + + req := &interfaces.TaskPrestartRequest{ + Task: &structs.Task{}, // needs to be non-nil for Task.User lookup + TaskDir: &allocdir.TaskDir{ + SecretsDir: dst, + }, + } + resp := &interfaces.TaskPrestartResponse{} + + err := h.Prestart(ctx, req, resp) + must.NoError(t, err) + + // listener should not have been set + must.Nil(t, h.ln) + + // File should not have been created + _, err = os.Stat(dst) + must.Error(t, err) + + // Assert stop also soft-fails + stopReq := &interfaces.TaskStopRequest{ + TaskDir: req.TaskDir, + } + stopResp := &interfaces.TaskStopResponse{} + err = h.Stop(ctx, stopReq, stopResp) + must.NoError(t, err) + + // File should not have been created + _, err = os.Stat(dst) + must.Error(t, err) +} + +// TestAPIHook_Ok asserts that the Task API Hook creates and cleans up a +// socket. +func TestAPIHook_Ok(t *testing.T) { + ci.Parallel(t) + + // If this test fails it may be because TempDir() + /api.sock is longer than + // the unix socket path length limit (sun_path) in which case the test should + // use a different temporary directory on that platform. + dst := t.TempDir() + + // Write "ok" and close the connection and listener + srv := testAPIListenerRegistrar{ + cb: func(ln net.Listener) error { + conn, err := ln.Accept() + if err != nil { + return err + } + if _, err = conn.Write([]byte("ok")); err != nil { + return err + } + conn.Close() + return nil + }, + } + + ctx := context.Background() + logger := testlog.HCLogger(t) + h := newAPIHook(ctx, srv, logger) + + req := &interfaces.TaskPrestartRequest{ + Task: &structs.Task{ + User: "nobody", + }, + TaskDir: &allocdir.TaskDir{ + SecretsDir: dst, + }, + } + resp := &interfaces.TaskPrestartResponse{} + + err := h.Prestart(ctx, req, resp) + must.NoError(t, err) + + // File should have been created + sockDst := apiSocketPath(req.TaskDir) + + // Stat and chown fail on Windows, so skip these checks + if runtime.GOOS != "windows" { + stat, err := os.Stat(sockDst) + must.NoError(t, err) + must.True(t, stat.Mode()&fs.ModeSocket != 0, + must.Sprintf("expected %q to be a unix socket but got %s", sockDst, stat.Mode())) + + nobody, _ := users.Lookup("nobody") + if syscall.Getuid() == 0 && nobody != nil { + t.Logf("root and nobody exists: testing file perms") + + // We're root and nobody exists! Check perms + must.Eq(t, fs.FileMode(0o600), stat.Mode().Perm()) + + sysStat, ok := stat.Sys().(*syscall.Stat_t) + must.True(t, ok, must.Sprintf("expected stat.Sys() to be a *syscall.Stat_t on %s but found %T", + runtime.GOOS, stat.Sys())) + + nobodyUID, err := strconv.Atoi(nobody.Uid) + must.NoError(t, err) + must.Eq(t, nobodyUID, int(sysStat.Uid)) + } + } + + // Assert the listener is working + conn, err := net.Dial("unix", sockDst) + must.NoError(t, err) + buf := make([]byte, 2) + _, err = conn.Read(buf) + must.NoError(t, err) + must.Eq(t, []byte("ok"), buf) + conn.Close() + + // Assert stop cleans up + stopReq := &interfaces.TaskStopRequest{ + TaskDir: req.TaskDir, + } + stopResp := &interfaces.TaskStopResponse{} + err = h.Stop(ctx, stopReq, stopResp) + must.NoError(t, err) + + // File should be gone + _, err = net.Dial("unix", sockDst) + must.Error(t, err) +} diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index 7c1b73dbd..7ea501982 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -68,6 +68,7 @@ func (tr *TaskRunner) initHooks() { newArtifactHook(tr, tr.getter, hookLogger), newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger), newDeviceHook(tr.devicemanager, hookLogger), + newAPIHook(tr.shutdownCtx, tr.clientConfig.APIListenerRegistrar, hookLogger), } // If the task has a CSI block, add the hook. @@ -431,7 +432,9 @@ func (tr *TaskRunner) stop() error { tr.logger.Trace("running stop hook", "name", name, "start", start) } - req := interfaces.TaskStopRequest{} + req := interfaces.TaskStopRequest{ + TaskDir: tr.taskDir, + } origHookState := tr.hookState(name) if origHookState != nil { diff --git a/client/config/config.go b/client/config/config.go index 2945f6daa..466230ddc 100644 --- a/client/config/config.go +++ b/client/config/config.go @@ -1,8 +1,10 @@ package config import ( + "context" "errors" "fmt" + "net" "reflect" "strconv" "strings" @@ -301,10 +303,27 @@ type Config struct { // used for template functions which require access to the Nomad API. TemplateDialer *bufconndialer.BufConnWrapper + // APIListenerRegistrar allows the client to register listeners created at + // runtime (eg the Task API) with the agent's HTTP server. Since the agent + // creates the HTTP *after* the client starts, we have to use this shim to + // pass listeners back to the agent. + // This is the same design as the bufconndialer but for the + // http.Serve(listener) API instead of the net.Dial API. + APIListenerRegistrar APIListenerRegistrar + // Artifact configuration from the agent's config file. Artifact *ArtifactConfig } +type APIListenerRegistrar interface { + // Serve the HTTP API on the provided listener. + // + // The context is because Serve may be called before the HTTP server has been + // initialized. If the context is canceled before the HTTP server is + // initialized, the context's error will be returned. + Serve(context.Context, net.Listener) error +} + // ClientTemplateConfig is configuration on the client specific to template // rendering type ClientTemplateConfig struct { diff --git a/client/config/testing.go b/client/config/testing.go index 02f87984f..adb703de2 100644 --- a/client/config/testing.go +++ b/client/config/testing.go @@ -1,7 +1,9 @@ package config import ( + "context" "io/ioutil" + "net" "os" "path/filepath" "time" @@ -74,5 +76,14 @@ func TestClientConfig(t testing.T) (*Config, func()) { // Same as default; necessary for task Event messages conf.MaxKillTimeout = 30 * time.Second + // Provide a stub APIListenerRegistrar implementation + conf.APIListenerRegistrar = NoopAPIListenerRegistrar{} + return conf, cleanup } + +type NoopAPIListenerRegistrar struct{} + +func (NoopAPIListenerRegistrar) Serve(_ context.Context, _ net.Listener) error { + return nil +} diff --git a/command/agent/agent.go b/command/agent/agent.go index ae3d3e661..d5703f721 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -115,7 +115,11 @@ type Agent struct { builtinListener net.Listener builtinDialer *bufconndialer.BufConnWrapper - InmemSink *metrics.InmemSink + // builtinServer is an HTTP server for attaching per-task listeners. Always + // requires auth. + builtinServer *builtinAPI + + inmemSink *metrics.InmemSink } // NewAgent is used to create a new agent with the given configuration @@ -124,7 +128,7 @@ func NewAgent(config *Config, logger log.InterceptLogger, logOutput io.Writer, i config: config, logOutput: logOutput, shutdownCh: make(chan struct{}), - InmemSink: inmem, + inmemSink: inmem, } // Create the loggers @@ -1020,6 +1024,11 @@ func (a *Agent) setupClient() error { a.builtinListener, a.builtinDialer = bufconndialer.New() conf.TemplateDialer = a.builtinDialer + // Initialize builtin API server here for use in the client, but it won't + // accept connections until the HTTP servers are created. + a.builtinServer = newBuiltinAPI() + conf.APIListenerRegistrar = a.builtinServer + nomadClient, err := client.NewClient( conf, a.consulCatalog, a.consulProxies, a.consulService, nil) if err != nil { @@ -1300,6 +1309,11 @@ func (a *Agent) GetConfig() *Config { return a.config } +// GetMetricsSink returns the metrics sink. +func (a *Agent) GetMetricsSink() *metrics.InmemSink { + return a.inmemSink +} + // setupConsul creates the Consul client and starts its main Run loop. func (a *Agent) setupConsul(consulConfig *config.ConsulConfig) error { apiConf, err := consulConfig.ApiConfig() diff --git a/command/agent/agent_endpoint.go b/command/agent/agent_endpoint.go index 898001ca8..21df7c3b3 100644 --- a/command/agent/agent_endpoint.go +++ b/command/agent/agent_endpoint.go @@ -440,7 +440,7 @@ func (s *HTTPServer) listServers(resp http.ResponseWriter, req *http.Request) (i return nil, structs.ErrPermissionDenied } - peers := s.agent.client.GetServers() + peers := client.GetServers() sort.Strings(peers) return peers, nil } @@ -468,9 +468,9 @@ func (s *HTTPServer) updateServers(resp http.ResponseWriter, req *http.Request) } // Set the servers list into the client - s.agent.logger.Trace("adding servers to the client's primary server list", "servers", servers, "path", "/v1/agent/servers", "method", "PUT") + s.logger.Trace("adding servers to the client's primary server list", "servers", servers, "path", "/v1/agent/servers", "method", "PUT") if _, err := client.SetServers(servers); err != nil { - s.agent.logger.Error("failed adding servers to client's server list", "servers", servers, "error", err, "path", "/v1/agent/servers", "method", "PUT") + s.logger.Error("failed adding servers to client's server list", "servers", servers, "error", err, "path", "/v1/agent/servers", "method", "PUT") //TODO is this the right error to return? return nil, CodedError(400, err.Error()) } @@ -708,7 +708,7 @@ func (s *HTTPServer) AgentHostRequest(resp http.ResponseWriter, req *http.Reques // The RPC endpoint actually forwards the request to the correct // agent, but we need to use the correct RPC interface. localClient, remoteClient, localServer := s.rpcHandlerForNode(lookupNodeID) - s.agent.logger.Debug("s.rpcHandlerForNode()", "lookupNodeID", lookupNodeID, "serverID", serverID, "nodeID", nodeID, "localClient", localClient, "remoteClient", remoteClient, "localServer", localServer) + s.logger.Debug("s.rpcHandlerForNode()", "lookupNodeID", lookupNodeID, "serverID", serverID, "nodeID", nodeID, "localClient", localClient, "remoteClient", remoteClient, "localServer", localServer) // Make the RPC call if localClient { diff --git a/command/agent/alloc_endpoint.go b/command/agent/alloc_endpoint.go index 18b4dfdc3..ee071e7db 100644 --- a/command/agent/alloc_endpoint.go +++ b/command/agent/alloc_endpoint.go @@ -222,7 +222,7 @@ func (s *HTTPServer) ClientAllocRequest(resp http.ResponseWriter, req *http.Requ case "exec": return s.allocExec(allocID, resp, req) case "snapshot": - if s.agent.client == nil { + if s.agent.Client() == nil { return nil, clientNotRunning } return s.allocSnapshot(allocID, resp, req) diff --git a/command/agent/consul/int_test.go b/command/agent/consul/int_test.go index 1fd6a6fcd..02124f108 100644 --- a/command/agent/consul/int_test.go +++ b/command/agent/consul/int_test.go @@ -46,7 +46,7 @@ func TestConsul_Integration(t *testing.T) { // Create an embedded Consul server testconsul, err := testutil.NewTestServerConfigT(t, func(c *testutil.TestServerConfig) { - c.Peering = nil // fix for older versions of Consul (<1.13.0) that don't support peering + c.Peering = nil // fix for older versions of Consul (<1.13.0) that don't support peering // If -v wasn't specified squelch consul logging if !testing.Verbose() { c.Stdout = ioutil.Discard @@ -61,6 +61,7 @@ func TestConsul_Integration(t *testing.T) { conf := config.DefaultConfig() conf.Node = mock.Node() conf.ConsulConfig.Addr = testconsul.HTTPAddr + conf.APIListenerRegistrar = config.NoopAPIListenerRegistrar{} consulConfig, err := conf.ConsulConfig.ApiConfig() if err != nil { t.Fatalf("error generating consul config: %v", err) diff --git a/command/agent/http.go b/command/agent/http.go index da3580e77..b51d6696a 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "context" "crypto/tls" "encoding/json" "errors" @@ -26,8 +27,10 @@ import ( "golang.org/x/time/rate" "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/client" "github.com/hashicorp/nomad/helper/noxssrw" "github.com/hashicorp/nomad/helper/tlsutil" + "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" ) @@ -74,9 +77,18 @@ var ( type handlerFn func(resp http.ResponseWriter, req *http.Request) (interface{}, error) type handlerByteFn func(resp http.ResponseWriter, req *http.Request) ([]byte, error) +type RPCer interface { + RPC(string, any, any) error + Server() *nomad.Server + Client() *client.Client + Stats() map[string]map[string]string + GetConfig() *Config + GetMetricsSink() *metrics.InmemSink +} + // HTTPServer is used to wrap an Agent and expose it over an HTTP interface type HTTPServer struct { - agent *Agent + agent RPCer mux *http.ServeMux listener net.Listener listenerCh chan struct{} @@ -170,7 +182,7 @@ func NewHTTPServers(agent *Agent, config *Config) ([]*HTTPServer, error) { srvs = append(srvs, srv) } - // This HTTP server is only create when running in client mode, otherwise + // This HTTP server is only created when running in client mode, otherwise // the builtinDialer and builtinListener will be nil. if agent.builtinDialer != nil && agent.builtinListener != nil { srv := &HTTPServer{ @@ -185,12 +197,15 @@ func NewHTTPServers(agent *Agent, config *Config) ([]*HTTPServer, error) { srv.registerHandlers(config.EnableDebug) + // builtinServer adds a wrapper to always authenticate requests httpServer := http.Server{ Addr: srv.Addr, - Handler: srv.mux, + Handler: newAuthMiddleware(srv, srv.mux), ErrorLog: newHTTPServerLogger(srv.logger), } + agent.builtinServer.SetServer(&httpServer) + go func() { defer close(srv.listenerCh) httpServer.Serve(agent.builtinListener) @@ -465,7 +480,8 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) { s.mux.Handle("/v1/vars", wrapCORS(s.wrap(s.VariablesListRequest))) s.mux.Handle("/v1/var/", wrapCORSWithAllowedMethods(s.wrap(s.VariableSpecificRequest), "HEAD", "GET", "PUT", "DELETE")) - uiConfigEnabled := s.agent.config.UI != nil && s.agent.config.UI.Enabled + agentConfig := s.agent.GetConfig() + uiConfigEnabled := agentConfig.UI != nil && agentConfig.UI.Enabled if uiEnabled && uiConfigEnabled { s.mux.Handle("/ui/", http.StripPrefix("/ui/", s.handleUI(http.FileServer(&UIAssetWrapper{FileSystem: assetFS()})))) @@ -484,7 +500,7 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) { s.mux.Handle("/", s.handleRootFallthrough()) if enableDebug { - if !s.agent.config.DevMode { + if !agentConfig.DevMode { s.logger.Warn("enable_debug is set to true. This is insecure and should not be enabled in production") } s.mux.HandleFunc("/debug/pprof/", pprof.Index) @@ -498,6 +514,54 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) { s.registerEnterpriseHandlers() } +// builtinAPI is a wrapper around serving the HTTP API to arbitrary listeners +// such as the Task API. It is necessary because the HTTP servers are created +// *after* the client has been initialized, so this wrapper blocks Serve +// requests from task api hooks until the HTTP server is setup and ready to +// accept from new listeners. +// +// bufconndialer provides similar functionality to consul-template except it +// satisfies the Dialer API as opposed to the Serve(Listener) API. +type builtinAPI struct { + srv *http.Server + srvReadyCh chan struct{} +} + +func newBuiltinAPI() *builtinAPI { + return &builtinAPI{ + srvReadyCh: make(chan struct{}), + } +} + +// SetServer sets the API HTTP server for Serve to add listeners to. +// +// It must be called exactly once and will panic if called more than once. +func (b *builtinAPI) SetServer(srv *http.Server) { + select { + case <-b.srvReadyCh: + panic(fmt.Sprintf("SetServer called twice. first=%p second=%p", b.srv, srv)) + default: + } + b.srv = srv + close(b.srvReadyCh) +} + +// Serve the HTTP API on the listener unless the context is canceled before the +// HTTP API is ready to serve listeners. A non-nil error will always be +// returned, but http.ErrServerClosed and net.ErrClosed can likely be ignored +// as they indicate the server or listener is being shutdown. +func (b *builtinAPI) Serve(ctx context.Context, l net.Listener) error { + select { + case <-ctx.Done(): + // Caller canceled context before server was ready. + return ctx.Err() + case <-b.srvReadyCh: + // Server ready for listeners! Continue on... + } + + return b.srv.Serve(l) +} + // HTTPCodedError is used to provide the HTTP error code type HTTPCodedError interface { error @@ -591,7 +655,7 @@ func errCodeFromHandler(err error) (int, string) { // wrap is used to wrap functions to make them more convenient func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Request) (interface{}, error)) func(resp http.ResponseWriter, req *http.Request) { f := func(resp http.ResponseWriter, req *http.Request) { - setHeaders(resp, s.agent.config.HTTPAPIResponseHeaders) + setHeaders(resp, s.agent.GetConfig().HTTPAPIResponseHeaders) // Invoke the handler reqURL := req.URL.String() start := time.Now() @@ -673,7 +737,7 @@ func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Reque // Handler functions are responsible for setting Content-Type Header func (s *HTTPServer) wrapNonJSON(handler func(resp http.ResponseWriter, req *http.Request) ([]byte, error)) func(resp http.ResponseWriter, req *http.Request) { f := func(resp http.ResponseWriter, req *http.Request) { - setHeaders(resp, s.agent.config.HTTPAPIResponseHeaders) + setHeaders(resp, s.agent.GetConfig().HTTPAPIResponseHeaders) // Invoke the handler reqURL := req.URL.String() start := time.Now() @@ -817,7 +881,7 @@ func (s *HTTPServer) parseRegion(req *http.Request, r *string) { if other := req.URL.Query().Get("region"); other != "" { *r = other } else if *r == "" { - *r = s.agent.config.Region + *r = s.agent.GetConfig().Region } } @@ -976,3 +1040,55 @@ func wrapCORS(f func(http.ResponseWriter, *http.Request)) http.Handler { func wrapCORSWithAllowedMethods(f func(http.ResponseWriter, *http.Request), methods ...string) http.Handler { return allowCORSWithMethods(methods...).Handler(http.HandlerFunc(f)) } + +// authMiddleware implements the http.Handler interface to enforce +// authentication for *all* requests. Even with ACLs enabled there are +// endpoints which are accessible without authenticating. This middleware is +// used for the Task API to enfoce authentication for all API access. +type authMiddleware struct { + srv *HTTPServer + wrapped http.Handler +} + +func newAuthMiddleware(srv *HTTPServer, h http.Handler) http.Handler { + return &authMiddleware{ + srv: srv, + wrapped: h, + } +} + +func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + args := structs.GenericRequest{} + reply := structs.ACLWhoAmIResponse{} + if a.srv.parse(resp, req, &args.Region, &args.QueryOptions) { + // Error parsing request, 400 + resp.WriteHeader(http.StatusBadRequest) + resp.Write([]byte(http.StatusText(http.StatusBadRequest))) + return + } + + if args.AuthToken == "" { + // 401 instead of 403 since no token was present. + resp.WriteHeader(http.StatusUnauthorized) + resp.Write([]byte(http.StatusText(http.StatusUnauthorized))) + return + } + + if err := a.srv.agent.RPC("ACL.WhoAmI", &args, &reply); err != nil { + a.srv.logger.Error("error authenticating built API request", "error", err, "url", req.URL, "method", req.Method) + resp.WriteHeader(500) + resp.Write([]byte("Server error authenticating request\n")) + return + } + + // Require an acl token or workload identity + if reply.Identity == nil || (reply.Identity.ACLToken == nil && reply.Identity.Claims == nil) { + a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL) + resp.WriteHeader(http.StatusForbidden) + resp.Write([]byte(http.StatusText(http.StatusForbidden))) + return + } + + a.srv.logger.Trace("Authenticated request", "id", reply.Identity, "method", req.Method, "url", req.URL) + a.wrapped.ServeHTTP(resp, req) +} diff --git a/command/agent/job_endpoint.go b/command/agent/job_endpoint.go index 47f5fce9c..407e782b6 100644 --- a/command/agent/job_endpoint.go +++ b/command/agent/job_endpoint.go @@ -819,7 +819,7 @@ func (s *HTTPServer) apiJobAndRequestToStructs(job *api.Job, req *http.Request, queryRegion := req.URL.Query().Get("region") requestRegion, jobRegion := regionForJob( - job, queryRegion, writeReq.Region, s.agent.config.Region, + job, queryRegion, writeReq.Region, s.agent.GetConfig().Region, ) sJob := ApiJobToStructJob(job) diff --git a/command/agent/metrics_endpoint.go b/command/agent/metrics_endpoint.go index 7233492ae..90686cb3e 100644 --- a/command/agent/metrics_endpoint.go +++ b/command/agent/metrics_endpoint.go @@ -25,14 +25,14 @@ func (s *HTTPServer) MetricsRequest(resp http.ResponseWriter, req *http.Request) // Only return Prometheus formatted metrics if the user has enabled // this functionality. - if !s.agent.config.Telemetry.PrometheusMetrics { + if !s.agent.GetConfig().Telemetry.PrometheusMetrics { return nil, CodedError(http.StatusUnsupportedMediaType, "Prometheus is not enabled") } s.prometheusHandler().ServeHTTP(resp, req) return nil, nil } - return s.agent.InmemSink.DisplayMetrics(resp, req) + return s.agent.GetMetricsSink().DisplayMetrics(resp, req) } func (s *HTTPServer) prometheusHandler() http.Handler { diff --git a/command/agent/variable_endpoint.go b/command/agent/variable_endpoint.go index 17f55ac9c..bbf8b03bc 100644 --- a/command/agent/variable_endpoint.go +++ b/command/agent/variable_endpoint.go @@ -16,7 +16,8 @@ func (s *HTTPServer) VariablesListRequest(resp http.ResponseWriter, req *http.Re args := structs.VariablesListRequest{} if s.parse(resp, req, &args.Region, &args.QueryOptions) { - return nil, nil + //TODO(schmichael) shouldn't we return something here?! + return nil, CodedError(http.StatusBadRequest, "failed to parse parameters") } var out structs.VariablesListResponse diff --git a/e2e/workload_id/input/api-auth.nomad.hcl b/e2e/workload_id/input/api-auth.nomad.hcl new file mode 100644 index 000000000..dae134697 --- /dev/null +++ b/e2e/workload_id/input/api-auth.nomad.hcl @@ -0,0 +1,99 @@ +job "api-auth" { + datacenters = ["dc1"] + type = "batch" + + constraint { + attribute = "${attr.kernel.name}" + value = "linux" + } + + group "api-auth" { + + # none task should get a 401 response + task "none" { + driver = "docker" + config { + image = "curlimages/curl:7.87.0" + args = [ + "--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock", + "-v", + "localhost/v1/agent/health", + ] + } + resources { + cpu = 16 + memory = 32 + disk = 64 + } + } + + # bad task should get a 403 response + task "bad" { + driver = "docker" + config { + image = "curlimages/curl:7.87.0" + args = [ + "--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock", + "-H", "X-Nomad-Token: 37297754-3b87-41da-9ac7-d98fd934deed", + "-v", + "localhost/v1/agent/health", + ] + } + resources { + cpu = 16 + memory = 32 + disk = 64 + } + } + + # docker-wid task should succeed due to using workload identity + task "docker-wid" { + driver = "docker" + + config { + image = "curlimages/curl:7.87.0" + args = [ + "--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock", + "-H", "Authorization: Bearer ${NOMAD_TOKEN}", + "-v", + "localhost/v1/agent/health", + ] + } + + identity { + env = true + } + + resources { + cpu = 16 + memory = 32 + disk = 64 + } + } + + # exec-wid task should succeed due to using workload identity + task "exec-wid" { + driver = "exec" + + config { + command = "curl" + args = [ + "-H", "Authorization: Bearer ${NOMAD_TOKEN}", + "--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock", + "-v", + "localhost/v1/agent/health", + ] + } + + identity { + env = true + } + + resources { + cpu = 16 + memory = 32 + disk = 64 + } + } + } +} diff --git a/e2e/workload_id/input/api-win.nomad.hcl b/e2e/workload_id/input/api-win.nomad.hcl new file mode 100644 index 000000000..7552500f3 --- /dev/null +++ b/e2e/workload_id/input/api-win.nomad.hcl @@ -0,0 +1,36 @@ +job "api-win" { + datacenters = ["dc1"] + type = "batch" + + constraint { + attribute = "${attr.kernel.name}" + value = "windows" + } + + constraint { + attribute = "${attr.cpu.arch}" + value = "amd64" + } + + group "api-win" { + + task "win" { + driver = "raw_exec" + config { + command = "powershell" + args = ["local/curl-7.87.0_4-win64-mingw/bin/curl.exe -H \"Authorization: Bearer $env:NOMAD_TOKEN\" --unix-socket $env:NOMAD_SECRETS_DIR/api.sock -v localhost:4646/v1/agent/health"] + } + artifact { + source = "https://curl.se/windows/dl-7.87.0_4/curl-7.87.0_4-win64-mingw.zip" + } + identity { + env = true + } + resources { + cpu = 16 + memory = 32 + disk = 64 + } + } + } +} diff --git a/e2e/workload_id/taskapi_test.go b/e2e/workload_id/taskapi_test.go new file mode 100644 index 000000000..3c636d367 --- /dev/null +++ b/e2e/workload_id/taskapi_test.go @@ -0,0 +1,111 @@ +package main + +import ( + "fmt" + "io" + "net/http" + "testing" + + "github.com/hashicorp/nomad/e2e/e2eutil" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/shoenig/test" + "github.com/shoenig/test/must" +) + +// TestTaskAPI runs subtets exercising the Task API related functionality. +// Bundled with Workload Identity as that's a prereq for the Task API to work. +func TestTaskAPI(t *testing.T) { + nomad := e2eutil.NomadClient(t) + + e2eutil.WaitForLeader(t, nomad) + e2eutil.WaitForNodesReady(t, nomad, 1) + + t.Run("testTaskAPI_Auth", testTaskAPIAuth) + t.Run("testTaskAPI_Windows", testTaskAPIWindows) +} + +func testTaskAPIAuth(t *testing.T) { + nomad := e2eutil.NomadClient(t) + jobID := "api-auth-" + uuid.Short() + jobIDs := []string{jobID} + t.Cleanup(e2eutil.CleanupJobsAndGC(t, &jobIDs)) + + // start job + allocs := e2eutil.RegisterAndWaitForAllocs(t, nomad, "./input/api-auth.nomad.hcl", jobID, "") + must.Len(t, 1, allocs) + allocID := allocs[0].ID + + // wait for batch alloc to complete + alloc := e2eutil.WaitForAllocStopped(t, nomad, allocID) + must.Eq(t, alloc.ClientStatus, "complete") + + assertions := []struct { + task string + suffix string + }{ + { + task: "none", + suffix: http.StatusText(http.StatusUnauthorized), + }, + { + task: "bad", + suffix: http.StatusText(http.StatusForbidden), + }, + { + task: "docker-wid", + suffix: `"ok":true}}`, + }, + { + task: "exec-wid", + suffix: `"ok":true}}`, + }, + } + + // Ensure the assertions and input file match + must.Len(t, len(assertions), alloc.Job.TaskGroups[0].Tasks, + must.Sprintf("test and jobspec mismatch")) + + for _, tc := range assertions { + logFile := fmt.Sprintf("alloc/logs/%s.stdout.0", tc.task) + fd, err := nomad.AllocFS().Cat(alloc, logFile, nil) + must.NoError(t, err) + logBytes, err := io.ReadAll(fd) + must.NoError(t, err) + logs := string(logBytes) + + ps := must.Sprintf("Task: %s Logs: <