Task API via Unix Domain Socket (#15864)

This change introduces the Task API: a portable way for tasks to access Nomad's HTTP API. This particular implementation uses a Unix Domain Socket and, unlike the agent's HTTP API, always requires authentication even if ACLs are disabled.

This PR contains the core feature and tests but followup work is required for the following TODO items:

- Docs - might do in a followup since dynamic node metadata / task api / workload id all need to interlink
- Unit tests for auth middleware
- Caching for auth middleware
- Rate limiting on negative lookups for auth middleware

---------

Co-authored-by: Seth Hoenig <shoenig@duck.com>
This commit is contained in:
Michael Schurter 2023-02-06 11:31:22 -08:00 committed by GitHub
parent 21895fb6f0
commit 0a496c845e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 876 additions and 38 deletions

3
.changelog/15864.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
client: added http api access for tasks via unix socket
```

View File

@ -33,8 +33,6 @@ import (
+-----------+
*Kill
(forces terminal)
Link: http://stable.ascii-flow.appspot.com/#Draw4489375405966393064/1824429135
*/
// TaskHook is a lifecycle hook into the life cycle of a task runner.
@ -186,6 +184,9 @@ type TaskStopRequest struct {
// ExistingState is previously set hook data and should only be
// read. Stop hooks cannot alter state.
ExistingState map[string]string
// TaskDir contains the task's directory tree on the host
TaskDir *allocdir.TaskDir
}
type TaskStopResponse struct{}

View File

@ -0,0 +1,119 @@
package taskrunner
import (
"context"
"errors"
"net"
"net/http"
"os"
"path/filepath"
"sync"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/allocdir"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
"github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/helper/users"
)
// apiHook exposes the Task API. The Task API allows task's to access the Nomad
// HTTP API without having to discover and connect to an agent's address.
// Instead a unix socket is provided in a standard location. To prevent access
// by untrusted workloads the Task API always requires authentication even when
// ACLs are disabled.
//
// The Task API hook largely soft-fails as there are a number of ways creating
// the unix socket could fail (the most common one being path length
// restrictions), and it is assumed most tasks won't require access to the Task
// API anyway. Tasks that do require access are expected to crash and get
// rescheduled should they land on a client who Task API hook soft-fails.
type apiHook struct {
shutdownCtx context.Context
srv config.APIListenerRegistrar
logger hclog.Logger
// Lock listener as it is updated from multiple hooks.
lock sync.Mutex
// Listener is the unix domain socket of the task api for this taks.
ln net.Listener
}
func newAPIHook(shutdownCtx context.Context, srv config.APIListenerRegistrar, logger hclog.Logger) *apiHook {
h := &apiHook{
shutdownCtx: shutdownCtx,
srv: srv,
}
h.logger = logger.Named(h.Name())
return h
}
func (*apiHook) Name() string {
return "api"
}
func (h *apiHook) Prestart(_ context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
h.lock.Lock()
defer h.lock.Unlock()
if h.ln != nil {
// Listener already set. Task is probably restarting.
return nil
}
udsPath := apiSocketPath(req.TaskDir)
udsln, err := users.SocketFileFor(h.logger, udsPath, req.Task.User)
if err != nil {
// Soft-fail and let the task fail if it requires the task api.
h.logger.Warn("error creating task api socket", "path", udsPath, "error", err)
return nil
}
go func() {
// Cannot use Prestart's context as it is closed after all prestart hooks
// have been closed, but we do want to try to cleanup on shutdown.
if err := h.srv.Serve(h.shutdownCtx, udsln); err != nil {
if errors.Is(err, http.ErrServerClosed) {
return
}
if errors.Is(err, net.ErrClosed) {
return
}
h.logger.Error("error serving task api", "error", err)
}
}()
h.ln = udsln
return nil
}
func (h *apiHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error {
h.lock.Lock()
defer h.lock.Unlock()
if h.ln != nil {
if err := h.ln.Close(); err != nil {
if !errors.Is(err, net.ErrClosed) {
h.logger.Debug("error closing task listener: %v", err)
}
}
h.ln = nil
}
// Best-effort at cleaining things up. Alloc dir cleanup will remove it if
// this fails for any reason.
_ = os.RemoveAll(apiSocketPath(req.TaskDir))
return nil
}
// apiSocketPath returns the path to the Task API socket.
//
// The path needs to be as short as possible because of the low limits on the
// sun_path char array imposed by the syscall used to create unix sockets.
//
// See https://github.com/hashicorp/nomad/pull/13971 for an example of the
// sadness this causes.
func apiSocketPath(taskDir *allocdir.TaskDir) string {
return filepath.Join(taskDir.SecretsDir, "api.sock")
}

View File

@ -0,0 +1,169 @@
package taskrunner
import (
"context"
"io/fs"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"syscall"
"testing"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/allocdir"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/helper/users"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/shoenig/test/must"
)
type testAPIListenerRegistrar struct {
cb func(net.Listener) error
}
func (n testAPIListenerRegistrar) Serve(_ context.Context, ln net.Listener) error {
if n.cb != nil {
return n.cb(ln)
}
return nil
}
// TestAPIHook_SoftFail asserts that the Task API Hook soft fails and does not
// return errors.
func TestAPIHook_SoftFail(t *testing.T) {
ci.Parallel(t)
// Use a SecretsDir that will always exceed Unix socket path length
// limits (sun_path)
dst := filepath.Join(t.TempDir(), strings.Repeat("_NOMAD_TEST_", 100))
ctx := context.Background()
srv := testAPIListenerRegistrar{}
logger := testlog.HCLogger(t)
h := newAPIHook(ctx, srv, logger)
req := &interfaces.TaskPrestartRequest{
Task: &structs.Task{}, // needs to be non-nil for Task.User lookup
TaskDir: &allocdir.TaskDir{
SecretsDir: dst,
},
}
resp := &interfaces.TaskPrestartResponse{}
err := h.Prestart(ctx, req, resp)
must.NoError(t, err)
// listener should not have been set
must.Nil(t, h.ln)
// File should not have been created
_, err = os.Stat(dst)
must.Error(t, err)
// Assert stop also soft-fails
stopReq := &interfaces.TaskStopRequest{
TaskDir: req.TaskDir,
}
stopResp := &interfaces.TaskStopResponse{}
err = h.Stop(ctx, stopReq, stopResp)
must.NoError(t, err)
// File should not have been created
_, err = os.Stat(dst)
must.Error(t, err)
}
// TestAPIHook_Ok asserts that the Task API Hook creates and cleans up a
// socket.
func TestAPIHook_Ok(t *testing.T) {
ci.Parallel(t)
// If this test fails it may be because TempDir() + /api.sock is longer than
// the unix socket path length limit (sun_path) in which case the test should
// use a different temporary directory on that platform.
dst := t.TempDir()
// Write "ok" and close the connection and listener
srv := testAPIListenerRegistrar{
cb: func(ln net.Listener) error {
conn, err := ln.Accept()
if err != nil {
return err
}
if _, err = conn.Write([]byte("ok")); err != nil {
return err
}
conn.Close()
return nil
},
}
ctx := context.Background()
logger := testlog.HCLogger(t)
h := newAPIHook(ctx, srv, logger)
req := &interfaces.TaskPrestartRequest{
Task: &structs.Task{
User: "nobody",
},
TaskDir: &allocdir.TaskDir{
SecretsDir: dst,
},
}
resp := &interfaces.TaskPrestartResponse{}
err := h.Prestart(ctx, req, resp)
must.NoError(t, err)
// File should have been created
sockDst := apiSocketPath(req.TaskDir)
// Stat and chown fail on Windows, so skip these checks
if runtime.GOOS != "windows" {
stat, err := os.Stat(sockDst)
must.NoError(t, err)
must.True(t, stat.Mode()&fs.ModeSocket != 0,
must.Sprintf("expected %q to be a unix socket but got %s", sockDst, stat.Mode()))
nobody, _ := users.Lookup("nobody")
if syscall.Getuid() == 0 && nobody != nil {
t.Logf("root and nobody exists: testing file perms")
// We're root and nobody exists! Check perms
must.Eq(t, fs.FileMode(0o600), stat.Mode().Perm())
sysStat, ok := stat.Sys().(*syscall.Stat_t)
must.True(t, ok, must.Sprintf("expected stat.Sys() to be a *syscall.Stat_t on %s but found %T",
runtime.GOOS, stat.Sys()))
nobodyUID, err := strconv.Atoi(nobody.Uid)
must.NoError(t, err)
must.Eq(t, nobodyUID, int(sysStat.Uid))
}
}
// Assert the listener is working
conn, err := net.Dial("unix", sockDst)
must.NoError(t, err)
buf := make([]byte, 2)
_, err = conn.Read(buf)
must.NoError(t, err)
must.Eq(t, []byte("ok"), buf)
conn.Close()
// Assert stop cleans up
stopReq := &interfaces.TaskStopRequest{
TaskDir: req.TaskDir,
}
stopResp := &interfaces.TaskStopResponse{}
err = h.Stop(ctx, stopReq, stopResp)
must.NoError(t, err)
// File should be gone
_, err = net.Dial("unix", sockDst)
must.Error(t, err)
}

View File

@ -68,6 +68,7 @@ func (tr *TaskRunner) initHooks() {
newArtifactHook(tr, tr.getter, hookLogger),
newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger),
newDeviceHook(tr.devicemanager, hookLogger),
newAPIHook(tr.shutdownCtx, tr.clientConfig.APIListenerRegistrar, hookLogger),
}
// If the task has a CSI block, add the hook.
@ -431,7 +432,9 @@ func (tr *TaskRunner) stop() error {
tr.logger.Trace("running stop hook", "name", name, "start", start)
}
req := interfaces.TaskStopRequest{}
req := interfaces.TaskStopRequest{
TaskDir: tr.taskDir,
}
origHookState := tr.hookState(name)
if origHookState != nil {

View File

@ -1,8 +1,10 @@
package config
import (
"context"
"errors"
"fmt"
"net"
"reflect"
"strconv"
"strings"
@ -301,10 +303,27 @@ type Config struct {
// used for template functions which require access to the Nomad API.
TemplateDialer *bufconndialer.BufConnWrapper
// APIListenerRegistrar allows the client to register listeners created at
// runtime (eg the Task API) with the agent's HTTP server. Since the agent
// creates the HTTP *after* the client starts, we have to use this shim to
// pass listeners back to the agent.
// This is the same design as the bufconndialer but for the
// http.Serve(listener) API instead of the net.Dial API.
APIListenerRegistrar APIListenerRegistrar
// Artifact configuration from the agent's config file.
Artifact *ArtifactConfig
}
type APIListenerRegistrar interface {
// Serve the HTTP API on the provided listener.
//
// The context is because Serve may be called before the HTTP server has been
// initialized. If the context is canceled before the HTTP server is
// initialized, the context's error will be returned.
Serve(context.Context, net.Listener) error
}
// ClientTemplateConfig is configuration on the client specific to template
// rendering
type ClientTemplateConfig struct {

View File

@ -1,7 +1,9 @@
package config
import (
"context"
"io/ioutil"
"net"
"os"
"path/filepath"
"time"
@ -74,5 +76,14 @@ func TestClientConfig(t testing.T) (*Config, func()) {
// Same as default; necessary for task Event messages
conf.MaxKillTimeout = 30 * time.Second
// Provide a stub APIListenerRegistrar implementation
conf.APIListenerRegistrar = NoopAPIListenerRegistrar{}
return conf, cleanup
}
type NoopAPIListenerRegistrar struct{}
func (NoopAPIListenerRegistrar) Serve(_ context.Context, _ net.Listener) error {
return nil
}

View File

@ -115,7 +115,11 @@ type Agent struct {
builtinListener net.Listener
builtinDialer *bufconndialer.BufConnWrapper
InmemSink *metrics.InmemSink
// builtinServer is an HTTP server for attaching per-task listeners. Always
// requires auth.
builtinServer *builtinAPI
inmemSink *metrics.InmemSink
}
// NewAgent is used to create a new agent with the given configuration
@ -124,7 +128,7 @@ func NewAgent(config *Config, logger log.InterceptLogger, logOutput io.Writer, i
config: config,
logOutput: logOutput,
shutdownCh: make(chan struct{}),
InmemSink: inmem,
inmemSink: inmem,
}
// Create the loggers
@ -1020,6 +1024,11 @@ func (a *Agent) setupClient() error {
a.builtinListener, a.builtinDialer = bufconndialer.New()
conf.TemplateDialer = a.builtinDialer
// Initialize builtin API server here for use in the client, but it won't
// accept connections until the HTTP servers are created.
a.builtinServer = newBuiltinAPI()
conf.APIListenerRegistrar = a.builtinServer
nomadClient, err := client.NewClient(
conf, a.consulCatalog, a.consulProxies, a.consulService, nil)
if err != nil {
@ -1300,6 +1309,11 @@ func (a *Agent) GetConfig() *Config {
return a.config
}
// GetMetricsSink returns the metrics sink.
func (a *Agent) GetMetricsSink() *metrics.InmemSink {
return a.inmemSink
}
// setupConsul creates the Consul client and starts its main Run loop.
func (a *Agent) setupConsul(consulConfig *config.ConsulConfig) error {
apiConf, err := consulConfig.ApiConfig()

View File

@ -440,7 +440,7 @@ func (s *HTTPServer) listServers(resp http.ResponseWriter, req *http.Request) (i
return nil, structs.ErrPermissionDenied
}
peers := s.agent.client.GetServers()
peers := client.GetServers()
sort.Strings(peers)
return peers, nil
}
@ -468,9 +468,9 @@ func (s *HTTPServer) updateServers(resp http.ResponseWriter, req *http.Request)
}
// Set the servers list into the client
s.agent.logger.Trace("adding servers to the client's primary server list", "servers", servers, "path", "/v1/agent/servers", "method", "PUT")
s.logger.Trace("adding servers to the client's primary server list", "servers", servers, "path", "/v1/agent/servers", "method", "PUT")
if _, err := client.SetServers(servers); err != nil {
s.agent.logger.Error("failed adding servers to client's server list", "servers", servers, "error", err, "path", "/v1/agent/servers", "method", "PUT")
s.logger.Error("failed adding servers to client's server list", "servers", servers, "error", err, "path", "/v1/agent/servers", "method", "PUT")
//TODO is this the right error to return?
return nil, CodedError(400, err.Error())
}
@ -708,7 +708,7 @@ func (s *HTTPServer) AgentHostRequest(resp http.ResponseWriter, req *http.Reques
// The RPC endpoint actually forwards the request to the correct
// agent, but we need to use the correct RPC interface.
localClient, remoteClient, localServer := s.rpcHandlerForNode(lookupNodeID)
s.agent.logger.Debug("s.rpcHandlerForNode()", "lookupNodeID", lookupNodeID, "serverID", serverID, "nodeID", nodeID, "localClient", localClient, "remoteClient", remoteClient, "localServer", localServer)
s.logger.Debug("s.rpcHandlerForNode()", "lookupNodeID", lookupNodeID, "serverID", serverID, "nodeID", nodeID, "localClient", localClient, "remoteClient", remoteClient, "localServer", localServer)
// Make the RPC call
if localClient {

View File

@ -222,7 +222,7 @@ func (s *HTTPServer) ClientAllocRequest(resp http.ResponseWriter, req *http.Requ
case "exec":
return s.allocExec(allocID, resp, req)
case "snapshot":
if s.agent.client == nil {
if s.agent.Client() == nil {
return nil, clientNotRunning
}
return s.allocSnapshot(allocID, resp, req)

View File

@ -46,7 +46,7 @@ func TestConsul_Integration(t *testing.T) {
// Create an embedded Consul server
testconsul, err := testutil.NewTestServerConfigT(t, func(c *testutil.TestServerConfig) {
c.Peering = nil // fix for older versions of Consul (<1.13.0) that don't support peering
c.Peering = nil // fix for older versions of Consul (<1.13.0) that don't support peering
// If -v wasn't specified squelch consul logging
if !testing.Verbose() {
c.Stdout = ioutil.Discard
@ -61,6 +61,7 @@ func TestConsul_Integration(t *testing.T) {
conf := config.DefaultConfig()
conf.Node = mock.Node()
conf.ConsulConfig.Addr = testconsul.HTTPAddr
conf.APIListenerRegistrar = config.NoopAPIListenerRegistrar{}
consulConfig, err := conf.ConsulConfig.ApiConfig()
if err != nil {
t.Fatalf("error generating consul config: %v", err)

View File

@ -2,6 +2,7 @@ package agent
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
@ -26,8 +27,10 @@ import (
"golang.org/x/time/rate"
"github.com/hashicorp/nomad/acl"
"github.com/hashicorp/nomad/client"
"github.com/hashicorp/nomad/helper/noxssrw"
"github.com/hashicorp/nomad/helper/tlsutil"
"github.com/hashicorp/nomad/nomad"
"github.com/hashicorp/nomad/nomad/structs"
)
@ -74,9 +77,18 @@ var (
type handlerFn func(resp http.ResponseWriter, req *http.Request) (interface{}, error)
type handlerByteFn func(resp http.ResponseWriter, req *http.Request) ([]byte, error)
type RPCer interface {
RPC(string, any, any) error
Server() *nomad.Server
Client() *client.Client
Stats() map[string]map[string]string
GetConfig() *Config
GetMetricsSink() *metrics.InmemSink
}
// HTTPServer is used to wrap an Agent and expose it over an HTTP interface
type HTTPServer struct {
agent *Agent
agent RPCer
mux *http.ServeMux
listener net.Listener
listenerCh chan struct{}
@ -170,7 +182,7 @@ func NewHTTPServers(agent *Agent, config *Config) ([]*HTTPServer, error) {
srvs = append(srvs, srv)
}
// This HTTP server is only create when running in client mode, otherwise
// This HTTP server is only created when running in client mode, otherwise
// the builtinDialer and builtinListener will be nil.
if agent.builtinDialer != nil && agent.builtinListener != nil {
srv := &HTTPServer{
@ -185,12 +197,15 @@ func NewHTTPServers(agent *Agent, config *Config) ([]*HTTPServer, error) {
srv.registerHandlers(config.EnableDebug)
// builtinServer adds a wrapper to always authenticate requests
httpServer := http.Server{
Addr: srv.Addr,
Handler: srv.mux,
Handler: newAuthMiddleware(srv, srv.mux),
ErrorLog: newHTTPServerLogger(srv.logger),
}
agent.builtinServer.SetServer(&httpServer)
go func() {
defer close(srv.listenerCh)
httpServer.Serve(agent.builtinListener)
@ -465,7 +480,8 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) {
s.mux.Handle("/v1/vars", wrapCORS(s.wrap(s.VariablesListRequest)))
s.mux.Handle("/v1/var/", wrapCORSWithAllowedMethods(s.wrap(s.VariableSpecificRequest), "HEAD", "GET", "PUT", "DELETE"))
uiConfigEnabled := s.agent.config.UI != nil && s.agent.config.UI.Enabled
agentConfig := s.agent.GetConfig()
uiConfigEnabled := agentConfig.UI != nil && agentConfig.UI.Enabled
if uiEnabled && uiConfigEnabled {
s.mux.Handle("/ui/", http.StripPrefix("/ui/", s.handleUI(http.FileServer(&UIAssetWrapper{FileSystem: assetFS()}))))
@ -484,7 +500,7 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) {
s.mux.Handle("/", s.handleRootFallthrough())
if enableDebug {
if !s.agent.config.DevMode {
if !agentConfig.DevMode {
s.logger.Warn("enable_debug is set to true. This is insecure and should not be enabled in production")
}
s.mux.HandleFunc("/debug/pprof/", pprof.Index)
@ -498,6 +514,54 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) {
s.registerEnterpriseHandlers()
}
// builtinAPI is a wrapper around serving the HTTP API to arbitrary listeners
// such as the Task API. It is necessary because the HTTP servers are created
// *after* the client has been initialized, so this wrapper blocks Serve
// requests from task api hooks until the HTTP server is setup and ready to
// accept from new listeners.
//
// bufconndialer provides similar functionality to consul-template except it
// satisfies the Dialer API as opposed to the Serve(Listener) API.
type builtinAPI struct {
srv *http.Server
srvReadyCh chan struct{}
}
func newBuiltinAPI() *builtinAPI {
return &builtinAPI{
srvReadyCh: make(chan struct{}),
}
}
// SetServer sets the API HTTP server for Serve to add listeners to.
//
// It must be called exactly once and will panic if called more than once.
func (b *builtinAPI) SetServer(srv *http.Server) {
select {
case <-b.srvReadyCh:
panic(fmt.Sprintf("SetServer called twice. first=%p second=%p", b.srv, srv))
default:
}
b.srv = srv
close(b.srvReadyCh)
}
// Serve the HTTP API on the listener unless the context is canceled before the
// HTTP API is ready to serve listeners. A non-nil error will always be
// returned, but http.ErrServerClosed and net.ErrClosed can likely be ignored
// as they indicate the server or listener is being shutdown.
func (b *builtinAPI) Serve(ctx context.Context, l net.Listener) error {
select {
case <-ctx.Done():
// Caller canceled context before server was ready.
return ctx.Err()
case <-b.srvReadyCh:
// Server ready for listeners! Continue on...
}
return b.srv.Serve(l)
}
// HTTPCodedError is used to provide the HTTP error code
type HTTPCodedError interface {
error
@ -591,7 +655,7 @@ func errCodeFromHandler(err error) (int, string) {
// wrap is used to wrap functions to make them more convenient
func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Request) (interface{}, error)) func(resp http.ResponseWriter, req *http.Request) {
f := func(resp http.ResponseWriter, req *http.Request) {
setHeaders(resp, s.agent.config.HTTPAPIResponseHeaders)
setHeaders(resp, s.agent.GetConfig().HTTPAPIResponseHeaders)
// Invoke the handler
reqURL := req.URL.String()
start := time.Now()
@ -673,7 +737,7 @@ func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Reque
// Handler functions are responsible for setting Content-Type Header
func (s *HTTPServer) wrapNonJSON(handler func(resp http.ResponseWriter, req *http.Request) ([]byte, error)) func(resp http.ResponseWriter, req *http.Request) {
f := func(resp http.ResponseWriter, req *http.Request) {
setHeaders(resp, s.agent.config.HTTPAPIResponseHeaders)
setHeaders(resp, s.agent.GetConfig().HTTPAPIResponseHeaders)
// Invoke the handler
reqURL := req.URL.String()
start := time.Now()
@ -817,7 +881,7 @@ func (s *HTTPServer) parseRegion(req *http.Request, r *string) {
if other := req.URL.Query().Get("region"); other != "" {
*r = other
} else if *r == "" {
*r = s.agent.config.Region
*r = s.agent.GetConfig().Region
}
}
@ -976,3 +1040,55 @@ func wrapCORS(f func(http.ResponseWriter, *http.Request)) http.Handler {
func wrapCORSWithAllowedMethods(f func(http.ResponseWriter, *http.Request), methods ...string) http.Handler {
return allowCORSWithMethods(methods...).Handler(http.HandlerFunc(f))
}
// authMiddleware implements the http.Handler interface to enforce
// authentication for *all* requests. Even with ACLs enabled there are
// endpoints which are accessible without authenticating. This middleware is
// used for the Task API to enfoce authentication for all API access.
type authMiddleware struct {
srv *HTTPServer
wrapped http.Handler
}
func newAuthMiddleware(srv *HTTPServer, h http.Handler) http.Handler {
return &authMiddleware{
srv: srv,
wrapped: h,
}
}
func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
args := structs.GenericRequest{}
reply := structs.ACLWhoAmIResponse{}
if a.srv.parse(resp, req, &args.Region, &args.QueryOptions) {
// Error parsing request, 400
resp.WriteHeader(http.StatusBadRequest)
resp.Write([]byte(http.StatusText(http.StatusBadRequest)))
return
}
if args.AuthToken == "" {
// 401 instead of 403 since no token was present.
resp.WriteHeader(http.StatusUnauthorized)
resp.Write([]byte(http.StatusText(http.StatusUnauthorized)))
return
}
if err := a.srv.agent.RPC("ACL.WhoAmI", &args, &reply); err != nil {
a.srv.logger.Error("error authenticating built API request", "error", err, "url", req.URL, "method", req.Method)
resp.WriteHeader(500)
resp.Write([]byte("Server error authenticating request\n"))
return
}
// Require an acl token or workload identity
if reply.Identity == nil || (reply.Identity.ACLToken == nil && reply.Identity.Claims == nil) {
a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL)
resp.WriteHeader(http.StatusForbidden)
resp.Write([]byte(http.StatusText(http.StatusForbidden)))
return
}
a.srv.logger.Trace("Authenticated request", "id", reply.Identity, "method", req.Method, "url", req.URL)
a.wrapped.ServeHTTP(resp, req)
}

View File

@ -819,7 +819,7 @@ func (s *HTTPServer) apiJobAndRequestToStructs(job *api.Job, req *http.Request,
queryRegion := req.URL.Query().Get("region")
requestRegion, jobRegion := regionForJob(
job, queryRegion, writeReq.Region, s.agent.config.Region,
job, queryRegion, writeReq.Region, s.agent.GetConfig().Region,
)
sJob := ApiJobToStructJob(job)

View File

@ -25,14 +25,14 @@ func (s *HTTPServer) MetricsRequest(resp http.ResponseWriter, req *http.Request)
// Only return Prometheus formatted metrics if the user has enabled
// this functionality.
if !s.agent.config.Telemetry.PrometheusMetrics {
if !s.agent.GetConfig().Telemetry.PrometheusMetrics {
return nil, CodedError(http.StatusUnsupportedMediaType, "Prometheus is not enabled")
}
s.prometheusHandler().ServeHTTP(resp, req)
return nil, nil
}
return s.agent.InmemSink.DisplayMetrics(resp, req)
return s.agent.GetMetricsSink().DisplayMetrics(resp, req)
}
func (s *HTTPServer) prometheusHandler() http.Handler {

View File

@ -16,7 +16,8 @@ func (s *HTTPServer) VariablesListRequest(resp http.ResponseWriter, req *http.Re
args := structs.VariablesListRequest{}
if s.parse(resp, req, &args.Region, &args.QueryOptions) {
return nil, nil
//TODO(schmichael) shouldn't we return something here?!
return nil, CodedError(http.StatusBadRequest, "failed to parse parameters")
}
var out structs.VariablesListResponse

View File

@ -0,0 +1,99 @@
job "api-auth" {
datacenters = ["dc1"]
type = "batch"
constraint {
attribute = "${attr.kernel.name}"
value = "linux"
}
group "api-auth" {
# none task should get a 401 response
task "none" {
driver = "docker"
config {
image = "curlimages/curl:7.87.0"
args = [
"--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock",
"-v",
"localhost/v1/agent/health",
]
}
resources {
cpu = 16
memory = 32
disk = 64
}
}
# bad task should get a 403 response
task "bad" {
driver = "docker"
config {
image = "curlimages/curl:7.87.0"
args = [
"--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock",
"-H", "X-Nomad-Token: 37297754-3b87-41da-9ac7-d98fd934deed",
"-v",
"localhost/v1/agent/health",
]
}
resources {
cpu = 16
memory = 32
disk = 64
}
}
# docker-wid task should succeed due to using workload identity
task "docker-wid" {
driver = "docker"
config {
image = "curlimages/curl:7.87.0"
args = [
"--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock",
"-H", "Authorization: Bearer ${NOMAD_TOKEN}",
"-v",
"localhost/v1/agent/health",
]
}
identity {
env = true
}
resources {
cpu = 16
memory = 32
disk = 64
}
}
# exec-wid task should succeed due to using workload identity
task "exec-wid" {
driver = "exec"
config {
command = "curl"
args = [
"-H", "Authorization: Bearer ${NOMAD_TOKEN}",
"--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock",
"-v",
"localhost/v1/agent/health",
]
}
identity {
env = true
}
resources {
cpu = 16
memory = 32
disk = 64
}
}
}
}

View File

@ -0,0 +1,36 @@
job "api-win" {
datacenters = ["dc1"]
type = "batch"
constraint {
attribute = "${attr.kernel.name}"
value = "windows"
}
constraint {
attribute = "${attr.cpu.arch}"
value = "amd64"
}
group "api-win" {
task "win" {
driver = "raw_exec"
config {
command = "powershell"
args = ["local/curl-7.87.0_4-win64-mingw/bin/curl.exe -H \"Authorization: Bearer $env:NOMAD_TOKEN\" --unix-socket $env:NOMAD_SECRETS_DIR/api.sock -v localhost:4646/v1/agent/health"]
}
artifact {
source = "https://curl.se/windows/dl-7.87.0_4/curl-7.87.0_4-win64-mingw.zip"
}
identity {
env = true
}
resources {
cpu = 16
memory = 32
disk = 64
}
}
}
}

View File

@ -0,0 +1,111 @@
package main
import (
"fmt"
"io"
"net/http"
"testing"
"github.com/hashicorp/nomad/e2e/e2eutil"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/shoenig/test"
"github.com/shoenig/test/must"
)
// TestTaskAPI runs subtets exercising the Task API related functionality.
// Bundled with Workload Identity as that's a prereq for the Task API to work.
func TestTaskAPI(t *testing.T) {
nomad := e2eutil.NomadClient(t)
e2eutil.WaitForLeader(t, nomad)
e2eutil.WaitForNodesReady(t, nomad, 1)
t.Run("testTaskAPI_Auth", testTaskAPIAuth)
t.Run("testTaskAPI_Windows", testTaskAPIWindows)
}
func testTaskAPIAuth(t *testing.T) {
nomad := e2eutil.NomadClient(t)
jobID := "api-auth-" + uuid.Short()
jobIDs := []string{jobID}
t.Cleanup(e2eutil.CleanupJobsAndGC(t, &jobIDs))
// start job
allocs := e2eutil.RegisterAndWaitForAllocs(t, nomad, "./input/api-auth.nomad.hcl", jobID, "")
must.Len(t, 1, allocs)
allocID := allocs[0].ID
// wait for batch alloc to complete
alloc := e2eutil.WaitForAllocStopped(t, nomad, allocID)
must.Eq(t, alloc.ClientStatus, "complete")
assertions := []struct {
task string
suffix string
}{
{
task: "none",
suffix: http.StatusText(http.StatusUnauthorized),
},
{
task: "bad",
suffix: http.StatusText(http.StatusForbidden),
},
{
task: "docker-wid",
suffix: `"ok":true}}`,
},
{
task: "exec-wid",
suffix: `"ok":true}}`,
},
}
// Ensure the assertions and input file match
must.Len(t, len(assertions), alloc.Job.TaskGroups[0].Tasks,
must.Sprintf("test and jobspec mismatch"))
for _, tc := range assertions {
logFile := fmt.Sprintf("alloc/logs/%s.stdout.0", tc.task)
fd, err := nomad.AllocFS().Cat(alloc, logFile, nil)
must.NoError(t, err)
logBytes, err := io.ReadAll(fd)
must.NoError(t, err)
logs := string(logBytes)
ps := must.Sprintf("Task: %s Logs: <<EOF\n%sEOF", tc.task, logs)
must.StrHasSuffix(t, tc.suffix, logs, ps)
}
}
func testTaskAPIWindows(t *testing.T) {
nomad := e2eutil.NomadClient(t)
winNodes, err := e2eutil.ListWindowsClientNodes(nomad)
must.NoError(t, err)
if len(winNodes) == 0 {
t.Skip("no Windows clients")
}
jobID := "api-win-" + uuid.Short()
jobIDs := []string{jobID}
t.Cleanup(e2eutil.CleanupJobsAndGC(t, &jobIDs))
// start job
allocs := e2eutil.RegisterAndWaitForAllocs(t, nomad, "./input/api-win.nomad.hcl", jobID, "")
must.Len(t, 1, allocs)
allocID := allocs[0].ID
// wait for batch alloc to complete
alloc := e2eutil.WaitForAllocStopped(t, nomad, allocID)
test.Eq(t, alloc.ClientStatus, "complete")
logFile := "alloc/logs/win.stdout.0"
fd, err := nomad.AllocFS().Cat(alloc, logFile, nil)
must.NoError(t, err)
logBytes, err := io.ReadAll(fd)
must.NoError(t, err)
logs := string(logBytes)
must.StrHasSuffix(t, `"ok":true}}`, logs)
}

View File

@ -2,11 +2,13 @@ package users
import (
"fmt"
"net"
"os"
"os/user"
"strconv"
"sync"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
)
@ -36,6 +38,23 @@ func Current() (*user.User, error) {
return user.Current()
}
// UIDforUser returns the UID for the specified username or returns an error.
//
// Will always fail on Windows and Plan 9.
func UIDforUser(username string) (int, error) {
u, err := Lookup(username)
if err != nil {
return 0, err
}
uid, err := strconv.Atoi(u.Uid)
if err != nil {
return 0, fmt.Errorf("error parsing uid: %w", err)
}
return uid, nil
}
// WriteFileFor is like os.WriteFile except if possible it chowns the file to
// the specified user (possibly from Task.User) and sets the permissions to
// 0o600.
@ -45,6 +64,8 @@ func Current() (*user.User, error) {
//
// On failure a multierror with both the original and fallback errors will be
// returned.
//
// See SocketFileFor if writing a unix socket file.
func WriteFileFor(path string, contents []byte, username string) error {
// Don't even bother trying to chown to an empty username
var origErr error
@ -72,16 +93,11 @@ func WriteFileFor(path string, contents []byte, username string) error {
}
func writeFileFor(path string, contents []byte, username string) error {
user, err := Lookup(username)
uid, err := UIDforUser(username)
if err != nil {
return err
}
uid, err := strconv.Atoi(user.Uid)
if err != nil {
return fmt.Errorf("error parsing uid: %w", err)
}
if err := os.WriteFile(path, contents, 0o600); err != nil {
return err
}
@ -95,3 +111,58 @@ func writeFileFor(path string, contents []byte, username string) error {
return nil
}
// SocketFileFor creates a unix domain socket file on the specified path and,
// if possible, makes it usable by only the specified user. Failing that it
// will leave the socket open to all users. Non-fatal errors are logged.
//
// See WriteFileFor if writing a regular file.
func SocketFileFor(logger hclog.Logger, path, username string) (net.Listener, error) {
if err := os.RemoveAll(path); err != nil {
logger.Warn("error removing socket", "path", path, "error", err)
}
udsln, err := net.Listen("unix", path)
if err != nil {
return nil, err
}
if username != "" {
// Try to set perms on socket file to least privileges.
if err := setSocketOwner(path, username); err == nil {
// Success! Exit early
return udsln, nil
}
// This error is expected to always occur in some environments (Windows,
// non-root agents), so don't log above Trace.
logger.Trace("failed to set user on socket", "path", path, "user", username, "error", err)
}
// Opportunistic least privileges failed above, so make sure anyone can use
// the socket.
if err := os.Chmod(path, 0o666); err != nil {
logger.Warn("error setting socket permissions", "path", path, "error", err)
}
return udsln, nil
}
func setSocketOwner(path, username string) error {
uid, err := UIDforUser(username)
if err != nil {
return err
}
if err := os.Chown(path, uid, -1); err != nil {
return err
}
if err := os.Chmod(path, 0o600); err != nil {
// Awkward situation that is hopefully impossible to reach where we could
// chown the socket but not change its mode.
return err
}
return nil
}

View File

@ -11,6 +11,7 @@ import (
"syscall"
"testing"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/shoenig/test/must"
"golang.org/x/sys/unix"
)
@ -58,7 +59,7 @@ func TestWriteFileFor_Linux(t *testing.T) {
stat, err := os.Lstat(path)
must.NoError(t, err)
must.True(t, stat.Mode().IsRegular(),
must.Sprintf("expected %s to be a normal file but found %#o", path, stat.Mode()))
must.Sprintf("expected %s to be a regular file but found %#o", path, stat.Mode()))
linuxStat, ok := stat.Sys().(*syscall.Stat_t)
must.True(t, ok, must.Sprintf("expected stat.Sys() to be a *syscall.Stat_t but found %T", stat.Sys()))
@ -78,3 +79,39 @@ func TestWriteFileFor_Linux(t *testing.T) {
must.Eq(t, 0o666&(^umask), int(stat.Mode()))
}
}
// TestSocketFileFor_Linux asserts that when running as root on Linux socket
// files are created with least permissions. If running as non-root then we
// leave the socket file as world writable.
func TestSocketFileFor_Linux(t *testing.T) {
path := filepath.Join(t.TempDir(), "api.sock")
logger := testlog.HCLogger(t)
ln, err := SocketFileFor(logger, path, "nobody")
must.NoError(t, err)
must.NotNil(t, ln)
defer ln.Close()
stat, err := os.Lstat(path)
must.NoError(t, err)
must.False(t, stat.Mode().IsRegular(),
must.Sprintf("expected %s to be a regular file but found %#o", path, stat.Mode()))
linuxStat, ok := stat.Sys().(*syscall.Stat_t)
must.True(t, ok, must.Sprintf("expected stat.Sys() to be a *syscall.Stat_t but found %T", stat.Sys()))
current, err := Current()
must.NoError(t, err)
if current.Username == "root" {
t.Logf("Running as root: asserting %s is owned by nobody", path)
nobody, err := Lookup("nobody")
must.NoError(t, err)
must.Eq(t, nobody.Uid, fmt.Sprintf("%d", linuxStat.Uid))
must.Eq(t, 0o600, int(stat.Mode().Perm()))
} else {
t.Logf("Running as non-root: asserting %s is world writable", path)
must.Eq(t, current.Uid, fmt.Sprintf("%d", linuxStat.Uid))
must.Eq(t, 0o666, int(stat.Mode().Perm()))
}
}

View File

@ -44,7 +44,23 @@ func TestWriteFileFor_Windows(t *testing.T) {
stat, err := os.Lstat(path)
must.NoError(t, err)
must.True(t, stat.Mode().IsRegular(),
must.Sprintf("expected %s to be a normal file but found %#o", path, stat.Mode()))
must.Sprintf("expected %s to be a regular file but found %#o", path, stat.Mode()))
// Assert Windows hits the fallback world-accessible case
must.Eq(t, 0o666, stat.Mode().Perm())
}
// TestSocketFileFor_Windows asserts that socket files cannot be chowned on
// windows.
func TestSocketFileFor_Windows(t *testing.T) {
path := filepath.Join(t.TempDir(), "api.sock")
ln, err := SocketFileFor(testlog.HCLogger(t), path, "Administrator")
must.NoError(t, err)
must.NotNil(t, ln)
defer ln.Close()
stat, err := os.Lstat(path)
must.NoError(t, err)
// Assert Windows hits the fallback world-accessible case
must.Eq(t, 0o666, stat.Mode().Perm())

View File

@ -2102,7 +2102,9 @@ func (a *ACL) GetAuthMethods(
}
// WhoAmI is a RPC for debugging authentication. This endpoint returns the same
// AuthenticatedIdentity that will be used by RPC handlers.
// AuthenticatedIdentity that will be used by RPC handlers, but unlike other
// endpoints will try to authenticate workload identities even if ACLs are
// disabled.
//
// TODO: At some point we might want to give this an equivalent HTTP endpoint
// once other Workload Identity work is solidified
@ -2118,6 +2120,15 @@ func (a *ACL) WhoAmI(args *structs.GenericRequest, reply *structs.ACLWhoAmIRespo
defer metrics.MeasureSince([]string{"nomad", "acl", "whoami"}, time.Now())
if !a.srv.config.ACLEnabled {
// Authenticate never verifies claimed when ACLs are disabled, but since
// this endpoint is explicitly for resolving identities, always try to
// verify any claims.
if claims, _ := a.srv.VerifyClaim(args.AuthToken); claims != nil {
args.SetIdentity(&structs.AuthenticatedIdentity{Claims: claims})
}
}
reply.Identity = args.GetIdentity()
return nil
}

View File

@ -501,12 +501,6 @@ func (ai *AuthenticatedIdentity) GetClaims() *IdentityClaims {
return ai.Claims
}
type RequestWithIdentity interface {
GetAuthToken() string
SetIdentity(identity *AuthenticatedIdentity)
GetIdentity() *AuthenticatedIdentity
}
func (ai *AuthenticatedIdentity) String() string {
if ai == nil {
return "unauthenticated"
@ -523,6 +517,12 @@ func (ai *AuthenticatedIdentity) String() string {
return fmt.Sprintf("%s:%s", ai.TLSName, ai.RemoteIP.String())
}
type RequestWithIdentity interface {
GetAuthToken() string
SetIdentity(identity *AuthenticatedIdentity)
GetIdentity() *AuthenticatedIdentity
}
// QueryMeta allows a query response to include potentially
// useful metadata about a query
type QueryMeta struct {