Add exit-after-auth functionality to agent (#5013)

This allows it to authenticate once, then exit once all sinks have
reported success. Useful for things like an init container vs. a
sidecard container.

Also adds command-level testing of it.
This commit is contained in:
Jeff Mitchell 2018-07-30 10:37:04 -04:00 committed by GitHub
parent 0ad44a7ac5
commit a6d0ae5890
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 297 additions and 79 deletions

View file

@ -169,7 +169,9 @@ func (c *AgentCommand) Run(args []string) int {
return 1 return 1
} }
c.logger = logging.NewVaultLoggerWithWriter(c.logWriter, level) if c.logger == nil {
c.logger = logging.NewVaultLoggerWithWriter(c.logWriter, level)
}
// Validation // Validation
if len(c.flagConfigs) != 1 { if len(c.flagConfigs) != 1 {
@ -313,8 +315,9 @@ func (c *AgentCommand) Run(args []string) int {
} }
ss := sink.NewSinkServer(&sink.SinkServerConfig{ ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: c.logger.Named("sink.server"), Logger: c.logger.Named("sink.server"),
Client: client, Client: client,
ExitAfterAuth: config.ExitAfterAuth,
}) })
ah := auth.NewAuthHandler(&auth.AuthHandlerConfig{ ah := auth.NewAuthHandler(&auth.AuthHandlerConfig{
@ -342,6 +345,9 @@ func (c *AgentCommand) Run(args []string) int {
}() }()
select { select {
case <-ss.DoneCh:
// This will happen if we exit-on-auth
c.logger.Info("sinks finished, exiting")
case <-c.ShutdownCh: case <-c.ShutdownCh:
c.UI.Output("==> Vault agent shutdown triggered") c.UI.Output("==> Vault agent shutdown triggered")
cancelFunc() cancelFunc()

View file

@ -19,8 +19,9 @@ import (
// Config is the configuration for the vault server. // Config is the configuration for the vault server.
type Config struct { type Config struct {
AutoAuth *AutoAuth `hcl:"auto_auth"` AutoAuth *AutoAuth `hcl:"auto_auth"`
PidFile string `hcl:"pid_file"` ExitAfterAuth bool `hcl:"exit_after_auth"`
PidFile string `hcl:"pid_file"`
} }
type AutoAuth struct { type AutoAuth struct {

View file

@ -2,10 +2,7 @@ package agent
import ( import (
"context" "context"
"crypto/ecdsa"
"crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
@ -24,50 +21,8 @@ import (
vaulthttp "github.com/hashicorp/vault/http" vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
) )
func getTestJWT(t *testing.T) (string, *ecdsa.PrivateKey) {
t.Helper()
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: "https://team-vault.auth0.com/",
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Audience: jwt.Audience{"https://vault.plugin.auth.jwt.test"},
}
privateCl := struct {
User string `json:"https://vault/user"`
Groups []string `json:"https://vault/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}
var key *ecdsa.PrivateKey
block, _ := pem.Decode([]byte(ecdsaPrivKey))
if block != nil {
var err error
key, err = x509.ParseECPrivateKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
}
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
t.Fatal(err)
}
raw, err := jwt.Signed(sig).Claims(cl).Claims(privateCl).CompactSerialize()
if err != nil {
t.Fatal(err)
}
return raw, key
}
func TestJWTEndToEnd(t *testing.T) { func TestJWTEndToEnd(t *testing.T) {
testJWTEndToEnd(t, false) testJWTEndToEnd(t, false)
testJWTEndToEnd(t, true) testJWTEndToEnd(t, true)
@ -100,7 +55,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{ _, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
"bound_issuer": "https://team-vault.auth0.com/", "bound_issuer": "https://team-vault.auth0.com/",
"jwt_validation_pubkeys": ecdsaPubKey, "jwt_validation_pubkeys": TestECDSAPubKey,
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -248,7 +203,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
} }
// Get a token // Get a token
jwtToken, _ := getTestJWT(t) jwtToken, _ := GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil { if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
t.Fatal(err) t.Fatal(err)
} else { } else {
@ -355,7 +310,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
// Get another token to test the backend pushing the need to authenticate // Get another token to test the backend pushing the need to authenticate
// to the handler // to the handler
jwtToken, _ = getTestJWT(t) jwtToken, _ = GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil { if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -394,16 +349,3 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
} }
} }
} }
const (
ecdsaPrivKey string = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKfldwWLPYsHjRL9EVTsjSbzTtcGRu6icohNfIqcb6A+oAoGCCqGSM49
AwEHoUQDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbSq+7+1q9BFxAkzjgKnlkXk5qx
hzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ==
-----END EC PRIVATE KEY-----`
ecdsaPubKey string = `-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbS
q+7+1q9BFxAkzjgKnlkXk5qxhzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ==
-----END PUBLIC KEY-----`
)

View file

@ -6,6 +6,7 @@ import (
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"os" "os"
"sync/atomic"
"time" "time"
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
@ -34,25 +35,30 @@ type SinkConfig struct {
} }
type SinkServerConfig struct { type SinkServerConfig struct {
Logger hclog.Logger Logger hclog.Logger
Client *api.Client Client *api.Client
Context context.Context Context context.Context
ExitAfterAuth bool
} }
// SinkServer is responsible for pushing tokens to sinks // SinkServer is responsible for pushing tokens to sinks
type SinkServer struct { type SinkServer struct {
DoneCh chan struct{} DoneCh chan struct{}
logger hclog.Logger logger hclog.Logger
client *api.Client client *api.Client
random *rand.Rand random *rand.Rand
exitAfterAuth bool
remaining *int32
} }
func NewSinkServer(conf *SinkServerConfig) *SinkServer { func NewSinkServer(conf *SinkServerConfig) *SinkServer {
ss := &SinkServer{ ss := &SinkServer{
DoneCh: make(chan struct{}), DoneCh: make(chan struct{}),
logger: conf.Logger, logger: conf.Logger,
client: conf.Client, client: conf.Client,
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
exitAfterAuth: conf.ExitAfterAuth,
remaining: new(int32),
} }
return ss return ss
@ -86,6 +92,7 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si
for { for {
select { select {
case <-sinkCh: case <-sinkCh:
atomic.AddInt32(ss.remaining, -1)
default: default:
break drainLoop break drainLoop
} }
@ -116,11 +123,13 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si
return currSink.WriteToken(currToken) return currSink.WriteToken(currToken)
} }
} }
atomic.AddInt32(ss.remaining, 1)
sinkCh <- sinkFunc(s, token) sinkCh <- sinkFunc(s, token)
} }
} }
case sinkFunc := <-sinkCh: case sinkFunc := <-sinkCh:
atomic.AddInt32(ss.remaining, -1)
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
@ -134,8 +143,13 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si
case <-ctx.Done(): case <-ctx.Done():
return return
case <-time.After(backoff): case <-time.After(backoff):
atomic.AddInt32(ss.remaining, 1)
sinkCh <- sinkFunc sinkCh <- sinkFunc
} }
} else {
if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth {
return
}
} }
} }
} }

