318 lines
7.6 KiB
Go
318 lines
7.6 KiB
Go
package dynamodb
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"net/http"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-test/deep"
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/helper/logging"
|
|
"github.com/hashicorp/vault/physical"
|
|
"github.com/ory/dockertest"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/dynamodb"
|
|
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
|
|
)
|
|
|
|
func TestDynamoDBBackend(t *testing.T) {
|
|
cleanup, endpoint, credsProvider := prepareDynamoDBTestContainer(t)
|
|
defer cleanup()
|
|
|
|
creds, err := credsProvider.Get()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
region := os.Getenv("AWS_DEFAULT_REGION")
|
|
if region == "" {
|
|
region = "us-east-1"
|
|
}
|
|
|
|
awsSession, err := session.NewSession(&aws.Config{
|
|
Credentials: credsProvider,
|
|
Endpoint: aws.String(endpoint),
|
|
Region: aws.String(region),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
conn := dynamodb.New(awsSession)
|
|
|
|
var randInt = rand.New(rand.NewSource(time.Now().UnixNano())).Int()
|
|
table := fmt.Sprintf("vault-dynamodb-testacc-%d", randInt)
|
|
|
|
defer func() {
|
|
conn.DeleteTable(&dynamodb.DeleteTableInput{
|
|
TableName: aws.String(table),
|
|
})
|
|
}()
|
|
|
|
logger := logging.NewVaultLogger(log.Debug)
|
|
|
|
b, err := NewDynamoDBBackend(map[string]string{
|
|
"access_key": creds.AccessKeyID,
|
|
"secret_key": creds.SecretAccessKey,
|
|
"session_token": creds.SessionToken,
|
|
"table": table,
|
|
"region": region,
|
|
"endpoint": endpoint,
|
|
}, logger)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
physical.ExerciseBackend(t, b)
|
|
physical.ExerciseBackend_ListPrefix(t, b)
|
|
|
|
t.Run("Marshalling upgrade", func(t *testing.T) {
|
|
path := "test_key"
|
|
|
|
// Manually write to DynamoDB using the old ConvertTo function
|
|
// for marshalling data
|
|
inputEntry := &physical.Entry{
|
|
Key: path,
|
|
Value: []byte{0x0f, 0xcf, 0x4a, 0x0f, 0xba, 0x2b, 0x15, 0xf0, 0xaa, 0x75, 0x09},
|
|
}
|
|
|
|
record := DynamoDBRecord{
|
|
Path: recordPathForVaultKey(inputEntry.Key),
|
|
Key: recordKeyForVaultKey(inputEntry.Key),
|
|
Value: inputEntry.Value,
|
|
}
|
|
|
|
item, err := dynamodbattribute.ConvertToMap(record)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
request := &dynamodb.PutItemInput{
|
|
Item: item,
|
|
TableName: &table,
|
|
}
|
|
conn.PutItem(request)
|
|
|
|
// Read back the data using the normal interface which should
|
|
// handle the old marshalling format gracefully
|
|
entry, err := b.Get(context.Background(), path)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
if diff := deep.Equal(inputEntry, entry); diff != nil {
|
|
t.Fatal(diff)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDynamoDBHABackend(t *testing.T) {
|
|
cleanup, endpoint, credsProvider := prepareDynamoDBTestContainer(t)
|
|
defer cleanup()
|
|
|
|
creds, err := credsProvider.Get()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
region := os.Getenv("AWS_DEFAULT_REGION")
|
|
if region == "" {
|
|
region = "us-east-1"
|
|
}
|
|
|
|
awsSession, err := session.NewSession(&aws.Config{
|
|
Credentials: credsProvider,
|
|
Endpoint: aws.String(endpoint),
|
|
Region: aws.String(region),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
conn := dynamodb.New(awsSession)
|
|
|
|
var randInt = rand.New(rand.NewSource(time.Now().UnixNano())).Int()
|
|
table := fmt.Sprintf("vault-dynamodb-testacc-%d", randInt)
|
|
|
|
defer func() {
|
|
conn.DeleteTable(&dynamodb.DeleteTableInput{
|
|
TableName: aws.String(table),
|
|
})
|
|
}()
|
|
|
|
logger := logging.NewVaultLogger(log.Debug)
|
|
b, err := NewDynamoDBBackend(map[string]string{
|
|
"access_key": creds.AccessKeyID,
|
|
"secret_key": creds.SecretAccessKey,
|
|
"session_token": creds.SessionToken,
|
|
"table": table,
|
|
"region": region,
|
|
"endpoint": endpoint,
|
|
}, logger)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
ha, ok := b.(physical.HABackend)
|
|
if !ok {
|
|
t.Fatalf("dynamodb does not implement HABackend")
|
|
}
|
|
physical.ExerciseHABackend(t, ha, ha)
|
|
testDynamoDBLockTTL(t, ha)
|
|
}
|
|
|
|
// Similar to testHABackend, but using internal implementation details to
|
|
// trigger the lock failure scenario by setting the lock renew period for one
|
|
// of the locks to a higher value than the lock TTL.
|
|
func testDynamoDBLockTTL(t *testing.T, ha physical.HABackend) {
|
|
// Set much smaller lock times to speed up the test.
|
|
lockTTL := time.Second * 3
|
|
renewInterval := time.Second * 1
|
|
watchInterval := time.Second * 1
|
|
|
|
// Get the lock
|
|
origLock, err := ha.LockWith("dynamodbttl", "bar")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
// set the first lock renew period to double the expected TTL.
|
|
lock := origLock.(*DynamoDBLock)
|
|
lock.renewInterval = lockTTL * 2
|
|
lock.ttl = lockTTL
|
|
lock.watchRetryInterval = watchInterval
|
|
|
|
// Attempt to lock
|
|
leaderCh, err := lock.Lock(nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if leaderCh == nil {
|
|
t.Fatalf("failed to get leader ch")
|
|
}
|
|
|
|
// Check the value
|
|
held, val, err := lock.Value()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !held {
|
|
t.Fatalf("should be held")
|
|
}
|
|
if val != "bar" {
|
|
t.Fatalf("bad value: %v", err)
|
|
}
|
|
|
|
// Second acquisition should succeed because the first lock should
|
|
// not renew within the 3 sec TTL.
|
|
origLock2, err := ha.LockWith("dynamodbttl", "baz")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
lock2 := origLock2.(*DynamoDBLock)
|
|
lock2.renewInterval = renewInterval
|
|
lock2.ttl = lockTTL
|
|
lock2.watchRetryInterval = watchInterval
|
|
|
|
// Cancel attempt in 6 sec so as not to block unit tests forever
|
|
stopCh := make(chan struct{})
|
|
time.AfterFunc(lockTTL*2, func() {
|
|
close(stopCh)
|
|
})
|
|
|
|
// Attempt to lock should work
|
|
leaderCh2, err := lock2.Lock(stopCh)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if leaderCh2 == nil {
|
|
t.Fatalf("should get leader ch")
|
|
}
|
|
|
|
// Check the value
|
|
held, val, err = lock2.Value()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !held {
|
|
t.Fatalf("should be held")
|
|
}
|
|
if val != "baz" {
|
|
t.Fatalf("bad value: %v", err)
|
|
}
|
|
|
|
// The first lock should have lost the leader channel
|
|
leaderChClosed := false
|
|
blocking := make(chan struct{})
|
|
// Attempt to read from the leader or the blocking channel, which ever one
|
|
// happens first.
|
|
go func() {
|
|
select {
|
|
case <-time.After(watchInterval * 3):
|
|
return
|
|
case <-leaderCh:
|
|
leaderChClosed = true
|
|
close(blocking)
|
|
case <-blocking:
|
|
return
|
|
}
|
|
}()
|
|
|
|
<-blocking
|
|
if !leaderChClosed {
|
|
t.Fatalf("original lock did not have its leader channel closed.")
|
|
}
|
|
|
|
// Cleanup
|
|
lock2.Unlock()
|
|
}
|
|
|
|
func prepareDynamoDBTestContainer(t *testing.T) (cleanup func(), retAddress string, creds *credentials.Credentials) {
|
|
// If environment variable is set, assume caller wants to target a real
|
|
// DynamoDB.
|
|
if os.Getenv("AWS_DYNAMODB_ENDPOINT") != "" {
|
|
return func() {}, os.Getenv("AWS_DYNAMODB_ENDPOINT"), credentials.NewEnvCredentials()
|
|
}
|
|
|
|
pool, err := dockertest.NewPool("")
|
|
if err != nil {
|
|
t.Fatalf("Failed to connect to docker: %s", err)
|
|
}
|
|
|
|
resource, err := pool.Run("cnadiminti/dynamodb-local", "latest", []string{})
|
|
if err != nil {
|
|
t.Fatalf("Could not start local DynamoDB: %s", err)
|
|
}
|
|
|
|
retAddress = "http://localhost:" + resource.GetPort("8000/tcp")
|
|
cleanup = func() {
|
|
err := pool.Purge(resource)
|
|
if err != nil {
|
|
t.Fatalf("Failed to cleanup local DynamoDB: %s", err)
|
|
}
|
|
}
|
|
|
|
// exponential backoff-retry, because the DynamoDB may not be able to accept
|
|
// connections yet
|
|
if err := pool.Retry(func() error {
|
|
var err error
|
|
resp, err := http.Get(retAddress)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if resp.StatusCode != 400 {
|
|
return fmt.Errorf("expected DynamoDB to return status code 400, got (%s) instead", resp.Status)
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
t.Fatalf("Could not connect to docker: %s", err)
|
|
}
|
|
return cleanup, retAddress, credentials.NewStaticCredentials("fake", "fake", "")
|
|
}
|