diff --git a/command/agent.go b/command/agent.go index 688e4a93f..57e0551d7 100644 --- a/command/agent.go +++ b/command/agent.go @@ -169,7 +169,9 @@ func (c *AgentCommand) Run(args []string) int { return 1 } - c.logger = logging.NewVaultLoggerWithWriter(c.logWriter, level) + if c.logger == nil { + c.logger = logging.NewVaultLoggerWithWriter(c.logWriter, level) + } // Validation if len(c.flagConfigs) != 1 { @@ -313,8 +315,9 @@ func (c *AgentCommand) Run(args []string) int { } ss := sink.NewSinkServer(&sink.SinkServerConfig{ - Logger: c.logger.Named("sink.server"), - Client: client, + Logger: c.logger.Named("sink.server"), + Client: client, + ExitAfterAuth: config.ExitAfterAuth, }) ah := auth.NewAuthHandler(&auth.AuthHandlerConfig{ @@ -342,6 +345,9 @@ func (c *AgentCommand) Run(args []string) int { }() select { + case <-ss.DoneCh: + // This will happen if we exit-on-auth + c.logger.Info("sinks finished, exiting") case <-c.ShutdownCh: c.UI.Output("==> Vault agent shutdown triggered") cancelFunc() diff --git a/command/agent/config/config.go b/command/agent/config/config.go index c809b0be1..f4e6c9f13 100644 --- a/command/agent/config/config.go +++ b/command/agent/config/config.go @@ -19,8 +19,9 @@ import ( // Config is the configuration for the vault server. type Config struct { - AutoAuth *AutoAuth `hcl:"auto_auth"` - PidFile string `hcl:"pid_file"` + AutoAuth *AutoAuth `hcl:"auto_auth"` + ExitAfterAuth bool `hcl:"exit_after_auth"` + PidFile string `hcl:"pid_file"` } type AutoAuth struct { diff --git a/command/agent/jwt_end_to_end_test.go b/command/agent/jwt_end_to_end_test.go index 2fa6c438b..da96e58c2 100644 --- a/command/agent/jwt_end_to_end_test.go +++ b/command/agent/jwt_end_to_end_test.go @@ -2,10 +2,7 @@ package agent import ( "context" - "crypto/ecdsa" - "crypto/x509" "encoding/json" - "encoding/pem" "io/ioutil" "os" "testing" @@ -24,50 +21,8 @@ import ( vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/vault" - jose "gopkg.in/square/go-jose.v2" - "gopkg.in/square/go-jose.v2/jwt" ) -func getTestJWT(t *testing.T) (string, *ecdsa.PrivateKey) { - t.Helper() - cl := jwt.Claims{ - Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", - Issuer: "https://team-vault.auth0.com/", - NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)), - Audience: jwt.Audience{"https://vault.plugin.auth.jwt.test"}, - } - - privateCl := struct { - User string `json:"https://vault/user"` - Groups []string `json:"https://vault/groups"` - }{ - "jeff", - []string{"foo", "bar"}, - } - - var key *ecdsa.PrivateKey - block, _ := pem.Decode([]byte(ecdsaPrivKey)) - if block != nil { - var err error - key, err = x509.ParseECPrivateKey(block.Bytes) - if err != nil { - t.Fatal(err) - } - } - - sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, (&jose.SignerOptions{}).WithType("JWT")) - if err != nil { - t.Fatal(err) - } - - raw, err := jwt.Signed(sig).Claims(cl).Claims(privateCl).CompactSerialize() - if err != nil { - t.Fatal(err) - } - - return raw, key -} - func TestJWTEndToEnd(t *testing.T) { testJWTEndToEnd(t, false) testJWTEndToEnd(t, true) @@ -100,7 +55,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) { _, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{ "bound_issuer": "https://team-vault.auth0.com/", - "jwt_validation_pubkeys": ecdsaPubKey, + "jwt_validation_pubkeys": TestECDSAPubKey, }) if err != nil { t.Fatal(err) @@ -248,7 +203,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) { } // Get a token - jwtToken, _ := getTestJWT(t) + jwtToken, _ := GetTestJWT(t) if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil { t.Fatal(err) } else { @@ -355,7 +310,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) { // Get another token to test the backend pushing the need to authenticate // to the handler - jwtToken, _ = getTestJWT(t) + jwtToken, _ = GetTestJWT(t) if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil { t.Fatal(err) } @@ -394,16 +349,3 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) { } } } - -const ( - ecdsaPrivKey string = `-----BEGIN EC PRIVATE KEY----- -MHcCAQEEIKfldwWLPYsHjRL9EVTsjSbzTtcGRu6icohNfIqcb6A+oAoGCCqGSM49 -AwEHoUQDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbSq+7+1q9BFxAkzjgKnlkXk5qx -hzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ== ------END EC PRIVATE KEY-----` - - ecdsaPubKey string = `-----BEGIN PUBLIC KEY----- -MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbS -q+7+1q9BFxAkzjgKnlkXk5qxhzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ== ------END PUBLIC KEY-----` -) diff --git a/command/agent/sink/sink.go b/command/agent/sink/sink.go index 98d2238ed..9fe99ec4b 100644 --- a/command/agent/sink/sink.go +++ b/command/agent/sink/sink.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "math/rand" "os" + "sync/atomic" "time" "github.com/hashicorp/errwrap" @@ -34,25 +35,30 @@ type SinkConfig struct { } type SinkServerConfig struct { - Logger hclog.Logger - Client *api.Client - Context context.Context + Logger hclog.Logger + Client *api.Client + Context context.Context + ExitAfterAuth bool } // SinkServer is responsible for pushing tokens to sinks type SinkServer struct { - DoneCh chan struct{} - logger hclog.Logger - client *api.Client - random *rand.Rand + DoneCh chan struct{} + logger hclog.Logger + client *api.Client + random *rand.Rand + exitAfterAuth bool + remaining *int32 } func NewSinkServer(conf *SinkServerConfig) *SinkServer { ss := &SinkServer{ - DoneCh: make(chan struct{}), - logger: conf.Logger, - client: conf.Client, - random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), + DoneCh: make(chan struct{}), + logger: conf.Logger, + client: conf.Client, + random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), + exitAfterAuth: conf.ExitAfterAuth, + remaining: new(int32), } return ss @@ -86,6 +92,7 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si for { select { case <-sinkCh: + atomic.AddInt32(ss.remaining, -1) default: break drainLoop } @@ -116,11 +123,13 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si return currSink.WriteToken(currToken) } } + atomic.AddInt32(ss.remaining, 1) sinkCh <- sinkFunc(s, token) } } case sinkFunc := <-sinkCh: + atomic.AddInt32(ss.remaining, -1) select { case <-ctx.Done(): return @@ -134,8 +143,13 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si case <-ctx.Done(): return case <-time.After(backoff): + atomic.AddInt32(ss.remaining, 1) sinkCh <- sinkFunc } + } else { + if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth { + return + } } } } diff --git a/command/agent/testing.go b/command/agent/testing.go new file mode 100644 index 000000000..fad5963aa --- /dev/null +++ b/command/agent/testing.go @@ -0,0 +1,65 @@ +package agent + +import ( + "crypto/ecdsa" + "crypto/x509" + "encoding/pem" + "testing" + "time" + + jose "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +func GetTestJWT(t *testing.T) (string, *ecdsa.PrivateKey) { + t.Helper() + cl := jwt.Claims{ + Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", + Issuer: "https://team-vault.auth0.com/", + NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)), + Audience: jwt.Audience{"https://vault.plugin.auth.jwt.test"}, + } + + privateCl := struct { + User string `json:"https://vault/user"` + Groups []string `json:"https://vault/groups"` + }{ + "jeff", + []string{"foo", "bar"}, + } + + var key *ecdsa.PrivateKey + block, _ := pem.Decode([]byte(TestECDSAPrivKey)) + if block != nil { + var err error + key, err = x509.ParseECPrivateKey(block.Bytes) + if err != nil { + t.Fatal(err) + } + } + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, (&jose.SignerOptions{}).WithType("JWT")) + if err != nil { + t.Fatal(err) + } + + raw, err := jwt.Signed(sig).Claims(cl).Claims(privateCl).CompactSerialize() + if err != nil { + t.Fatal(err) + } + + return raw, key +} + +const ( + TestECDSAPrivKey string = `-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIKfldwWLPYsHjRL9EVTsjSbzTtcGRu6icohNfIqcb6A+oAoGCCqGSM49 +AwEHoUQDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbSq+7+1q9BFxAkzjgKnlkXk5qx +hzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ== +-----END EC PRIVATE KEY-----` + + TestECDSAPubKey string = `-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbS +q+7+1q9BFxAkzjgKnlkXk5qxhzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ== +-----END PUBLIC KEY-----` +) diff --git a/command/agent_test.go b/command/agent_test.go new file mode 100644 index 000000000..d7281fc38 --- /dev/null +++ b/command/agent_test.go @@ -0,0 +1,186 @@ +package command + +import ( + "fmt" + "io/ioutil" + "os" + "testing" + + hclog "github.com/hashicorp/go-hclog" + vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command/agent" + "github.com/hashicorp/vault/helper/logging" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" + "github.com/mitchellh/cli" +) + +func testAgentCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *AgentCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &AgentCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + ShutdownCh: MakeShutdownCh(), + logger: logger, + } +} + +func TestExitAfterAuth(t *testing.T) { + logger := logging.NewVaultLogger(hclog.Trace) + coreConfig := &vault.CoreConfig{ + Logger: logger, + CredentialBackends: map[string]logical.Factory{ + "jwt": vaultjwt.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + vault.TestWaitActive(t, cluster.Cores[0].Core) + client := cluster.Cores[0].Client + + // Setup Vault + err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{ + Type: "jwt", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{ + "bound_issuer": "https://team-vault.auth0.com/", + "jwt_validation_pubkeys": agent.TestECDSAPubKey, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{ + "bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", + "bound_audiences": "https://vault.plugin.auth.jwt.test", + "user_claim": "https://vault/user", + "groups_claim": "https://vault/groups", + "policies": "test", + "period": "3s", + }) + if err != nil { + t.Fatal(err) + } + + inf, err := ioutil.TempFile("", "auth.jwt.test.") + if err != nil { + t.Fatal(err) + } + in := inf.Name() + inf.Close() + os.Remove(in) + t.Logf("input: %s", in) + + sink1f, err := ioutil.TempFile("", "sink1.jwt.test.") + if err != nil { + t.Fatal(err) + } + sink1 := sink1f.Name() + sink1f.Close() + os.Remove(sink1) + t.Logf("sink1: %s", sink1) + + sink2f, err := ioutil.TempFile("", "sink2.jwt.test.") + if err != nil { + t.Fatal(err) + } + sink2 := sink2f.Name() + sink2f.Close() + os.Remove(sink2) + t.Logf("sink2: %s", sink2) + + conff, err := ioutil.TempFile("", "conf.jwt.test.") + if err != nil { + t.Fatal(err) + } + conf := conff.Name() + conff.Close() + os.Remove(conf) + t.Logf("config: %s", conf) + + jwtToken, _ := agent.GetTestJWT(t) + if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil { + t.Fatal(err) + } else { + logger.Trace("wrote test jwt", "path", in) + } + + config := ` +exit_after_auth = true + +auto_auth { + method { + type = "jwt" + config = { + role = "test" + path = "%s" + } + } + + sink { + type = "file" + config = { + path = "%s" + } + } + + sink "file" { + config = { + path = "%s" + } + } +} +` + + config = fmt.Sprintf(config, in, sink1, sink2) + if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil { + t.Fatal(err) + } else { + logger.Trace("wrote test config", "path", conf) + } + + // If this hangs forever until the test times out, exit-after-auth isn't + // working + ui, cmd := testAgentCommand(t, logger) + cmd.client = client + + code := cmd.Run([]string{"-config", conf}) + if code != 0 { + t.Errorf("expected %d to be %d", code, 0) + t.Logf("output from agent:\n%s", ui.OutputWriter.String()) + t.Logf("error from agent:\n%s", ui.ErrorWriter.String()) + } + + sink1Bytes, err := ioutil.ReadFile(sink1) + if err != nil { + t.Fatal(err) + } + if len(sink1Bytes) == 0 { + t.Fatal("got no output from sink 1") + } + + sink2Bytes, err := ioutil.ReadFile(sink2) + if err != nil { + t.Fatal(err) + } + if len(sink2Bytes) == 0 { + t.Fatal("got no output from sink 2") + } + + if string(sink1Bytes) != string(sink2Bytes) { + t.Fatal("sink 1/2 values don't match") + } +} diff --git a/website/source/docs/agent/index.html.md b/website/source/docs/agent/index.html.md index 3a11c0ce9..87a296647 100644 --- a/website/source/docs/agent/index.html.md +++ b/website/source/docs/agent/index.html.md @@ -26,10 +26,14 @@ Auto-Auth functionality takes place within an `auto_auth` configuration stanza. ## Configuration -There is one currently-available general configuration option: +These are the currently-available general configuration option: - `pid_file` `(string: "")` - Path to the file in which the agent's Process ID - (PID) should be stored. + (PID) should be stored + +- `exit_after_auth` `(bool: false)` - If set to `true`, the agent will exit + with code `0` after a single successful auth, where success means that a + token was retrieved and all sinks successfully wrote it ## Example Configuration