package dynamodb import ( "context" "fmt" "math" "net/http" "os" pkgPath "path" "sort" "strconv" "strings" "sync" "time" log "github.com/hashicorp/go-hclog" "github.com/armon/go-metrics" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" "github.com/hashicorp/errwrap" cleanhttp "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/awsutil" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/physical" ) const ( // DefaultDynamoDBRegion is used when no region is configured // explicitly. DefaultDynamoDBRegion = "us-east-1" // DefaultDynamoDBTableName is used when no table name // is configured explicitly. DefaultDynamoDBTableName = "vault-dynamodb-backend" // DefaultDynamoDBReadCapacity is the default read capacity // that is used when none is configured explicitly. DefaultDynamoDBReadCapacity = 5 // DefaultDynamoDBWriteCapacity is the default write capacity // that is used when none is configured explicitly. DefaultDynamoDBWriteCapacity = 5 // DynamoDBEmptyPath is the string that is used instead of // empty strings when stored in DynamoDB. DynamoDBEmptyPath = " " // DynamoDBLockPrefix is the prefix used to mark DynamoDB records // as locks. This prefix causes them not to be returned by // List operations. DynamoDBLockPrefix = "_" // The lock TTL matches the default that Consul API uses, 15 seconds. DynamoDBLockTTL = 15 * time.Second // The amount of time to wait between the lock renewals DynamoDBLockRenewInterval = 5 * time.Second // DynamoDBLockRetryInterval is the amount of time to wait // if a lock fails before trying again. DynamoDBLockRetryInterval = time.Second // DynamoDBWatchRetryMax is the number of times to re-try a // failed watch before signaling that leadership is lost. DynamoDBWatchRetryMax = 5 // DynamoDBWatchRetryInterval is the amount of time to wait // if a watch fails before trying again. DynamoDBWatchRetryInterval = 5 * time.Second ) // Verify DynamoDBBackend satisfies the correct interfaces var _ physical.Backend = (*DynamoDBBackend)(nil) var _ physical.HABackend = (*DynamoDBBackend)(nil) var _ physical.Lock = (*DynamoDBLock)(nil) // DynamoDBBackend is a physical backend that stores data in // a DynamoDB table. It can be run in high-availability mode // as DynamoDB has locking capabilities. type DynamoDBBackend struct { table string client *dynamodb.DynamoDB logger log.Logger haEnabled bool permitPool *physical.PermitPool } // DynamoDBRecord is the representation of a vault entry in // DynamoDB. The vault key is split up into two components // (Path and Key) in order to allow more efficient listings. type DynamoDBRecord struct { Path string Key string Value []byte } // DynamoDBLock implements a lock using an DynamoDB client. type DynamoDBLock struct { backend *DynamoDBBackend value, key string identity string held bool lock sync.Mutex // Allow modifying the Lock durations for ease of unit testing. renewInterval time.Duration ttl time.Duration watchRetryInterval time.Duration } type DynamoDBLockRecord struct { Path string Key string Value []byte Identity []byte Expires int64 } // NewDynamoDBBackend constructs a DynamoDB backend. If the // configured DynamoDB table does not exist, it creates it. func NewDynamoDBBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { table := os.Getenv("AWS_DYNAMODB_TABLE") if table == "" { table = conf["table"] if table == "" { table = DefaultDynamoDBTableName } } readCapacityString := os.Getenv("AWS_DYNAMODB_READ_CAPACITY") if readCapacityString == "" { readCapacityString = conf["read_capacity"] if readCapacityString == "" { readCapacityString = "0" } } readCapacity, err := strconv.Atoi(readCapacityString) if err != nil { return nil, fmt.Errorf("invalid read capacity: %q", readCapacityString) } if readCapacity == 0 { readCapacity = DefaultDynamoDBReadCapacity } writeCapacityString := os.Getenv("AWS_DYNAMODB_WRITE_CAPACITY") if writeCapacityString == "" { writeCapacityString = conf["write_capacity"] if writeCapacityString == "" { writeCapacityString = "0" } } writeCapacity, err := strconv.Atoi(writeCapacityString) if err != nil { return nil, fmt.Errorf("invalid write capacity: %q", writeCapacityString) } if writeCapacity == 0 { writeCapacity = DefaultDynamoDBWriteCapacity } accessKey := os.Getenv("AWS_ACCESS_KEY_ID") if accessKey == "" { accessKey = conf["access_key"] } secretKey := os.Getenv("AWS_SECRET_ACCESS_KEY") if secretKey == "" { secretKey = conf["secret_key"] } sessionToken := os.Getenv("AWS_SESSION_TOKEN") if sessionToken == "" { sessionToken = conf["session_token"] } endpoint := os.Getenv("AWS_DYNAMODB_ENDPOINT") if endpoint == "" { endpoint = conf["endpoint"] } region := os.Getenv("AWS_REGION") if region == "" { region = os.Getenv("AWS_DEFAULT_REGION") if region == "" { region = conf["region"] if region == "" { region = DefaultDynamoDBRegion } } } dynamodbMaxRetryString := os.Getenv("AWS_DYNAMODB_MAX_RETRIES") if dynamodbMaxRetryString == "" { dynamodbMaxRetryString = conf["dynamodb_max_retries"] } var dynamodbMaxRetry int = aws.UseServiceDefaultRetries if dynamodbMaxRetryString != "" { var err error dynamodbMaxRetry, err = strconv.Atoi(dynamodbMaxRetryString) if err != nil { return nil, fmt.Errorf("invalid max retry: %q", dynamodbMaxRetryString) } } credsConfig := &awsutil.CredentialsConfig{ AccessKey: accessKey, SecretKey: secretKey, SessionToken: sessionToken, } creds, err := credsConfig.GenerateCredentialChain() if err != nil { return nil, err } pooledTransport := cleanhttp.DefaultPooledTransport() pooledTransport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount awsConf := aws.NewConfig(). WithCredentials(creds). WithRegion(region). WithEndpoint(endpoint). WithHTTPClient(&http.Client{ Transport: pooledTransport, }). WithMaxRetries(dynamodbMaxRetry) awsSession, err := session.NewSession(awsConf) if err != nil { return nil, errwrap.Wrapf("Could not establish AWS session: {{err}}", err) } client := dynamodb.New(awsSession) if err := ensureTableExists(client, table, readCapacity, writeCapacity); err != nil { return nil, err } haEnabled := os.Getenv("DYNAMODB_HA_ENABLED") if haEnabled == "" { haEnabled = conf["ha_enabled"] } haEnabledBool, _ := strconv.ParseBool(haEnabled) maxParStr, ok := conf["max_parallel"] var maxParInt int if ok { maxParInt, err = strconv.Atoi(maxParStr) if err != nil { return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err) } if logger.IsDebug() { logger.Debug("physical/dynamodb: max_parallel set", "max_parallel", maxParInt) } } return &DynamoDBBackend{ table: table, client: client, permitPool: physical.NewPermitPool(maxParInt), haEnabled: haEnabledBool, logger: logger, }, nil } // Put is used to insert or update an entry func (d *DynamoDBBackend) Put(ctx context.Context, entry *physical.Entry) error { defer metrics.MeasureSince([]string{"dynamodb", "put"}, time.Now()) record := DynamoDBRecord{ Path: recordPathForVaultKey(entry.Key), Key: recordKeyForVaultKey(entry.Key), Value: entry.Value, } item, err := dynamodbattribute.MarshalMap(record) if err != nil { return errwrap.Wrapf("could not convert prefix record to DynamoDB item: {{err}}", err) } requests := []*dynamodb.WriteRequest{{ PutRequest: &dynamodb.PutRequest{ Item: item, }, }} for _, prefix := range physical.Prefixes(entry.Key) { record = DynamoDBRecord{ Path: recordPathForVaultKey(prefix), Key: fmt.Sprintf("%s/", recordKeyForVaultKey(prefix)), } item, err := dynamodbattribute.MarshalMap(record) if err != nil { return errwrap.Wrapf("could not convert prefix record to DynamoDB item: {{err}}", err) } requests = append(requests, &dynamodb.WriteRequest{ PutRequest: &dynamodb.PutRequest{ Item: item, }, }) } return d.batchWriteRequests(requests) } // Get is used to fetch an entry func (d *DynamoDBBackend) Get(ctx context.Context, key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"dynamodb", "get"}, time.Now()) d.permitPool.Acquire() defer d.permitPool.Release() resp, err := d.client.GetItem(&dynamodb.GetItemInput{ TableName: aws.String(d.table), ConsistentRead: aws.Bool(true), Key: map[string]*dynamodb.AttributeValue{ "Path": {S: aws.String(recordPathForVaultKey(key))}, "Key": {S: aws.String(recordKeyForVaultKey(key))}, }, }) if err != nil { return nil, err } if resp.Item == nil { return nil, nil } record := &DynamoDBRecord{} if err := dynamodbattribute.UnmarshalMap(resp.Item, record); err != nil { return nil, err } return &physical.Entry{ Key: vaultKey(record), Value: record.Value, }, nil } // Delete is used to permanently delete an entry func (d *DynamoDBBackend) Delete(ctx context.Context, key string) error { defer metrics.MeasureSince([]string{"dynamodb", "delete"}, time.Now()) requests := []*dynamodb.WriteRequest{{ DeleteRequest: &dynamodb.DeleteRequest{ Key: map[string]*dynamodb.AttributeValue{ "Path": {S: aws.String(recordPathForVaultKey(key))}, "Key": {S: aws.String(recordKeyForVaultKey(key))}, }, }, }} // Clean up empty "folders" by looping through all levels of the path to the item being deleted looking for // children. Loop from deepest path to shallowest, and only consider items children if they are not going to be // deleted by our batch delete request. If a path has no valid children, then it should be considered an empty // "folder" and be deleted along with the original item in our batch job. Because we loop from deepest path to // shallowest, once we find a path level that contains valid children we can stop the cleanup operation. prefixes := physical.Prefixes(key) sort.Sort(sort.Reverse(sort.StringSlice(prefixes))) for index, prefix := range prefixes { // Because delete batches its requests, we need to pass keys we know are going to be deleted into // hasChildren so it can exclude those when it determines if there WILL be any children left after // the delete operations have completed. var excluded []string if index == 0 { // This is the value we know for sure is being deleted excluded = append(excluded, recordKeyForVaultKey(key)) } else { // The previous path doesn't count as a child, since if we're still looping, we've found no children excluded = append(excluded, recordKeyForVaultKey(prefixes[index-1])) } hasChildren, err := d.hasChildren(prefix, excluded) if err != nil { return err } if !hasChildren { // If there are no children other than ones we know are being deleted then cleanup empty "folder" pointers requests = append(requests, &dynamodb.WriteRequest{ DeleteRequest: &dynamodb.DeleteRequest{ Key: map[string]*dynamodb.AttributeValue{ "Path": {S: aws.String(recordPathForVaultKey(prefix))}, "Key": {S: aws.String(fmt.Sprintf("%s/", recordKeyForVaultKey(prefix)))}, }, }, }) } else { // This loop starts at the deepest path and works backwards looking for children // once a deeper level of the path has been found to have children there is no // more cleanup that needs to happen, otherwise we might remove folder pointers // to that deeper path making it "undiscoverable" with the list operation break } } return d.batchWriteRequests(requests) } // List is used to list all the keys under a given // prefix, up to the next prefix. func (d *DynamoDBBackend) List(ctx context.Context, prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"dynamodb", "list"}, time.Now()) prefix = strings.TrimSuffix(prefix, "/") keys := []string{} prefix = escapeEmptyPath(prefix) queryInput := &dynamodb.QueryInput{ TableName: aws.String(d.table), ConsistentRead: aws.Bool(true), KeyConditions: map[string]*dynamodb.Condition{ "Path": { ComparisonOperator: aws.String("EQ"), AttributeValueList: []*dynamodb.AttributeValue{{ S: aws.String(prefix), }}, }, }, } d.permitPool.Acquire() defer d.permitPool.Release() err := d.client.QueryPages(queryInput, func(out *dynamodb.QueryOutput, lastPage bool) bool { var record DynamoDBRecord for _, item := range out.Items { dynamodbattribute.UnmarshalMap(item, &record) if !strings.HasPrefix(record.Key, DynamoDBLockPrefix) { keys = append(keys, record.Key) } } return !lastPage }) if err != nil { return nil, err } return keys, nil } // hasChildren returns true if there exist items below a certain path prefix. // To do so, the method fetches such items from DynamoDB. This method is primarily // used by Delete. Because DynamoDB requests are batched this method is being called // before any deletes take place. To account for that hasChildren accepts a slice of // strings representing values we expect to find that should NOT be counted as children // because they are going to be deleted. func (d *DynamoDBBackend) hasChildren(prefix string, exclude []string) (bool, error) { prefix = strings.TrimSuffix(prefix, "/") prefix = escapeEmptyPath(prefix) queryInput := &dynamodb.QueryInput{ TableName: aws.String(d.table), ConsistentRead: aws.Bool(true), KeyConditions: map[string]*dynamodb.Condition{ "Path": { ComparisonOperator: aws.String("EQ"), AttributeValueList: []*dynamodb.AttributeValue{{ S: aws.String(prefix), }}, }, }, // Avoid fetching too many items from DynamoDB for performance reasons. // We want to know if there are any children we don't expect to see. // Answering that question requires fetching a minimum of one more item // than the number we expect. In most cases this value will be 2 Limit: aws.Int64(int64(len(exclude) + 1)), } d.permitPool.Acquire() defer d.permitPool.Release() out, err := d.client.Query(queryInput) if err != nil { return false, err } var childrenExist bool for _, item := range out.Items { for _, excluded := range exclude { // Check if we've found an item we didn't expect to. Look for "folder" pointer keys (trailing slash) // and regular value keys (no trailing slash) if *item["Key"].S != excluded && *item["Key"].S != fmt.Sprintf("%s/", excluded) { childrenExist = true break } } if childrenExist { // We only need to find ONE child we didn't expect to. break } } return childrenExist, nil } // LockWith is used for mutual exclusion based on the given key. func (d *DynamoDBBackend) LockWith(key, value string) (physical.Lock, error) { identity, err := uuid.GenerateUUID() if err != nil { return nil, err } return &DynamoDBLock{ backend: d, key: pkgPath.Join(pkgPath.Dir(key), DynamoDBLockPrefix+pkgPath.Base(key)), value: value, identity: identity, renewInterval: DynamoDBLockRenewInterval, ttl: DynamoDBLockTTL, watchRetryInterval: DynamoDBWatchRetryInterval, }, nil } func (d *DynamoDBBackend) HAEnabled() bool { return d.haEnabled } // batchWriteRequests takes a list of write requests and executes them in badges // with a maximum size of 25 (which is the limit of BatchWriteItem requests). func (d *DynamoDBBackend) batchWriteRequests(requests []*dynamodb.WriteRequest) error { for len(requests) > 0 { batchSize := int(math.Min(float64(len(requests)), 25)) batch := requests[:batchSize] requests = requests[batchSize:] d.permitPool.Acquire() _, err := d.client.BatchWriteItem(&dynamodb.BatchWriteItemInput{ RequestItems: map[string][]*dynamodb.WriteRequest{ d.table: batch, }, }) d.permitPool.Release() if err != nil { return err } } return nil } // Lock tries to acquire the lock by repeatedly trying to create // a record in the DynamoDB table. It will block until either the // stop channel is closed or the lock could be acquired successfully. // The returned channel will be closed once the lock is deleted or // changed in the DynamoDB table. func (l *DynamoDBLock) Lock(stopCh <-chan struct{}) (doneCh <-chan struct{}, retErr error) { l.lock.Lock() defer l.lock.Unlock() if l.held { return nil, fmt.Errorf("lock already held") } done := make(chan struct{}) // close done channel even in case of error defer func() { if retErr != nil { close(done) } }() var ( stop = make(chan struct{}) success = make(chan struct{}) errors = make(chan error) leader = make(chan struct{}) ) // try to acquire the lock asynchronously go l.tryToLock(stop, success, errors) select { case <-success: l.held = true // after acquiring it successfully, we must renew the lock periodically, // and watch the lock in order to close the leader channel // once it is lost. go l.periodicallyRenewLock(leader) go l.watch(leader) case retErr = <-errors: close(stop) return nil, retErr case <-stopCh: close(stop) return nil, nil } return leader, retErr } // Unlock releases the lock by deleting the lock record from the // DynamoDB table. func (l *DynamoDBLock) Unlock() error { l.lock.Lock() defer l.lock.Unlock() if !l.held { return nil } l.held = false if err := l.backend.Delete(context.Background(), l.key); err != nil { return err } return nil } // Value checks whether or not the lock is held by any instance of DynamoDBLock, // including this one, and returns the current value. func (l *DynamoDBLock) Value() (bool, string, error) { entry, err := l.backend.Get(context.Background(), l.key) if err != nil { return false, "", err } if entry == nil { return false, "", nil } return true, string(entry.Value), nil } // tryToLock tries to create a new item in DynamoDB // every `DynamoDBLockRetryInterval`. As long as the item // cannot be created (because it already exists), it will // be retried. If the operation fails due to an error, it // is sent to the errors channel. // When the lock could be acquired successfully, the success // channel is closed. func (l *DynamoDBLock) tryToLock(stop, success chan struct{}, errors chan error) { ticker := time.NewTicker(DynamoDBLockRetryInterval) for { select { case <-stop: ticker.Stop() case <-ticker.C: err := l.writeItem() if err != nil { if err, ok := err.(awserr.Error); ok { // Don't report a condition check failure, this means that the lock // is already being held. if err.Code() != dynamodb.ErrCodeConditionalCheckFailedException { errors <- err } } else { // Its not an AWS error, and is probably not transient, bail out. errors <- err return } } else { ticker.Stop() close(success) return } } } } func (l *DynamoDBLock) periodicallyRenewLock(done chan struct{}) { ticker := time.NewTicker(l.renewInterval) for { select { case <-ticker.C: l.writeItem() case <-done: ticker.Stop() return } } } // Attempts to put/update the dynamodb item using condition expressions to // evaluate the TTL. func (l *DynamoDBLock) writeItem() error { now := time.Now() _, err := l.backend.client.UpdateItem(&dynamodb.UpdateItemInput{ TableName: aws.String(l.backend.table), Key: map[string]*dynamodb.AttributeValue{ "Path": &dynamodb.AttributeValue{S: aws.String(recordPathForVaultKey(l.key))}, "Key": &dynamodb.AttributeValue{S: aws.String(recordKeyForVaultKey(l.key))}, }, UpdateExpression: aws.String("SET #value=:value, #identity=:identity, #expires=:expires"), // If both key and path already exist, we can only write if // A. identity is equal to our identity (or the identity doesn't exist) // or // B. The ttl on the item is <= to the current time ConditionExpression: aws.String( "attribute_not_exists(#path) or " + "attribute_not_exists(#key) or " + // To work when upgrading from older versions that did not include the // Identity attribute, we first check if the attr doesn't exist, and if // it does, then we check if the identity is equal to our own. "(attribute_not_exists(#identity) or #identity = :identity) or " + "#expires <= :now", ), ExpressionAttributeNames: map[string]*string{ "#path": aws.String("Path"), "#key": aws.String("Key"), "#identity": aws.String("Identity"), "#expires": aws.String("Expires"), "#value": aws.String("Value"), }, ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{ ":identity": &dynamodb.AttributeValue{B: []byte(l.identity)}, ":value": &dynamodb.AttributeValue{B: []byte(l.value)}, ":now": &dynamodb.AttributeValue{N: aws.String(strconv.FormatInt(now.UnixNano(), 10))}, ":expires": &dynamodb.AttributeValue{N: aws.String(strconv.FormatInt(now.Add(l.ttl).UnixNano(), 10))}, }, }) return err } // watch checks whether the lock has changed in the // DynamoDB table and closes the leader channel if so. // The interval is set by `DynamoDBWatchRetryInterval`. // If an error occurs during the check, watch will retry // the operation for `DynamoDBWatchRetryMax` times and // close the leader channel if it can't succeed. func (l *DynamoDBLock) watch(lost chan struct{}) { retries := DynamoDBWatchRetryMax ticker := time.NewTicker(l.watchRetryInterval) WatchLoop: for { select { case <-ticker.C: resp, err := l.backend.client.GetItem(&dynamodb.GetItemInput{ TableName: aws.String(l.backend.table), ConsistentRead: aws.Bool(true), Key: map[string]*dynamodb.AttributeValue{ "Path": {S: aws.String(recordPathForVaultKey(l.key))}, "Key": {S: aws.String(recordKeyForVaultKey(l.key))}, }, }) if err != nil { retries-- if retries == 0 { break WatchLoop } continue } if resp == nil { break WatchLoop } record := &DynamoDBLockRecord{} err = dynamodbattribute.UnmarshalMap(resp.Item, record) if err != nil || string(record.Identity) != l.identity { break WatchLoop } } } close(lost) } // ensureTableExists creates a DynamoDB table with a given // DynamoDB client. If the table already exists, it is not // being reconfigured. func ensureTableExists(client *dynamodb.DynamoDB, table string, readCapacity, writeCapacity int) error { _, err := client.DescribeTable(&dynamodb.DescribeTableInput{ TableName: aws.String(table), }) if awsError, ok := err.(awserr.Error); ok { if awsError.Code() == "ResourceNotFoundException" { _, err = client.CreateTable(&dynamodb.CreateTableInput{ TableName: aws.String(table), ProvisionedThroughput: &dynamodb.ProvisionedThroughput{ ReadCapacityUnits: aws.Int64(int64(readCapacity)), WriteCapacityUnits: aws.Int64(int64(writeCapacity)), }, KeySchema: []*dynamodb.KeySchemaElement{{ AttributeName: aws.String("Path"), KeyType: aws.String("HASH"), }, { AttributeName: aws.String("Key"), KeyType: aws.String("RANGE"), }}, AttributeDefinitions: []*dynamodb.AttributeDefinition{{ AttributeName: aws.String("Path"), AttributeType: aws.String("S"), }, { AttributeName: aws.String("Key"), AttributeType: aws.String("S"), }}, }) if err != nil { return err } err = client.WaitUntilTableExists(&dynamodb.DescribeTableInput{ TableName: aws.String(table), }) if err != nil { return err } } } if err != nil { return err } return nil } // recordPathForVaultKey transforms a vault key into // a value suitable for the `DynamoDBRecord`'s `Path` // property. This path equals the the vault key without // its last component. func recordPathForVaultKey(key string) string { if strings.Contains(key, "/") { return pkgPath.Dir(key) } return DynamoDBEmptyPath } // recordKeyForVaultKey transforms a vault key into // a value suitable for the `DynamoDBRecord`'s `Key` // property. This path equals the the vault key's // last component. func recordKeyForVaultKey(key string) string { return pkgPath.Base(key) } // vaultKey returns the vault key for a given record // from the DynamoDB table. This is the combination of // the records Path and Key. func vaultKey(record *DynamoDBRecord) string { path := unescapeEmptyPath(record.Path) if path == "" { return record.Key } return pkgPath.Join(record.Path, record.Key) } // escapeEmptyPath is used to escape the root key's path // with a value that can be stored in DynamoDB. DynamoDB // does not allow values to be empty strings. func escapeEmptyPath(s string) string { if s == "" { return DynamoDBEmptyPath } return s } // unescapeEmptyPath is the opposite of `escapeEmptyPath`. func unescapeEmptyPath(s string) string { if s == DynamoDBEmptyPath { return "" } return s }