open-vault/command/agent/auth/auth.go

219 lines
5.2 KiB
Go

package auth
import (
"context"
"math/rand"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/jsonutil"
)
type AuthMethod interface {
Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error)
NewCreds() chan struct{}
CredSuccess()
Shutdown()
}
type AuthConfig struct {
Logger hclog.Logger
MountPath string
WrapTTL time.Duration
Config map[string]interface{}
}
// AuthHandler is responsible for keeping a token alive and renewed and passing
// new tokens to the sink server
type AuthHandler struct {
DoneCh chan struct{}
OutputCh chan string
logger hclog.Logger
client *api.Client
random *rand.Rand
wrapTTL time.Duration
}
type AuthHandlerConfig struct {
Logger hclog.Logger
Client *api.Client
WrapTTL time.Duration
}
func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
ah := &AuthHandler{
DoneCh: make(chan struct{}),
// This is buffered so that if we try to output after the sink server
// has been shut down, during agent shutdown, we won't block
OutputCh: make(chan string, 1),
logger: conf.Logger,
client: conf.Client,
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
wrapTTL: conf.WrapTTL,
}
return ah
}
func backoffOrQuit(ctx context.Context, backoff time.Duration) {
select {
case <-time.After(backoff):
case <-ctx.Done():
}
}
func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
if am == nil {
panic("nil auth method")
}
ah.logger.Info("starting auth handler")
defer func() {
am.Shutdown()
close(ah.OutputCh)
close(ah.DoneCh)
ah.logger.Info("auth handler stopped")
}()
credCh := am.NewCreds()
if credCh == nil {
credCh = make(chan struct{})
}
var renewer *api.Renewer
for {
select {
case <-ctx.Done():
return
default:
}
// Create a fresh backoff value
backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second))
ah.logger.Info("authenticating")
path, data, err := am.Authenticate(ctx, ah.client)
if err != nil {
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
clientToUse := ah.client
if ah.wrapTTL > 0 {
wrapClient, err := ah.client.Clone()
if err != nil {
ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
wrapClient.SetWrappingLookupFunc(func(string, string) string {
return ah.wrapTTL.String()
})
clientToUse = wrapClient
}
secret, err := clientToUse.Logical().Write(path, data)
// Check errors/sanity
if err != nil {
ah.logger.Error("error authenticating", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
switch {
case ah.wrapTTL > 0:
if secret.WrapInfo == nil {
ah.logger.Error("authentication returned nil wrap info", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
if secret.WrapInfo.Token == "" {
ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo)
if err != nil {
ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing")
ah.OutputCh <- string(wrappedResp)
am.CredSuccess()
select {
case <-ctx.Done():
ah.logger.Info("shutdown triggered")
continue
case <-credCh:
ah.logger.Info("auth method found new credentials, re-authenticating")
continue
}
default:
if secret.Auth == nil {
ah.logger.Error("authentication returned nil auth info", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
if secret.Auth.ClientToken == "" {
ah.logger.Error("authentication returned empty client token", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
ah.logger.Info("authentication successful, sending token to sinks")
ah.OutputCh <- secret.Auth.ClientToken
am.CredSuccess()
}
if renewer != nil {
renewer.Stop()
}
renewer, err = ah.client.NewRenewer(&api.RenewerInput{
Secret: secret,
})
if err != nil {
ah.logger.Error("error creating renewer, backing off and retrying", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
// Start the renewal process
ah.logger.Info("starting renewal process")
go renewer.Renew()
RenewerLoop:
for {
select {
case <-ctx.Done():
ah.logger.Info("shutdown triggered, stopping renewer")
renewer.Stop()
break RenewerLoop
case err := <-renewer.DoneCh():
ah.logger.Info("renewer done channel triggered")
if err != nil {
ah.logger.Error("error renewing token", "error", err)
}
break RenewerLoop
case <-renewer.RenewCh():
ah.logger.Info("renewed auth token")
case <-credCh:
ah.logger.Info("auth method found new credentials, re-authenticating")
break RenewerLoop
}
}
}
}