65
command/agent/testing.go Normal file
View file

@ -0,0 +1,65 @@
package agent
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/pem"
"testing"
"time"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)
func GetTestJWT(t *testing.T) (string, *ecdsa.PrivateKey) {
t.Helper()
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: "https://team-vault.auth0.com/",
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Audience: jwt.Audience{"https://vault.plugin.auth.jwt.test"},
}
privateCl := struct {
User string `json:"https://vault/user"`
Groups []string `json:"https://vault/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}
var key *ecdsa.PrivateKey
block, _ := pem.Decode([]byte(TestECDSAPrivKey))
if block != nil {
var err error
key, err = x509.ParseECPrivateKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
}
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
t.Fatal(err)
}
raw, err := jwt.Signed(sig).Claims(cl).Claims(privateCl).CompactSerialize()
if err != nil {
t.Fatal(err)
}
return raw, key
}
const (
TestECDSAPrivKey string = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKfldwWLPYsHjRL9EVTsjSbzTtcGRu6icohNfIqcb6A+oAoGCCqGSM49
AwEHoUQDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbSq+7+1q9BFxAkzjgKnlkXk5qx
hzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ==
-----END EC PRIVATE KEY-----`
TestECDSAPubKey string = `-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbS
q+7+1q9BFxAkzjgKnlkXk5qxhzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ==
-----END PUBLIC KEY-----`
)

