open-vault/physical/dynamodb/dynamodb_test.go

319 lines
7.6 KiB
Go

package dynamodb
import (
"context"
"fmt"
"math/rand"
"net/http"
"os"
"testing"
"time"
"github.com/go-test/deep"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/logging"
"github.com/hashicorp/vault/physical"
"github.com/ory/dockertest"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
)
func TestDynamoDBBackend(t *testing.T) {
cleanup, endpoint, credsProvider := prepareDynamoDBTestContainer(t)
defer cleanup()
creds, err := credsProvider.Get()
if err != nil {
t.Fatalf("err: %v", err)
}
region := os.Getenv("AWS_DEFAULT_REGION")
if region == "" {
region = "us-east-1"
}
awsSession, err := session.NewSession(&aws.Config{
Credentials: credsProvider,
Endpoint: aws.String(endpoint),
Region: aws.String(region),
})
if err != nil {
t.Fatalf("err: %v", err)
}
conn := dynamodb.New(awsSession)
var randInt = rand.New(rand.NewSource(time.Now().UnixNano())).Int()
table := fmt.Sprintf("vault-dynamodb-testacc-%d", randInt)
defer func() {
conn.DeleteTable(&dynamodb.DeleteTableInput{
TableName: aws.String(table),
})
}()
logger := logging.NewVaultLogger(log.Debug)
b, err := NewDynamoDBBackend(map[string]string{
"access_key": creds.AccessKeyID,
"secret_key": creds.SecretAccessKey,
"session_token": creds.SessionToken,
"table": table,
"region": region,
"endpoint": endpoint,
}, logger)
if err != nil {
t.Fatalf("err: %s", err)
}
physical.ExerciseBackend(t, b)
physical.ExerciseBackend_ListPrefix(t, b)
t.Run("Marshalling upgrade", func(t *testing.T) {
path := "test_key"
// Manually write to DynamoDB using the old ConvertTo function
// for marshalling data
inputEntry := &physical.Entry{
Key: path,
Value: []byte{0x0f, 0xcf, 0x4a, 0x0f, 0xba, 0x2b, 0x15, 0xf0, 0xaa, 0x75, 0x09},
}
record := DynamoDBRecord{
Path: recordPathForVaultKey(inputEntry.Key),
Key: recordKeyForVaultKey(inputEntry.Key),
Value: inputEntry.Value,
}
item, err := dynamodbattribute.ConvertToMap(record)
if err != nil {
t.Fatalf("err: %s", err)
}
request := &dynamodb.PutItemInput{
Item: item,
TableName: &table,
}
conn.PutItem(request)
// Read back the data using the normal interface which should
// handle the old marshalling format gracefully
entry, err := b.Get(context.Background(), path)
if err != nil {
t.Fatalf("err: %s", err)
}
if diff := deep.Equal(inputEntry, entry); diff != nil {
t.Fatal(diff)
}
})
}
func TestDynamoDBHABackend(t *testing.T) {
cleanup, endpoint, credsProvider := prepareDynamoDBTestContainer(t)
defer cleanup()
creds, err := credsProvider.Get()
if err != nil {
t.Fatalf("err: %v", err)
}
region := os.Getenv("AWS_DEFAULT_REGION")
if region == "" {
region = "us-east-1"
}
awsSession, err := session.NewSession(&aws.Config{
Credentials: credsProvider,
Endpoint: aws.String(endpoint),
Region: aws.String(region),
})
if err != nil {
t.Fatalf("err: %v", err)
}
conn := dynamodb.New(awsSession)
var randInt = rand.New(rand.NewSource(time.Now().UnixNano())).Int()
table := fmt.Sprintf("vault-dynamodb-testacc-%d", randInt)
defer func() {
conn.DeleteTable(&dynamodb.DeleteTableInput{
TableName: aws.String(table),
})
}()
logger := logging.NewVaultLogger(log.Debug)
b, err := NewDynamoDBBackend(map[string]string{
"access_key": creds.AccessKeyID,
"secret_key": creds.SecretAccessKey,
"session_token": creds.SessionToken,
"table": table,
"region": region,
"endpoint": endpoint,
}, logger)
if err != nil {
t.Fatalf("err: %s", err)
}
ha, ok := b.(physical.HABackend)
if !ok {
t.Fatalf("dynamodb does not implement HABackend")
}
physical.ExerciseHABackend(t, ha, ha)
testDynamoDBLockTTL(t, ha)
}
// Similar to testHABackend, but using internal implementation details to
// trigger the lock failure scenario by setting the lock renew period for one
// of the locks to a higher value than the lock TTL.
func testDynamoDBLockTTL(t *testing.T, ha physical.HABackend) {
// Set much smaller lock times to speed up the test.
lockTTL := time.Second * 3
renewInterval := time.Second * 1
watchInterval := time.Second * 1
// Get the lock
origLock, err := ha.LockWith("dynamodbttl", "bar")
if err != nil {
t.Fatalf("err: %v", err)
}
// set the first lock renew period to double the expected TTL.
lock := origLock.(*DynamoDBLock)
lock.renewInterval = lockTTL * 2
lock.ttl = lockTTL
lock.watchRetryInterval = watchInterval
// Attempt to lock
leaderCh, err := lock.Lock(nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if leaderCh == nil {
t.Fatalf("failed to get leader ch")
}
// Check the value
held, val, err := lock.Value()
if err != nil {
t.Fatalf("err: %v", err)
}
if !held {
t.Fatalf("should be held")
}
if val != "bar" {
t.Fatalf("bad value: %v", err)
}
// Second acquisition should succeed because the first lock should
// not renew within the 3 sec TTL.
origLock2, err := ha.LockWith("dynamodbttl", "baz")
if err != nil {
t.Fatalf("err: %v", err)
}
lock2 := origLock2.(*DynamoDBLock)
lock2.renewInterval = renewInterval
lock2.ttl = lockTTL
lock2.watchRetryInterval = watchInterval
// Cancel attempt in 6 sec so as not to block unit tests forever
stopCh := make(chan struct{})
time.AfterFunc(lockTTL*2, func() {
close(stopCh)
})
// Attempt to lock should work
leaderCh2, err := lock2.Lock(stopCh)
if err != nil {
t.Fatalf("err: %v", err)
}
if leaderCh2 == nil {
t.Fatalf("should get leader ch")
}
// Check the value
held, val, err = lock2.Value()
if err != nil {
t.Fatalf("err: %v", err)
}
if !held {
t.Fatalf("should be held")
}
if val != "baz" {
t.Fatalf("bad value: %v", err)
}
// The first lock should have lost the leader channel
leaderChClosed := false
blocking := make(chan struct{})
// Attempt to read from the leader or the blocking channel, which ever one
// happens first.
go func() {
select {
case <-time.After(watchInterval * 3):
return
case <-leaderCh:
leaderChClosed = true
close(blocking)
case <-blocking:
return
}
}()
<-blocking
if !leaderChClosed {
t.Fatalf("original lock did not have its leader channel closed.")
}
// Cleanup
lock2.Unlock()
}
func prepareDynamoDBTestContainer(t *testing.T) (cleanup func(), retAddress string, creds *credentials.Credentials) {
// If environment variable is set, assume caller wants to target a real
// DynamoDB.
if os.Getenv("AWS_DYNAMODB_ENDPOINT") != "" {
return func() {}, os.Getenv("AWS_DYNAMODB_ENDPOINT"), credentials.NewEnvCredentials()
}
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Failed to connect to docker: %s", err)
}
resource, err := pool.Run("cnadiminti/dynamodb-local", "latest", []string{})
if err != nil {
t.Fatalf("Could not start local DynamoDB: %s", err)
}
retAddress = "http://localhost:" + resource.GetPort("8000/tcp")
cleanup = func() {
err := pool.Purge(resource)
if err != nil {
t.Fatalf("Failed to cleanup local DynamoDB: %s", err)
}
}
// exponential backoff-retry, because the DynamoDB may not be able to accept
// connections yet
if err := pool.Retry(func() error {
var err error
resp, err := http.Get(retAddress)
if err != nil {
return err
}
if resp.StatusCode != 400 {
return fmt.Errorf("expected DynamoDB to return status code 400, got (%s) instead", resp.Status)
}
return nil
}); err != nil {
cleanup()
t.Fatalf("Could not connect to docker: %s", err)
}
return cleanup, retAddress, credentials.NewStaticCredentials("fake", "fake", "")
}