52581cd472
Adds debug and warn logging around AWS credential chain generation, specifically to help users debugging auto-unseal problems on AWS, by logging which role is being used in the case of a webidentity token. Adds a deferred call to flush the log output as well, to ensure logs are output in the event of an initialization failure.
327 lines
7.4 KiB
Go
327 lines
7.4 KiB
Go
package s3
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/armon/go-metrics"
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/s3"
|
|
"github.com/hashicorp/errwrap"
|
|
"github.com/hashicorp/go-cleanhttp"
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/sdk/helper/awsutil"
|
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
|
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
)
|
|
|
|
// Verify S3Backend satisfies the correct interfaces
|
|
var _ physical.Backend = (*S3Backend)(nil)
|
|
|
|
// S3Backend is a physical backend that stores data
|
|
// within an S3 bucket.
|
|
type S3Backend struct {
|
|
bucket string
|
|
path string
|
|
kmsKeyId string
|
|
client *s3.S3
|
|
logger log.Logger
|
|
permitPool *physical.PermitPool
|
|
}
|
|
|
|
// NewS3Backend constructs a S3 backend using a pre-existing
|
|
// bucket. Credentials can be provided to the backend, sourced
|
|
// from the environment, AWS credential files or by IAM role.
|
|
func NewS3Backend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
|
bucket := os.Getenv("AWS_S3_BUCKET")
|
|
if bucket == "" {
|
|
bucket = conf["bucket"]
|
|
if bucket == "" {
|
|
return nil, fmt.Errorf("'bucket' must be set")
|
|
}
|
|
}
|
|
|
|
path := conf["path"]
|
|
|
|
accessKey, ok := conf["access_key"]
|
|
if !ok {
|
|
accessKey = ""
|
|
}
|
|
secretKey, ok := conf["secret_key"]
|
|
if !ok {
|
|
secretKey = ""
|
|
}
|
|
sessionToken, ok := conf["session_token"]
|
|
if !ok {
|
|
sessionToken = ""
|
|
}
|
|
endpoint := os.Getenv("AWS_S3_ENDPOINT")
|
|
if endpoint == "" {
|
|
endpoint = conf["endpoint"]
|
|
}
|
|
region := os.Getenv("AWS_REGION")
|
|
if region == "" {
|
|
region = os.Getenv("AWS_DEFAULT_REGION")
|
|
if region == "" {
|
|
region = conf["region"]
|
|
if region == "" {
|
|
region = "us-east-1"
|
|
}
|
|
}
|
|
}
|
|
s3ForcePathStyleStr, ok := conf["s3_force_path_style"]
|
|
if !ok {
|
|
s3ForcePathStyleStr = "false"
|
|
}
|
|
s3ForcePathStyleBool, err := parseutil.ParseBool(s3ForcePathStyleStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid boolean set for s3_force_path_style: %q", s3ForcePathStyleStr)
|
|
}
|
|
disableSSLStr, ok := conf["disable_ssl"]
|
|
if !ok {
|
|
disableSSLStr = "false"
|
|
}
|
|
disableSSLBool, err := parseutil.ParseBool(disableSSLStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid boolean set for disable_ssl: %q", disableSSLStr)
|
|
}
|
|
|
|
credsConfig := &awsutil.CredentialsConfig{
|
|
AccessKey: accessKey,
|
|
SecretKey: secretKey,
|
|
SessionToken: sessionToken,
|
|
Logger: logger,
|
|
}
|
|
creds, err := credsConfig.GenerateCredentialChain()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pooledTransport := cleanhttp.DefaultPooledTransport()
|
|
pooledTransport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
|
|
|
|
sess, err := session.NewSession(&aws.Config{
|
|
Credentials: creds,
|
|
HTTPClient: &http.Client{
|
|
Transport: pooledTransport,
|
|
},
|
|
Endpoint: aws.String(endpoint),
|
|
Region: aws.String(region),
|
|
S3ForcePathStyle: aws.Bool(s3ForcePathStyleBool),
|
|
DisableSSL: aws.Bool(disableSSLBool),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s3conn := s3.New(sess)
|
|
|
|
_, err = s3conn.ListObjects(&s3.ListObjectsInput{Bucket: &bucket})
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf(fmt.Sprintf("unable to access bucket %q in region %q: {{err}}", bucket, region), err)
|
|
}
|
|
|
|
maxParStr, ok := conf["max_parallel"]
|
|
var maxParInt int
|
|
if ok {
|
|
maxParInt, err = strconv.Atoi(maxParStr)
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
|
}
|
|
}
|
|
|
|
kmsKeyId, ok := conf["kms_key_id"]
|
|
if !ok {
|
|
kmsKeyId = ""
|
|
}
|
|
|
|
s := &S3Backend{
|
|
client: s3conn,
|
|
bucket: bucket,
|
|
path: path,
|
|
kmsKeyId: kmsKeyId,
|
|
logger: logger,
|
|
permitPool: physical.NewPermitPool(maxParInt),
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
// Put is used to insert or update an entry
|
|
func (s *S3Backend) Put(ctx context.Context, entry *physical.Entry) error {
|
|
defer metrics.MeasureSince([]string{"s3", "put"}, time.Now())
|
|
|
|
s.permitPool.Acquire()
|
|
defer s.permitPool.Release()
|
|
|
|
// Setup key
|
|
key := path.Join(s.path, entry.Key)
|
|
|
|
putObjectInput := &s3.PutObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
Body: bytes.NewReader(entry.Value),
|
|
}
|
|
|
|
if s.kmsKeyId != "" {
|
|
putObjectInput.ServerSideEncryption = aws.String("aws:kms")
|
|
putObjectInput.SSEKMSKeyId = aws.String(s.kmsKeyId)
|
|
}
|
|
|
|
_, err := s.client.PutObject(putObjectInput)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Get is used to fetch an entry
|
|
func (s *S3Backend) Get(ctx context.Context, key string) (*physical.Entry, error) {
|
|
defer metrics.MeasureSince([]string{"s3", "get"}, time.Now())
|
|
|
|
s.permitPool.Acquire()
|
|
defer s.permitPool.Release()
|
|
|
|
// Setup key
|
|
key = path.Join(s.path, key)
|
|
|
|
resp, err := s.client.GetObject(&s3.GetObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
})
|
|
if resp != nil && resp.Body != nil {
|
|
defer resp.Body.Close()
|
|
}
|
|
if awsErr, ok := err.(awserr.RequestFailure); ok {
|
|
// Return nil on 404s, error on anything else
|
|
if awsErr.StatusCode() == 404 {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp == nil {
|
|
return nil, fmt.Errorf("got nil response from S3 but no error")
|
|
}
|
|
|
|
data := bytes.NewBuffer(nil)
|
|
if resp.ContentLength != nil {
|
|
data = bytes.NewBuffer(make([]byte, 0, *resp.ContentLength))
|
|
}
|
|
_, err = io.Copy(data, resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Strip path prefix
|
|
if s.path != "" {
|
|
key = strings.TrimPrefix(key, s.path+"/")
|
|
}
|
|
|
|
ent := &physical.Entry{
|
|
Key: key,
|
|
Value: data.Bytes(),
|
|
}
|
|
|
|
return ent, nil
|
|
}
|
|
|
|
// Delete is used to permanently delete an entry
|
|
func (s *S3Backend) Delete(ctx context.Context, key string) error {
|
|
defer metrics.MeasureSince([]string{"s3", "delete"}, time.Now())
|
|
|
|
s.permitPool.Acquire()
|
|
defer s.permitPool.Release()
|
|
|
|
// Setup key
|
|
key = path.Join(s.path, key)
|
|
|
|
_, err := s.client.DeleteObject(&s3.DeleteObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// List is used to list all the keys under a given
|
|
// prefix, up to the next prefix.
|
|
func (s *S3Backend) List(ctx context.Context, prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"s3", "list"}, time.Now())
|
|
|
|
s.permitPool.Acquire()
|
|
defer s.permitPool.Release()
|
|
|
|
// Setup prefix
|
|
prefix = path.Join(s.path, prefix)
|
|
|
|
// Validate prefix (if present) is ending with a "/"
|
|
if prefix != "" && !strings.HasSuffix(prefix, "/") {
|
|
prefix += "/"
|
|
}
|
|
|
|
params := &s3.ListObjectsV2Input{
|
|
Bucket: aws.String(s.bucket),
|
|
Prefix: aws.String(prefix),
|
|
Delimiter: aws.String("/"),
|
|
}
|
|
|
|
keys := []string{}
|
|
|
|
err := s.client.ListObjectsV2Pages(params,
|
|
func(page *s3.ListObjectsV2Output, lastPage bool) bool {
|
|
if page != nil {
|
|
// Add truncated 'folder' paths
|
|
for _, commonPrefix := range page.CommonPrefixes {
|
|
// Avoid panic
|
|
if commonPrefix == nil {
|
|
continue
|
|
}
|
|
|
|
commonPrefix := strings.TrimPrefix(*commonPrefix.Prefix, prefix)
|
|
keys = append(keys, commonPrefix)
|
|
}
|
|
// Add objects only from the current 'folder'
|
|
for _, key := range page.Contents {
|
|
// Avoid panic
|
|
if key == nil {
|
|
continue
|
|
}
|
|
|
|
key := strings.TrimPrefix(*key.Key, prefix)
|
|
keys = append(keys, key)
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sort.Strings(keys)
|
|
|
|
return keys, nil
|
|
}
|