186
command/agent_test.go Normal file
View file

@ -0,0 +1,186 @@
package command
import (
"fmt"
"io/ioutil"
"os"
"testing"
hclog "github.com/hashicorp/go-hclog"
vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent"
"github.com/hashicorp/vault/helper/logging"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
)
func testAgentCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *AgentCommand) {
tb.Helper()
ui := cli.NewMockUi()
return ui, &AgentCommand{
BaseCommand: &BaseCommand{
UI: ui,
},
ShutdownCh: MakeShutdownCh(),
logger: logger,
}
}
func TestExitAfterAuth(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"jwt": vaultjwt.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
client := cluster.Cores[0].Client
// Setup Vault
err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{
Type: "jwt",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
"bound_issuer": "https://team-vault.auth0.com/",
"jwt_validation_pubkeys": agent.TestECDSAPubKey,
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
"bound_audiences": "https://vault.plugin.auth.jwt.test",
"user_claim": "https://vault/user",
"groups_claim": "https://vault/groups",
"policies": "test",
"period": "3s",
})
if err != nil {
t.Fatal(err)
}
inf, err := ioutil.TempFile("", "auth.jwt.test.")
if err != nil {
t.Fatal(err)
}
in := inf.Name()
inf.Close()
os.Remove(in)
t.Logf("input: %s", in)
sink1f, err := ioutil.TempFile("", "sink1.jwt.test.")
if err != nil {
t.Fatal(err)
}
sink1 := sink1f.Name()
sink1f.Close()
os.Remove(sink1)
t.Logf("sink1: %s", sink1)
sink2f, err := ioutil.TempFile("", "sink2.jwt.test.")
if err != nil {
t.Fatal(err)
}
sink2 := sink2f.Name()
sink2f.Close()
os.Remove(sink2)
t.Logf("sink2: %s", sink2)
conff, err := ioutil.TempFile("", "conf.jwt.test.")
if err != nil {
t.Fatal(err)
}
conf := conff.Name()
conff.Close()
os.Remove(conf)
t.Logf("config: %s", conf)
jwtToken, _ := agent.GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test jwt", "path", in)
}
config := `
exit_after_auth = true
auto_auth {
method {
type = "jwt"
config = {
role = "test"
path = "%s"
}
}
sink {
type = "file"
config = {
path = "%s"
}
}
sink "file" {
config = {
path = "%s"
}
}
}
`
config = fmt.Sprintf(config, in, sink1, sink2)
if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test config", "path", conf)
}
// If this hangs forever until the test times out, exit-after-auth isn't
// working
ui, cmd := testAgentCommand(t, logger)
cmd.client = client
code := cmd.Run([]string{"-config", conf})
if code != 0 {
t.Errorf("expected %d to be %d", code, 0)
t.Logf("output from agent:\n%s", ui.OutputWriter.String())
t.Logf("error from agent:\n%s", ui.ErrorWriter.String())
}
sink1Bytes, err := ioutil.ReadFile(sink1)
if err != nil {
t.Fatal(err)
}
if len(sink1Bytes) == 0 {
t.Fatal("got no output from sink 1")
}
sink2Bytes, err := ioutil.ReadFile(sink2)
if err != nil {
t.Fatal(err)
}
if len(sink2Bytes) == 0 {
t.Fatal("got no output from sink 2")
}
if string(sink1Bytes) != string(sink2Bytes) {
t.Fatal("sink 1/2 values don't match")
}
}

View file

@ -26,10 +26,14 @@ Auto-Auth functionality takes place within an `auto_auth` configuration stanza.
## Configuration ## Configuration
There is one currently-available general configuration option: These are the currently-available general configuration option:
- `pid_file` `(string: "")` - Path to the file in which the agent's Process ID - `pid_file` `(string: "")` - Path to the file in which the agent's Process ID
(PID) should be stored. (PID) should be stored
- `exit_after_auth` `(bool: false)` - If set to `true`, the agent will exit
with code `0` after a single successful auth, where success means that a
token was retrieved and all sinks successfully wrote it
## Example Configuration ## Example Configuration