281 lines
8.0 KiB
Go
281 lines
8.0 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package taskrunner
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
|
|
ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
|
|
"github.com/hashicorp/nomad/client/consul"
|
|
"github.com/hashicorp/nomad/nomad/structs"
|
|
)
|
|
|
|
const (
|
|
// the name of this hook, used in logs
|
|
sidsHookName = "consul_si_token"
|
|
|
|
// sidsBackoffBaseline is the baseline time for exponential backoff when
|
|
// attempting to retrieve a Consul SI token
|
|
sidsBackoffBaseline = 5 * time.Second
|
|
|
|
// sidsBackoffLimit is the limit of the exponential backoff when attempting
|
|
// to retrieve a Consul SI token
|
|
sidsBackoffLimit = 3 * time.Minute
|
|
|
|
// sidsDerivationTimeout limits the amount of time we may spend trying to
|
|
// derive a SI token. If the hook does not get a token within this amount of
|
|
// time, the result is a failure.
|
|
sidsDerivationTimeout = 5 * time.Minute
|
|
|
|
// sidsTokenFile is the name of the file holding the Consul SI token inside
|
|
// the task's secret directory
|
|
sidsTokenFile = "si_token"
|
|
|
|
// sidsTokenFilePerms is the level of file permissions granted on the file
|
|
// in the secrets directory for the task
|
|
sidsTokenFilePerms = 0440
|
|
)
|
|
|
|
type sidsHookConfig struct {
|
|
alloc *structs.Allocation
|
|
task *structs.Task
|
|
sidsClient consul.ServiceIdentityAPI
|
|
lifecycle ti.TaskLifecycle
|
|
logger hclog.Logger
|
|
}
|
|
|
|
// Service Identities hook for managing SI tokens of connect enabled tasks.
|
|
type sidsHook struct {
|
|
// alloc is the allocation
|
|
alloc *structs.Allocation
|
|
|
|
// taskName is the name of the task
|
|
task *structs.Task
|
|
|
|
// sidsClient is the Consul client [proxy] for requesting SI tokens
|
|
sidsClient consul.ServiceIdentityAPI
|
|
|
|
// lifecycle is used to signal, restart, and kill a task
|
|
lifecycle ti.TaskLifecycle
|
|
|
|
// derivationTimeout is the amount of time we may wait for Consul to successfully
|
|
// provide a SI token. Making this configurable for testing, otherwise
|
|
// default to sidsDerivationTimeout
|
|
derivationTimeout time.Duration
|
|
|
|
// logger is used to log
|
|
logger hclog.Logger
|
|
|
|
// lock variables that can be manipulated after hook creation
|
|
lock sync.Mutex
|
|
// firstRun keeps track of whether the hook is being called for the first
|
|
// time (for this task) during the lifespan of the Nomad Client process.
|
|
firstRun bool
|
|
}
|
|
|
|
func newSIDSHook(c sidsHookConfig) *sidsHook {
|
|
return &sidsHook{
|
|
alloc: c.alloc,
|
|
task: c.task,
|
|
sidsClient: c.sidsClient,
|
|
lifecycle: c.lifecycle,
|
|
derivationTimeout: sidsDerivationTimeout,
|
|
logger: c.logger.Named(sidsHookName),
|
|
firstRun: true,
|
|
}
|
|
}
|
|
|
|
func (h *sidsHook) Name() string {
|
|
return sidsHookName
|
|
}
|
|
|
|
func (h *sidsHook) Prestart(
|
|
ctx context.Context,
|
|
req *interfaces.TaskPrestartRequest,
|
|
resp *interfaces.TaskPrestartResponse) error {
|
|
|
|
h.lock.Lock()
|
|
defer h.lock.Unlock()
|
|
|
|
// do nothing if we have already done things
|
|
if h.earlyExit() {
|
|
resp.Done = true
|
|
return nil
|
|
}
|
|
|
|
// optimistically try to recover token from disk
|
|
token, err := h.recoverToken(req.TaskDir.SecretsDir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// need to ask for a new SI token & persist it to disk
|
|
if token == "" {
|
|
if token, err = h.deriveSIToken(ctx); err != nil {
|
|
return err
|
|
}
|
|
if err := h.writeToken(req.TaskDir.SecretsDir, token); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
h.logger.Info("derived SI token", "task", h.task.Name, "si_task", h.task.Kind.Value())
|
|
|
|
resp.Done = true
|
|
return nil
|
|
}
|
|
|
|
// earlyExit returns true if the Prestart hook has already been executed during
|
|
// the instantiation of this task runner.
|
|
//
|
|
// assumes h is locked
|
|
func (h *sidsHook) earlyExit() bool {
|
|
if h.firstRun {
|
|
h.firstRun = false
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
// writeToken writes token into the secrets directory for the task.
|
|
func (h *sidsHook) writeToken(dir string, token string) error {
|
|
tokenPath := filepath.Join(dir, sidsTokenFile)
|
|
if err := os.WriteFile(tokenPath, []byte(token), sidsTokenFilePerms); err != nil {
|
|
return fmt.Errorf("failed to write SI token: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// recoverToken returns the token saved to disk in the secrets directory for the
|
|
// task if it exists, or the empty string if the file does not exist. an error
|
|
// is returned only for some other (e.g. disk IO) error.
|
|
func (h *sidsHook) recoverToken(dir string) (string, error) {
|
|
tokenPath := filepath.Join(dir, sidsTokenFile)
|
|
token, err := os.ReadFile(tokenPath)
|
|
if err != nil {
|
|
if !os.IsNotExist(err) {
|
|
h.logger.Error("failed to recover SI token", "error", err)
|
|
return "", fmt.Errorf("failed to recover SI token: %w", err)
|
|
}
|
|
h.logger.Trace("no pre-existing SI token to recover", "task", h.task.Name)
|
|
return "", nil // token file does not exist yet
|
|
}
|
|
h.logger.Trace("recovered pre-existing SI token", "task", h.task.Name)
|
|
return string(token), nil
|
|
}
|
|
|
|
// siDerivationResult is used to pass along the result of attempting to derive
|
|
// an SI token between the goroutine doing the derivation and its caller
|
|
type siDerivationResult struct {
|
|
token string
|
|
err error
|
|
}
|
|
|
|
// deriveSIToken spawns and waits on a goroutine which will make attempts to
|
|
// derive an SI token until a token is successfully created, or ctx is signaled
|
|
// done.
|
|
func (h *sidsHook) deriveSIToken(ctx context.Context) (string, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, h.derivationTimeout)
|
|
defer cancel()
|
|
|
|
resultCh := make(chan siDerivationResult)
|
|
|
|
// keep trying to get the token in the background
|
|
go h.tryDerive(ctx, resultCh)
|
|
|
|
// wait until we get a token, or we get a signal to quit
|
|
for {
|
|
select {
|
|
case result := <-resultCh:
|
|
if result.err != nil {
|
|
h.logger.Error("failed to derive SI token", "error", result.err)
|
|
h.kill(ctx, fmt.Errorf("failed to derive SI token: %w", result.err))
|
|
return "", result.err
|
|
}
|
|
return result.token, nil
|
|
case <-ctx.Done():
|
|
return "", ctx.Err()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *sidsHook) kill(ctx context.Context, reason error) {
|
|
if err := h.lifecycle.Kill(ctx,
|
|
structs.NewTaskEvent(structs.TaskKilling).
|
|
SetFailsTask().
|
|
SetDisplayMessage(reason.Error()),
|
|
); err != nil {
|
|
h.logger.Error("failed to kill task", "kill_reason", reason, "error", err)
|
|
}
|
|
}
|
|
|
|
// tryDerive loops forever until a token is created, or ctx is done.
|
|
func (h *sidsHook) tryDerive(ctx context.Context, ch chan<- siDerivationResult) {
|
|
for attempt := 0; backoff(ctx, attempt); attempt++ {
|
|
|
|
tokens, err := h.sidsClient.DeriveSITokens(h.alloc, []string{h.task.Name})
|
|
|
|
switch {
|
|
case err == nil:
|
|
token, exists := tokens[h.task.Name]
|
|
if !exists {
|
|
err := errors.New("response does not include token for task")
|
|
h.logger.Error("derive SI token is missing token for task", "error", err, "task", h.task.Name)
|
|
ch <- siDerivationResult{token: "", err: err}
|
|
return
|
|
}
|
|
ch <- siDerivationResult{token: token, err: nil}
|
|
return
|
|
case structs.IsServerSide(err):
|
|
// the error is known to be a server problem, just die
|
|
h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "server_side", true)
|
|
ch <- siDerivationResult{token: "", err: err}
|
|
return
|
|
case !structs.IsRecoverable(err):
|
|
// the error is known not to be recoverable, just die
|
|
h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "recoverable", false)
|
|
ch <- siDerivationResult{token: "", err: err}
|
|
return
|
|
|
|
default:
|
|
// the error is marked recoverable, retry after some backoff
|
|
h.logger.Error("failed attempt to derive SI token", "error", err, "recoverable", true)
|
|
}
|
|
}
|
|
}
|
|
|
|
func backoff(ctx context.Context, attempt int) bool {
|
|
next := computeBackoff(attempt)
|
|
select {
|
|
case <-ctx.Done():
|
|
return false
|
|
case <-time.After(next):
|
|
return true
|
|
}
|
|
}
|
|
|
|
func computeBackoff(attempt int) time.Duration {
|
|
switch attempt {
|
|
case 0:
|
|
return 0
|
|
case 1:
|
|
// go fast on first retry, because a unit test should be fast
|
|
return 100 * time.Millisecond
|
|
default:
|
|
wait := time.Duration(attempt) * sidsBackoffBaseline
|
|
if wait > sidsBackoffLimit {
|
|
wait = sidsBackoffLimit
|
|
}
|
|
return wait
|
|
}
|
|
}
|