303c2aee7c
* Update tooling * Run gofumpt * go mod vendor
324 lines
7.4 KiB
Go
324 lines
7.4 KiB
Go
package s3
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/armon/go-metrics"
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/s3"
|
|
"github.com/hashicorp/errwrap"
|
|
"github.com/hashicorp/go-cleanhttp"
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/sdk/helper/awsutil"
|
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
|
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
)
|
|
|
|
// Verify S3Backend satisfies the correct interfaces
|
|
var _ physical.Backend = (*S3Backend)(nil)
|
|
|
|
// S3Backend is a physical backend that stores data
|
|
// within an S3 bucket.
|
|
type S3Backend struct {
|
|
bucket string
|
|
path string
|
|
kmsKeyId string
|
|
client *s3.S3
|
|
logger log.Logger
|
|
permitPool *physical.PermitPool
|
|
}
|
|
|
|
// NewS3Backend constructs a S3 backend using a pre-existing
|
|
// bucket. Credentials can be provided to the backend, sourced
|
|
// from the environment, AWS credential files or by IAM role.
|
|
func NewS3Backend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
|
bucket := os.Getenv("AWS_S3_BUCKET")
|
|
if bucket == "" {
|
|
bucket = conf["bucket"]
|
|
if bucket == "" {
|
|
return nil, fmt.Errorf("'bucket' must be set")
|
|
}
|
|
}
|
|
|
|
path := conf["path"]
|
|
|
|
accessKey, ok := conf["access_key"]
|
|
if !ok {
|
|
accessKey = ""
|
|
}
|
|
secretKey, ok := conf["secret_key"]
|
|
if !ok {
|
|
secretKey = ""
|
|
}
|
|
sessionToken, ok := conf["session_token"]
|
|
if !ok {
|
|
sessionToken = ""
|
|
}
|
|
endpoint := os.Getenv("AWS_S3_ENDPOINT")
|
|
if endpoint == "" {
|
|
endpoint = conf["endpoint"]
|
|
}
|
|
region := os.Getenv("AWS_REGION")
|
|
if region == "" {
|
|
region = os.Getenv("AWS_DEFAULT_REGION")
|
|
if region == "" {
|
|
region = conf["region"]
|
|
if region == "" {
|
|
region = "us-east-1"
|
|
}
|
|
}
|
|
}
|
|
s3ForcePathStyleStr, ok := conf["s3_force_path_style"]
|
|
if !ok {
|
|
s3ForcePathStyleStr = "false"
|
|
}
|
|
s3ForcePathStyleBool, err := parseutil.ParseBool(s3ForcePathStyleStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid boolean set for s3_force_path_style: %q", s3ForcePathStyleStr)
|
|
}
|
|
disableSSLStr, ok := conf["disable_ssl"]
|
|
if !ok {
|
|
disableSSLStr = "false"
|
|
}
|
|
disableSSLBool, err := parseutil.ParseBool(disableSSLStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid boolean set for disable_ssl: %q", disableSSLStr)
|
|
}
|
|
|
|
credsConfig := &awsutil.CredentialsConfig{
|
|
AccessKey: accessKey,
|
|
SecretKey: secretKey,
|
|
SessionToken: sessionToken,
|
|
Logger: logger,
|
|
}
|
|
creds, err := credsConfig.GenerateCredentialChain()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pooledTransport := cleanhttp.DefaultPooledTransport()
|
|
pooledTransport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
|
|
|
|
sess, err := session.NewSession(&aws.Config{
|
|
Credentials: creds,
|
|
HTTPClient: &http.Client{
|
|
Transport: pooledTransport,
|
|
},
|
|
Endpoint: aws.String(endpoint),
|
|
Region: aws.String(region),
|
|
S3ForcePathStyle: aws.Bool(s3ForcePathStyleBool),
|
|
DisableSSL: aws.Bool(disableSSLBool),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s3conn := s3.New(sess)
|
|
|
|
_, err = s3conn.ListObjects(&s3.ListObjectsInput{Bucket: &bucket})
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf(fmt.Sprintf("unable to access bucket %q in region %q: {{err}}", bucket, region), err)
|
|
}
|
|
|
|
maxParStr, ok := conf["max_parallel"]
|
|
var maxParInt int
|
|
if ok {
|
|
maxParInt, err = strconv.Atoi(maxParStr)
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
|
}
|
|
}
|
|
|
|
kmsKeyId, ok := conf["kms_key_id"]
|
|
if !ok {
|
|
kmsKeyId = ""
|
|
}
|
|
|
|
s := &S3Backend{
|
|
client: s3conn,
|
|
bucket: bucket,
|
|
path: path,
|
|
kmsKeyId: kmsKeyId,
|
|
logger: logger,
|
|
permitPool: physical.NewPermitPool(maxParInt),
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
// Put is used to insert or update an entry
|
|
func (s *S3Backend) Put(ctx context.Context, entry *physical.Entry) error {
|
|
defer metrics.MeasureSince([]string{"s3", "put"}, time.Now())
|
|
|
|
s.permitPool.Acquire()
|
|
defer s.permitPool.Release()
|
|
|
|
// Setup key
|
|
key := path.Join(s.path, entry.Key)
|
|
|
|
putObjectInput := &s3.PutObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
Body: bytes.NewReader(entry.Value),
|
|
}
|
|
|
|
if s.kmsKeyId != "" {
|
|
putObjectInput.ServerSideEncryption = aws.String("aws:kms")
|
|
putObjectInput.SSEKMSKeyId = aws.String(s.kmsKeyId)
|
|
}
|
|
|
|
_, err := s.client.PutObject(putObjectInput)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Get is used to fetch an entry
|
|
func (s *S3Backend) Get(ctx context.Context, key string) (*physical.Entry, error) {
|
|
defer metrics.MeasureSince([]string{"s3", "get"}, time.Now())
|
|
|
|
s.permitPool.Acquire()
|
|
defer s.permitPool.Release()
|
|
|
|
// Setup key
|
|
key = path.Join(s.path, key)
|
|
|
|
resp, err := s.client.GetObject(&s3.GetObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
})
|
|
if resp != nil && resp.Body != nil {
|
|
defer resp.Body.Close()
|
|
}
|
|
if awsErr, ok := err.(awserr.RequestFailure); ok {
|
|
// Return nil on 404s, error on anything else
|
|
if awsErr.StatusCode() == 404 {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp == nil {
|
|
return nil, fmt.Errorf("got nil response from S3 but no error")
|
|
}
|
|
|
|
data := bytes.NewBuffer(nil)
|
|
if resp.ContentLength != nil {
|
|
data = bytes.NewBuffer(make([]byte, 0, *resp.ContentLength))
|
|
}
|
|
_, err = io.Copy(data, resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Strip path prefix
|
|
if s.path != "" {
|
|
key = strings.TrimPrefix(key, s.path+"/")
|
|
}
|
|
|
|
ent := &physical.Entry{
|
|
Key: key,
|
|
Value: data.Bytes(),
|
|
}
|
|
|
|
return ent, nil
|
|
}
|
|
|
|
// Delete is used to permanently delete an entry
|
|
func (s *S3Backend) Delete(ctx context.Context, key string) error {
|
|
defer metrics.MeasureSince([]string{"s3", "delete"}, time.Now())
|
|
|
|
s.permitPool.Acquire()
|
|
defer s.permitPool.Release()
|
|
|
|
// Setup key
|
|
key = path.Join(s.path, key)
|
|
|
|
_, err := s.client.DeleteObject(&s3.DeleteObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(key),
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// List is used to list all the keys under a given
|
|
// prefix, up to the next prefix.
|
|
func (s *S3Backend) List(ctx context.Context, prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"s3", "list"}, time.Now())
|
|
|
|
s.permitPool.Acquire()
|
|
defer s.permitPool.Release()
|
|
|
|
// Setup prefix
|
|
prefix = path.Join(s.path, prefix)
|
|
|
|
// Validate prefix (if present) is ending with a "/"
|
|
if prefix != "" && !strings.HasSuffix(prefix, "/") {
|
|
prefix += "/"
|
|
}
|
|
|
|
params := &s3.ListObjectsV2Input{
|
|
Bucket: aws.String(s.bucket),
|
|
Prefix: aws.String(prefix),
|
|
Delimiter: aws.String("/"),
|
|
}
|
|
|
|
keys := []string{}
|
|
|
|
err := s.client.ListObjectsV2Pages(params,
|
|
func(page *s3.ListObjectsV2Output, lastPage bool) bool {
|
|
if page != nil {
|
|
// Add truncated 'folder' paths
|
|
for _, commonPrefix := range page.CommonPrefixes {
|
|
// Avoid panic
|
|
if commonPrefix == nil {
|
|
continue
|
|
}
|
|
|
|
commonPrefix := strings.TrimPrefix(*commonPrefix.Prefix, prefix)
|
|
keys = append(keys, commonPrefix)
|
|
}
|
|
// Add objects only from the current 'folder'
|
|
for _, key := range page.Contents {
|
|
// Avoid panic
|
|
if key == nil {
|
|
continue
|
|
}
|
|
|
|
key := strings.TrimPrefix(*key.Key, prefix)
|
|
keys = append(keys, key)
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sort.Strings(keys)
|
|
|
|
return keys, nil
|
|
}
|