419 lines
9.8 KiB
Go
419 lines
9.8 KiB
Go
package dynamodb
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-test/deep"
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/helper/testhelpers/docker"
|
|
"github.com/hashicorp/vault/sdk/helper/logging"
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/dynamodb"
|
|
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
|
|
)
|
|
|
|
func TestDynamoDBBackend(t *testing.T) {
|
|
cleanup, svccfg := prepareDynamoDBTestContainer(t)
|
|
defer cleanup()
|
|
|
|
creds, err := svccfg.Credentials.Get()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
region := os.Getenv("AWS_DEFAULT_REGION")
|
|
if region == "" {
|
|
region = "us-east-1"
|
|
}
|
|
|
|
awsSession, err := session.NewSession(&aws.Config{
|
|
Credentials: svccfg.Credentials,
|
|
Endpoint: aws.String(svccfg.URL().String()),
|
|
Region: aws.String(region),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
conn := dynamodb.New(awsSession)
|
|
|
|
randInt := rand.New(rand.NewSource(time.Now().UnixNano())).Int()
|
|
table := fmt.Sprintf("vault-dynamodb-testacc-%d", randInt)
|
|
|
|
defer func() {
|
|
conn.DeleteTable(&dynamodb.DeleteTableInput{
|
|
TableName: aws.String(table),
|
|
})
|
|
}()
|
|
|
|
logger := logging.NewVaultLogger(log.Debug)
|
|
|
|
b, err := NewDynamoDBBackend(map[string]string{
|
|
"access_key": creds.AccessKeyID,
|
|
"secret_key": creds.SecretAccessKey,
|
|
"session_token": creds.SessionToken,
|
|
"table": table,
|
|
"region": region,
|
|
"endpoint": svccfg.URL().String(),
|
|
}, logger)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
physical.ExerciseBackend(t, b)
|
|
physical.ExerciseBackend_ListPrefix(t, b)
|
|
|
|
t.Run("Marshalling upgrade", func(t *testing.T) {
|
|
path := "test_key"
|
|
|
|
// Manually write to DynamoDB using the old ConvertTo function
|
|
// for marshalling data
|
|
inputEntry := &physical.Entry{
|
|
Key: path,
|
|
Value: []byte{0x0f, 0xcf, 0x4a, 0x0f, 0xba, 0x2b, 0x15, 0xf0, 0xaa, 0x75, 0x09},
|
|
}
|
|
|
|
record := DynamoDBRecord{
|
|
Path: recordPathForVaultKey(inputEntry.Key),
|
|
Key: recordKeyForVaultKey(inputEntry.Key),
|
|
Value: inputEntry.Value,
|
|
}
|
|
|
|
item, err := dynamodbattribute.ConvertToMap(record)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
request := &dynamodb.PutItemInput{
|
|
Item: item,
|
|
TableName: &table,
|
|
}
|
|
conn.PutItem(request)
|
|
|
|
// Read back the data using the normal interface which should
|
|
// handle the old marshalling format gracefully
|
|
entry, err := b.Get(context.Background(), path)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
if diff := deep.Equal(inputEntry, entry); diff != nil {
|
|
t.Fatal(diff)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDynamoDBHABackend(t *testing.T) {
|
|
cleanup, svccfg := prepareDynamoDBTestContainer(t)
|
|
defer cleanup()
|
|
|
|
creds, err := svccfg.Credentials.Get()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
region := os.Getenv("AWS_DEFAULT_REGION")
|
|
if region == "" {
|
|
region = "us-east-1"
|
|
}
|
|
|
|
awsSession, err := session.NewSession(&aws.Config{
|
|
Credentials: svccfg.Credentials,
|
|
Endpoint: aws.String(svccfg.URL().String()),
|
|
Region: aws.String(region),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
conn := dynamodb.New(awsSession)
|
|
|
|
randInt := rand.New(rand.NewSource(time.Now().UnixNano())).Int()
|
|
table := fmt.Sprintf("vault-dynamodb-testacc-%d", randInt)
|
|
|
|
defer func() {
|
|
conn.DeleteTable(&dynamodb.DeleteTableInput{
|
|
TableName: aws.String(table),
|
|
})
|
|
}()
|
|
|
|
logger := logging.NewVaultLogger(log.Debug)
|
|
config := map[string]string{
|
|
"access_key": creds.AccessKeyID,
|
|
"secret_key": creds.SecretAccessKey,
|
|
"session_token": creds.SessionToken,
|
|
"table": table,
|
|
"region": region,
|
|
"endpoint": svccfg.URL().String(),
|
|
}
|
|
|
|
b, err := NewDynamoDBBackend(config, logger)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
b2, err := NewDynamoDBBackend(config, logger)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
physical.ExerciseHABackend(t, b.(physical.HABackend), b2.(physical.HABackend))
|
|
testDynamoDBLockTTL(t, b.(physical.HABackend))
|
|
testDynamoDBLockRenewal(t, b.(physical.HABackend))
|
|
}
|
|
|
|
// Similar to testHABackend, but using internal implementation details to
|
|
// trigger the lock failure scenario by setting the lock renew period for one
|
|
// of the locks to a higher value than the lock TTL.
|
|
func testDynamoDBLockTTL(t *testing.T, ha physical.HABackend) {
|
|
// Set much smaller lock times to speed up the test.
|
|
lockTTL := time.Second * 3
|
|
renewInterval := time.Second * 1
|
|
watchInterval := time.Second * 1
|
|
|
|
// Get the lock
|
|
origLock, err := ha.LockWith("dynamodbttl", "bar")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
// set the first lock renew period to double the expected TTL.
|
|
lock := origLock.(*DynamoDBLock)
|
|
lock.renewInterval = lockTTL * 2
|
|
lock.ttl = lockTTL
|
|
lock.watchRetryInterval = watchInterval
|
|
|
|
// Attempt to lock
|
|
leaderCh, err := lock.Lock(nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if leaderCh == nil {
|
|
t.Fatalf("failed to get leader ch")
|
|
}
|
|
|
|
// Check the value
|
|
held, val, err := lock.Value()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !held {
|
|
t.Fatalf("should be held")
|
|
}
|
|
if val != "bar" {
|
|
t.Fatalf("bad value: %v", err)
|
|
}
|
|
|
|
// Second acquisition should succeed because the first lock should
|
|
// not renew within the 3 sec TTL.
|
|
origLock2, err := ha.LockWith("dynamodbttl", "baz")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
lock2 := origLock2.(*DynamoDBLock)
|
|
lock2.renewInterval = renewInterval
|
|
lock2.ttl = lockTTL
|
|
lock2.watchRetryInterval = watchInterval
|
|
|
|
// Cancel attempt eventually so as not to block unit tests forever
|
|
stopCh := make(chan struct{})
|
|
time.AfterFunc(lockTTL*10, func() {
|
|
close(stopCh)
|
|
})
|
|
|
|
// Attempt to lock should work
|
|
leaderCh2, err := lock2.Lock(stopCh)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if leaderCh2 == nil {
|
|
t.Fatalf("should get leader ch")
|
|
}
|
|
|
|
// Check the value
|
|
held, val, err = lock2.Value()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !held {
|
|
t.Fatalf("should be held")
|
|
}
|
|
if val != "baz" {
|
|
t.Fatalf("bad value: %v", err)
|
|
}
|
|
|
|
// The first lock should have lost the leader channel
|
|
leaderChClosed := false
|
|
blocking := make(chan struct{})
|
|
// Attempt to read from the leader or the blocking channel, which ever one
|
|
// happens first.
|
|
go func() {
|
|
select {
|
|
case <-time.After(watchInterval * 3):
|
|
return
|
|
case <-leaderCh:
|
|
leaderChClosed = true
|
|
close(blocking)
|
|
case <-blocking:
|
|
return
|
|
}
|
|
}()
|
|
|
|
<-blocking
|
|
if !leaderChClosed {
|
|
t.Fatalf("original lock did not have its leader channel closed.")
|
|
}
|
|
|
|
// Cleanup
|
|
lock2.Unlock()
|
|
}
|
|
|
|
// Similar to testHABackend, but using internal implementation details to
|
|
// trigger a renewal before a "watch" check, which has been a source of
|
|
// race conditions.
|
|
func testDynamoDBLockRenewal(t *testing.T, ha physical.HABackend) {
|
|
renewInterval := time.Second * 1
|
|
watchInterval := time.Second * 5
|
|
|
|
// Get the lock
|
|
origLock, err := ha.LockWith("dynamodbrenewal", "bar")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
// customize the renewal and watch intervals
|
|
lock := origLock.(*DynamoDBLock)
|
|
lock.renewInterval = renewInterval
|
|
lock.watchRetryInterval = watchInterval
|
|
|
|
// Attempt to lock
|
|
leaderCh, err := lock.Lock(nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if leaderCh == nil {
|
|
t.Fatalf("failed to get leader ch")
|
|
}
|
|
|
|
// Check the value
|
|
held, val, err := lock.Value()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !held {
|
|
t.Fatalf("should be held")
|
|
}
|
|
if val != "bar" {
|
|
t.Fatalf("bad value: %v", err)
|
|
}
|
|
|
|
// Release the lock, which will delete the stored item
|
|
if err := lock.Unlock(); err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
// Wait longer than the renewal time, but less than the watch time
|
|
time.Sleep(1500 * time.Millisecond)
|
|
|
|
// Attempt to lock with new lock
|
|
newLock, err := ha.LockWith("dynamodbrenewal", "baz")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
// Cancel attempt in 6 sec so as not to block unit tests forever
|
|
stopCh := make(chan struct{})
|
|
time.AfterFunc(6*time.Second, func() {
|
|
close(stopCh)
|
|
})
|
|
|
|
// Attempt to lock should work
|
|
leaderCh2, err := newLock.Lock(stopCh)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if leaderCh2 == nil {
|
|
t.Fatalf("should get leader ch")
|
|
}
|
|
|
|
// Check the value
|
|
held, val, err = newLock.Value()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !held {
|
|
t.Fatalf("should be held")
|
|
}
|
|
if val != "baz" {
|
|
t.Fatalf("bad value: %v", err)
|
|
}
|
|
|
|
// Cleanup
|
|
newLock.Unlock()
|
|
}
|
|
|
|
type Config struct {
|
|
docker.ServiceURL
|
|
Credentials *credentials.Credentials
|
|
}
|
|
|
|
var _ docker.ServiceConfig = &Config{}
|
|
|
|
func prepareDynamoDBTestContainer(t *testing.T) (func(), *Config) {
|
|
// If environment variable is set, assume caller wants to target a real
|
|
// DynamoDB.
|
|
if endpoint := os.Getenv("AWS_DYNAMODB_ENDPOINT"); endpoint != "" {
|
|
s, err := docker.NewServiceURLParse(endpoint)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return func() {}, &Config{*s, credentials.NewEnvCredentials()}
|
|
}
|
|
|
|
runner, err := docker.NewServiceRunner(docker.RunOptions{
|
|
ImageRepo: "cnadiminti/dynamodb-local",
|
|
ImageTag: "latest",
|
|
ContainerName: "dynamodb",
|
|
Ports: []string{"8000/tcp"},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Could not start local DynamoDB: %s", err)
|
|
}
|
|
|
|
svc, err := runner.StartService(context.Background(), connectDynamoDB)
|
|
if err != nil {
|
|
t.Fatalf("Could not start local DynamoDB: %s", err)
|
|
}
|
|
|
|
return svc.Cleanup, svc.Config.(*Config)
|
|
}
|
|
|
|
func connectDynamoDB(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
|
|
u := url.URL{
|
|
Scheme: "http",
|
|
Host: fmt.Sprintf("%s:%d", host, port),
|
|
}
|
|
resp, err := http.Get(u.String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.StatusCode != 400 {
|
|
return nil, err
|
|
}
|
|
|
|
return &Config{
|
|
ServiceURL: *docker.NewServiceURL(u),
|
|
Credentials: credentials.NewStaticCredentials("fake", "fake", ""),
|
|
}, nil
|
|
}
|