open-vault/command/agent/cert_with_name_end_to_end_t...

245 lines
5.8 KiB
Go

package agent
import (
"context"
"encoding/pem"
"io/ioutil"
"os"
"testing"
"time"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
vaultcert "github.com/hashicorp/vault/builtin/credential/cert"
"github.com/hashicorp/vault/command/agent/auth"
agentcert "github.com/hashicorp/vault/command/agent/auth/cert"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
"github.com/hashicorp/vault/helper/dhutil"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
)
func TestCertWithNameEndToEnd(t *testing.T) {
testCertWithNameEndToEnd(t, false)
testCertWithNameEndToEnd(t, true)
}
func testCertWithNameEndToEnd(t *testing.T, ahWrapping bool) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"cert": vaultcert.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
client := cluster.Cores[0].Client
// Setup Vault
err := client.Sys().EnableAuthWithOptions("cert", &api.EnableAuthOptions{
Type: "cert",
})
if err != nil {
t.Fatal(err)
}
certificatePEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cluster.CACert.Raw})
_, err = client.Logical().Write("auth/cert/certs/test", map[string]interface{}{
"name": "test",
"certificate": string(certificatePEM),
"policies": "default",
})
if err != nil {
t.Fatal(err)
}
// Generate encryption params
pub, pri, err := dhutil.GeneratePublicPrivateKey()
if err != nil {
t.Fatal(err)
}
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
if err != nil {
t.Fatal(err)
}
out := ouf.Name()
ouf.Close()
os.Remove(out)
t.Logf("output: %s", out)
dhpathf, err := ioutil.TempFile("", "auth.dhpath.test.")
if err != nil {
t.Fatal(err)
}
dhpath := dhpathf.Name()
dhpathf.Close()
os.Remove(dhpath)
// Write DH public key to file
mPubKey, err := jsonutil.EncodeJSON(&dhutil.PublicKeyInfo{
Curve25519PublicKey: pub,
})
if err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(dhpath, mPubKey, 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote dh param file", "path", dhpath)
}
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
am, err := agentcert.NewCertAuthMethod(&auth.AuthConfig{
Logger: logger.Named("auth.cert"),
MountPath: "auth/cert",
Config: map[string]interface{}{
"name": "test",
},
})
if err != nil {
t.Fatal(err)
}
ahConfig := &auth.AuthHandlerConfig{
Logger: logger.Named("auth.handler"),
Client: client,
EnableReauthOnNewCredentials: true,
}
if ahWrapping {
ahConfig.WrapTTL = 10 * time.Second
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
defer func() {
<-ah.DoneCh
}()
config := &sink.SinkConfig{
Logger: logger.Named("sink.file"),
AAD: "foobar",
DHType: "curve25519",
DHPath: dhpath,
Config: map[string]interface{}{
"path": out,
},
}
if !ahWrapping {
config.WrapTTL = 10 * time.Second
}
fs, err := file.NewFileSink(config)
if err != nil {
t.Fatal(err)
}
config.Sink = fs
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
cloned, err := client.Clone()
if err != nil {
t.Fatal(err)
}
checkToken := func() string {
timeout := time.Now().Add(5 * time.Second)
for {
if time.Now().After(timeout) {
t.Fatal("did not find a written token after timeout")
}
val, err := ioutil.ReadFile(out)
if err == nil {
os.Remove(out)
if len(val) == 0 {
t.Fatal("written token was empty")
}
// First decrypt it
resp := new(dhutil.Envelope)
if err := jsonutil.DecodeJSON(val, resp); err != nil {
continue
}
aesKey, err := dhutil.GenerateSharedKey(pri, resp.Curve25519PublicKey)
if err != nil {
t.Fatal(err)
}
if len(aesKey) == 0 {
t.Fatal("got empty aes key")
}
val, err = dhutil.DecryptAES(aesKey, resp.EncryptedPayload, resp.Nonce, []byte("foobar"))
if err != nil {
t.Fatalf("error: %v\nresp: %v", err, string(val))
}
// Now unwrap it
wrapInfo := new(api.SecretWrapInfo)
if err := jsonutil.DecodeJSON(val, wrapInfo); err != nil {
t.Fatal(err)
}
switch {
case wrapInfo.TTL != 10:
t.Fatalf("bad wrap info: %v", wrapInfo.TTL)
case !ahWrapping && wrapInfo.CreationPath != "sys/wrapping/wrap":
t.Fatalf("bad wrap path: %v", wrapInfo.CreationPath)
case ahWrapping && wrapInfo.CreationPath != "auth/cert/login":
t.Fatalf("bad wrap path: %v", wrapInfo.CreationPath)
case wrapInfo.Token == "":
t.Fatal("wrap token is empty")
}
cloned.SetToken(wrapInfo.Token)
secret, err := cloned.Logical().Unwrap("")
if err != nil {
t.Fatal(err)
}
if ahWrapping {
switch {
case secret.Auth == nil:
t.Fatal("unwrap secret auth is nil")
case secret.Auth.ClientToken == "":
t.Fatal("unwrap token is nil")
}
return secret.Auth.ClientToken
} else {
switch {
case secret.Data == nil:
t.Fatal("unwrap secret data is nil")
case secret.Data["token"] == nil:
t.Fatal("unwrap token is nil")
}
return secret.Data["token"].(string)
}
}
time.Sleep(250 * time.Millisecond)
}
}
checkToken()
}