4ef8d66318
The existing code would retain the previous backoff value even after the system had recovered. This PR fixes that issue and improves the structure of the backoff code.
366 lines
9.4 KiB
Go
366 lines
9.4 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"math/rand"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
|
)
|
|
|
|
const (
|
|
initialBackoff = 1 * time.Second
|
|
defaultMaxBackoff = 5 * time.Minute
|
|
)
|
|
|
|
// AuthMethod is the interface that auto-auth methods implement for the agent
|
|
// to use.
|
|
type AuthMethod interface {
|
|
// Authenticate returns a mount path, header, request body, and error.
|
|
// The header may be nil if no special header is needed.
|
|
Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error)
|
|
NewCreds() chan struct{}
|
|
CredSuccess()
|
|
Shutdown()
|
|
}
|
|
|
|
// AuthMethodWithClient is an extended interface that can return an API client
|
|
// for use during the authentication call.
|
|
type AuthMethodWithClient interface {
|
|
AuthMethod
|
|
AuthClient(client *api.Client) (*api.Client, error)
|
|
}
|
|
|
|
type AuthConfig struct {
|
|
Logger hclog.Logger
|
|
MountPath string
|
|
WrapTTL time.Duration
|
|
Config map[string]interface{}
|
|
}
|
|
|
|
// AuthHandler is responsible for keeping a token alive and renewed and passing
|
|
// new tokens to the sink server
|
|
type AuthHandler struct {
|
|
OutputCh chan string
|
|
TemplateTokenCh chan string
|
|
token string
|
|
logger hclog.Logger
|
|
client *api.Client
|
|
random *rand.Rand
|
|
wrapTTL time.Duration
|
|
maxBackoff time.Duration
|
|
enableReauthOnNewCredentials bool
|
|
enableTemplateTokenCh bool
|
|
}
|
|
|
|
type AuthHandlerConfig struct {
|
|
Logger hclog.Logger
|
|
Client *api.Client
|
|
WrapTTL time.Duration
|
|
MaxBackoff time.Duration
|
|
Token string
|
|
EnableReauthOnNewCredentials bool
|
|
EnableTemplateTokenCh bool
|
|
}
|
|
|
|
func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
|
|
ah := &AuthHandler{
|
|
// This is buffered so that if we try to output after the sink server
|
|
// has been shut down, during agent shutdown, we won't block
|
|
OutputCh: make(chan string, 1),
|
|
TemplateTokenCh: make(chan string, 1),
|
|
token: conf.Token,
|
|
logger: conf.Logger,
|
|
client: conf.Client,
|
|
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
|
|
wrapTTL: conf.WrapTTL,
|
|
maxBackoff: conf.MaxBackoff,
|
|
enableReauthOnNewCredentials: conf.EnableReauthOnNewCredentials,
|
|
enableTemplateTokenCh: conf.EnableTemplateTokenCh,
|
|
}
|
|
|
|
return ah
|
|
}
|
|
|
|
func backoffOrQuit(ctx context.Context, backoff *agentBackoff) {
|
|
select {
|
|
case <-time.After(backoff.current):
|
|
case <-ctx.Done():
|
|
}
|
|
|
|
// Increase exponential backoff for the next time if we don't
|
|
// successfully auth/renew/etc.
|
|
backoff.next()
|
|
}
|
|
|
|
func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|
if am == nil {
|
|
return errors.New("auth handler: nil auth method")
|
|
}
|
|
|
|
backoff := newAgentBackoff(ah.maxBackoff)
|
|
|
|
ah.logger.Info("starting auth handler")
|
|
defer func() {
|
|
am.Shutdown()
|
|
close(ah.OutputCh)
|
|
close(ah.TemplateTokenCh)
|
|
ah.logger.Info("auth handler stopped")
|
|
}()
|
|
|
|
credCh := am.NewCreds()
|
|
if !ah.enableReauthOnNewCredentials {
|
|
realCredCh := credCh
|
|
credCh = nil
|
|
if realCredCh != nil {
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-realCredCh:
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
if credCh == nil {
|
|
credCh = make(chan struct{})
|
|
}
|
|
|
|
var watcher *api.LifetimeWatcher
|
|
first := true
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
|
|
default:
|
|
}
|
|
|
|
var clientToUse *api.Client
|
|
var err error
|
|
var path string
|
|
var data map[string]interface{}
|
|
var header http.Header
|
|
|
|
switch am.(type) {
|
|
case AuthMethodWithClient:
|
|
clientToUse, err = am.(AuthMethodWithClient).AuthClient(ah.client)
|
|
if err != nil {
|
|
ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoff)
|
|
backoffOrQuit(ctx, backoff)
|
|
continue
|
|
}
|
|
default:
|
|
clientToUse = ah.client
|
|
}
|
|
|
|
var secret *api.Secret = new(api.Secret)
|
|
if first && ah.token != "" {
|
|
ah.logger.Debug("using preloaded token")
|
|
|
|
first = false
|
|
ah.logger.Debug("lookup-self with preloaded token")
|
|
clientToUse.SetToken(ah.token)
|
|
|
|
secret, err = clientToUse.Logical().Read("auth/token/lookup-self")
|
|
if err != nil {
|
|
ah.logger.Error("could not look up token", "err", err, "backoff", backoff)
|
|
backoffOrQuit(ctx, backoff)
|
|
continue
|
|
}
|
|
|
|
duration, _ := secret.Data["ttl"].(json.Number).Int64()
|
|
secret.Auth = &api.SecretAuth{
|
|
ClientToken: secret.Data["id"].(string),
|
|
LeaseDuration: int(duration),
|
|
Renewable: secret.Data["renewable"].(bool),
|
|
}
|
|
} else {
|
|
ah.logger.Info("authenticating")
|
|
|
|
path, header, data, err = am.Authenticate(ctx, ah.client)
|
|
if err != nil {
|
|
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff)
|
|
backoffOrQuit(ctx, backoff)
|
|
continue
|
|
}
|
|
}
|
|
|
|
if ah.wrapTTL > 0 {
|
|
wrapClient, err := clientToUse.Clone()
|
|
if err != nil {
|
|
ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff)
|
|
backoffOrQuit(ctx, backoff)
|
|
continue
|
|
}
|
|
wrapClient.SetWrappingLookupFunc(func(string, string) string {
|
|
return ah.wrapTTL.String()
|
|
})
|
|
clientToUse = wrapClient
|
|
}
|
|
for key, values := range header {
|
|
for _, value := range values {
|
|
clientToUse.AddHeader(key, value)
|
|
}
|
|
}
|
|
|
|
// This should only happen if there's no preloaded token (regular auto-auth login)
|
|
// or if a preloaded token has expired and is now switching to auto-auth.
|
|
if secret.Auth == nil {
|
|
secret, err = clientToUse.Logical().Write(path, data)
|
|
// Check errors/sanity
|
|
if err != nil {
|
|
ah.logger.Error("error authenticating", "error", err, "backoff", backoff)
|
|
backoffOrQuit(ctx, backoff)
|
|
continue
|
|
}
|
|
}
|
|
|
|
switch {
|
|
case ah.wrapTTL > 0:
|
|
if secret.WrapInfo == nil {
|
|
ah.logger.Error("authentication returned nil wrap info", "backoff", backoff)
|
|
backoffOrQuit(ctx, backoff)
|
|
continue
|
|
}
|
|
if secret.WrapInfo.Token == "" {
|
|
ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff)
|
|
backoffOrQuit(ctx, backoff)
|
|
continue
|
|
}
|
|
wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo)
|
|
if err != nil {
|
|
ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff)
|
|
backoffOrQuit(ctx, backoff)
|
|
continue
|
|
}
|
|
ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing")
|
|
ah.OutputCh <- string(wrappedResp)
|
|
if ah.enableTemplateTokenCh {
|
|
ah.TemplateTokenCh <- string(wrappedResp)
|
|
}
|
|
|
|
am.CredSuccess()
|
|
backoff.reset()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
ah.logger.Info("shutdown triggered")
|
|
continue
|
|
|
|
case <-credCh:
|
|
ah.logger.Info("auth method found new credentials, re-authenticating")
|
|
continue
|
|
}
|
|
|
|
default:
|
|
if secret == nil || secret.Auth == nil {
|
|
ah.logger.Error("authentication returned nil auth info", "backoff", backoff)
|
|
backoffOrQuit(ctx, backoff)
|
|
continue
|
|
}
|
|
if secret.Auth.ClientToken == "" {
|
|
ah.logger.Error("authentication returned empty client token", "backoff", backoff)
|
|
backoffOrQuit(ctx, backoff)
|
|
continue
|
|
}
|
|
ah.logger.Info("authentication successful, sending token to sinks")
|
|
ah.OutputCh <- secret.Auth.ClientToken
|
|
if ah.enableTemplateTokenCh {
|
|
ah.TemplateTokenCh <- secret.Auth.ClientToken
|
|
}
|
|
|
|
am.CredSuccess()
|
|
backoff.reset()
|
|
}
|
|
|
|
if watcher != nil {
|
|
watcher.Stop()
|
|
}
|
|
|
|
watcher, err = clientToUse.NewLifetimeWatcher(&api.LifetimeWatcherInput{
|
|
Secret: secret,
|
|
})
|
|
if err != nil {
|
|
ah.logger.Error("error creating lifetime watcher, backing off and retrying", "error", err, "backoff", backoff)
|
|
backoffOrQuit(ctx, backoff)
|
|
continue
|
|
}
|
|
|
|
// Start the renewal process
|
|
ah.logger.Info("starting renewal process")
|
|
go watcher.Renew()
|
|
|
|
LifetimeWatcherLoop:
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
ah.logger.Info("shutdown triggered, stopping lifetime watcher")
|
|
watcher.Stop()
|
|
break LifetimeWatcherLoop
|
|
|
|
case err := <-watcher.DoneCh():
|
|
ah.logger.Info("lifetime watcher done channel triggered")
|
|
if err != nil {
|
|
ah.logger.Error("error renewing token", "error", err)
|
|
}
|
|
break LifetimeWatcherLoop
|
|
|
|
case <-watcher.RenewCh():
|
|
ah.logger.Info("renewed auth token")
|
|
|
|
case <-credCh:
|
|
ah.logger.Info("auth method found new credentials, re-authenticating")
|
|
break LifetimeWatcherLoop
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// agentBackoff tracks exponential backoff state.
|
|
type agentBackoff struct {
|
|
max time.Duration
|
|
current time.Duration
|
|
}
|
|
|
|
func newAgentBackoff(max time.Duration) *agentBackoff {
|
|
if max <= 0 {
|
|
max = defaultMaxBackoff
|
|
}
|
|
|
|
return &agentBackoff{
|
|
max: max,
|
|
current: initialBackoff,
|
|
}
|
|
}
|
|
|
|
// next determines the next backoff duration that is roughly twice
|
|
// the current value, capped to a max value, with a measure of randomness.
|
|
func (b *agentBackoff) next() {
|
|
maxBackoff := 2 * b.current
|
|
|
|
if maxBackoff > b.max {
|
|
maxBackoff = b.max
|
|
}
|
|
|
|
// Trim a random amount (0-25%) off the doubled duration
|
|
trim := rand.Int63n(int64(maxBackoff) / 4)
|
|
b.current = maxBackoff - time.Duration(trim)
|
|
}
|
|
|
|
func (b *agentBackoff) reset() {
|
|
b.current = initialBackoff
|
|
}
|
|
|
|
func (b agentBackoff) String() string {
|
|
return b.current.Truncate(10 * time.Millisecond).String()
|
|
}
|