diff --git a/changelog/10850.txt b/changelog/10850.txt new file mode 100644 index 000000000..211b9a10d --- /dev/null +++ b/changelog/10850.txt @@ -0,0 +1,3 @@ +```release-note:improvement +agent: change auto-auth to preload an existing token on start +``` diff --git a/command/agent/auth/auth.go b/command/agent/auth/auth.go index 9425dcd96..e8aa9cfdf 100644 --- a/command/agent/auth/auth.go +++ b/command/agent/auth/auth.go @@ -2,6 +2,7 @@ package auth import ( "context" + "encoding/json" "errors" "math/rand" "net/http" @@ -42,6 +43,7 @@ type AuthConfig struct { type AuthHandler struct { OutputCh chan string TemplateTokenCh chan string + token string logger hclog.Logger client *api.Client random *rand.Rand @@ -54,6 +56,7 @@ type AuthHandlerConfig struct { Logger hclog.Logger Client *api.Client WrapTTL time.Duration + Token string EnableReauthOnNewCredentials bool EnableTemplateTokenCh bool } @@ -64,6 +67,7 @@ func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler { // has been shut down, during agent shutdown, we won't block OutputCh: make(chan string, 1), TemplateTokenCh: make(chan string, 1), + token: conf.Token, logger: conf.Logger, client: conf.Client, random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), @@ -116,6 +120,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { } var watcher *api.LifetimeWatcher + first := true for { select { @@ -128,16 +133,11 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { // Create a fresh backoff value backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second)) - ah.logger.Info("authenticating") - - path, header, data, err := am.Authenticate(ctx, ah.client) - if err != nil { - ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds()) - backoffOrQuit(ctx, backoff) - continue - } - var clientToUse *api.Client + var err error + var path string + var data map[string]interface{} + var header http.Header switch am.(type) { case AuthMethodWithClient: @@ -151,6 +151,38 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { clientToUse = ah.client } + var secret *api.Secret = new(api.Secret) + if first && ah.token != "" { + ah.logger.Debug("using preloaded token") + + first = false + ah.logger.Debug("lookup-self with preloaded token") + clientToUse.SetToken(ah.token) + + secret, err = clientToUse.Logical().Read("auth/token/lookup-self") + if err != nil { + ah.logger.Error("could not look up token", "err", err, "backoff", backoff.Seconds()) + backoffOrQuit(ctx, backoff) + continue + } + + duration, _ := secret.Data["ttl"].(json.Number).Int64() + secret.Auth = &api.SecretAuth{ + ClientToken: secret.Data["id"].(string), + LeaseDuration: int(duration), + Renewable: secret.Data["renewable"].(bool), + } + } else { + ah.logger.Info("authenticating") + + path, header, data, err = am.Authenticate(ctx, ah.client) + if err != nil { + ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds()) + backoffOrQuit(ctx, backoff) + continue + } + } + if ah.wrapTTL > 0 { wrapClient, err := clientToUse.Clone() if err != nil { @@ -169,12 +201,16 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { } } - secret, err := clientToUse.Logical().Write(path, data) - // Check errors/sanity - if err != nil { - ah.logger.Error("error authenticating", "error", err, "backoff", backoff.Seconds()) - backoffOrQuit(ctx, backoff) - continue + // This should only happen if there's no preloaded token (regular auto-auth login) + // or if a preloaded token has expired and is now switching to auto-auth. + if secret.Auth == nil { + secret, err = clientToUse.Logical().Write(path, data) + // Check errors/sanity + if err != nil { + ah.logger.Error("error authenticating", "error", err, "backoff", backoff.Seconds()) + backoffOrQuit(ctx, backoff) + continue + } } switch { diff --git a/command/agent/auto_auth_preload_token_end_to_end_test.go b/command/agent/auto_auth_preload_token_end_to_end_test.go new file mode 100644 index 000000000..9e059049c --- /dev/null +++ b/command/agent/auto_auth_preload_token_end_to_end_test.go @@ -0,0 +1,238 @@ +package agent + +import ( + "context" + "io/ioutil" + "os" + "testing" + "time" + + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + credAppRole "github.com/hashicorp/vault/builtin/credential/approle" + "github.com/hashicorp/vault/command/agent/auth" + agentAppRole "github.com/hashicorp/vault/command/agent/auth/approle" + "github.com/hashicorp/vault/command/agent/sink" + "github.com/hashicorp/vault/command/agent/sink/file" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault" +) + +func TestTokenPreload_UsingAutoAuth(t *testing.T) { + logger := logging.NewVaultLogger(hclog.Trace) + coreConfig := &vault.CoreConfig{ + Logger: logger, + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + CredentialBackends: map[string]logical.Factory{ + "approle": credAppRole.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + vault.TestWaitActive(t, cluster.Cores[0].Core) + client := cluster.Cores[0].Client + + // Setup Vault + if err := client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{ + Type: "approle", + }); err != nil { + t.Fatal(err) + } + + // Setup Approle + _, err := client.Logical().Write("auth/approle/role/test1", map[string]interface{}{ + "bind_secret_id": "true", + "token_ttl": "3s", + "token_max_ttl": "10s", + "policies": []string{"test-autoauth"}, + }) + if err != nil { + t.Fatal(err) + } + + resp, err := client.Logical().Write("auth/approle/role/test1/secret-id", nil) + if err != nil { + t.Fatal(err) + } + secretID1 := resp.Data["secret_id"].(string) + + resp, err = client.Logical().Read("auth/approle/role/test1/role-id") + if err != nil { + t.Fatal(err) + } + roleID1 := resp.Data["role_id"].(string) + + rolef, err := ioutil.TempFile("", "auth.role-id.test.") + if err != nil { + t.Fatal(err) + } + role := rolef.Name() + rolef.Close() // WriteFile doesn't need it open + defer os.Remove(role) + t.Logf("input role_id_file_path: %s", role) + + secretf, err := ioutil.TempFile("", "auth.secret-id.test.") + if err != nil { + t.Fatal(err) + } + secret := secretf.Name() + secretf.Close() + defer os.Remove(secret) + t.Logf("input secret_id_file_path: %s", secret) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + + conf := map[string]interface{}{ + "role_id_file_path": role, + "secret_id_file_path": secret, + } + + if err := ioutil.WriteFile(role, []byte(roleID1), 0600); err != nil { + t.Fatal(err) + } else { + logger.Trace("wrote test role 1", "path", role) + } + + if err := ioutil.WriteFile(secret, []byte(secretID1), 0600); err != nil { + t.Fatal(err) + } else { + logger.Trace("wrote test secret 1", "path", secret) + } + + // Setup Preload Token + tokenRespRaw, err := client.Logical().Write("auth/token/create", map[string]interface{}{ + "ttl": "10s", + "explicit-max-ttl": "15s", + "policies": []string{""}, + }) + if err != nil { + t.Fatal(err) + } + + if tokenRespRaw.Auth == nil || tokenRespRaw.Auth.ClientToken == "" { + t.Fatal("expected token but got none") + } + token := tokenRespRaw.Auth.ClientToken + + am, err := agentAppRole.NewApproleAuthMethod(&auth.AuthConfig{ + Logger: logger.Named("auth.approle"), + MountPath: "auth/approle", + Config: conf, + }) + if err != nil { + t.Fatal(err) + } + + ahConfig := &auth.AuthHandlerConfig{ + Logger: logger.Named("auth.handler"), + Client: client, + Token: token, + } + + ah := auth.NewAuthHandler(ahConfig) + + tmpFile, err := ioutil.TempFile("", "auth.tokensink.test.") + if err != nil { + t.Fatal(err) + } + tokenSinkFileName := tmpFile.Name() + tmpFile.Close() + os.Remove(tokenSinkFileName) + t.Logf("output: %s", tokenSinkFileName) + + config := &sink.SinkConfig{ + Logger: logger.Named("sink.file"), + Config: map[string]interface{}{ + "path": tokenSinkFileName, + }, + WrapTTL: 10 * time.Second, + } + + fs, err := file.NewFileSink(config) + if err != nil { + t.Fatal(err) + } + config.Sink = fs + + ss := sink.NewSinkServer(&sink.SinkServerConfig{ + Logger: logger.Named("sink.server"), + Client: client, + }) + + errCh := make(chan error) + go func() { + errCh <- ah.Run(ctx, am) + }() + defer func() { + select { + case <-ctx.Done(): + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } + }() + + go func() { + errCh <- ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config}) + }() + defer func() { + select { + case <-ctx.Done(): + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } + }() + + // This has to be after the other defers so it happens first. It allows + // successful test runs to immediately cancel all of the runner goroutines + // and unblock any of the blocking defer calls by the runner's DoneCh that + // comes before this and avoid successful tests from taking the entire + // timeout duration. + defer cancel() + + if stat, err := os.Lstat(tokenSinkFileName); err == nil { + t.Fatalf("expected err but got %s", stat) + } else if !os.IsNotExist(err) { + t.Fatal("expected notexist err") + } + + // Wait 2 seconds for the env variables to be detected and an auth to be generated. + time.Sleep(time.Second * 2) + + authToken, err := readToken(tokenSinkFileName) + if err != nil { + t.Fatal(err) + } + + if authToken.Token == "" { + t.Fatal("expected token but didn't receive it") + } + + wrappedToken := map[string]interface{}{ + "token": authToken.Token, + } + unwrapResp, err := client.Logical().Write("sys/wrapping/unwrap", wrappedToken) + if err != nil { + t.Fatalf("error unwrapping token: %s", err) + } + + sinkToken, ok := unwrapResp.Data["token"].(string) + if !ok { + t.Fatal("expected token but didn't receive it") + } + + if sinkToken != token { + t.Fatalf("auth token and preload token should be the same: expected: %s, actual: %s", token, sinkToken) + } +}