logical/aws

This commit is contained in:
Mitchell Hashimoto 2015-03-20 17:59:48 +01:00
parent a0f59f682b
commit 62d9bec8be
12 changed files with 534 additions and 22 deletions

View File

@ -17,6 +17,14 @@ test: generate
TF_ACC= go test $(TEST) $(TESTARGS) -timeout=30s -parallel=4
@$(MAKE) vet
# testacc runs acceptance tests
testacc: generate
@if [ "$(TEST)" = "./..." ]; then \
echo "ERROR: Set TEST to a specific package"; \
exit 1; \
fi
TF_ACC=1 go test $(TEST) -v $(TESTARGS) -timeout 45m
# testrace runs the race checker
testrace: generate
TF_ACC= go test -race $(TEST) $(TESTARGS)

View File

@ -0,0 +1,31 @@
package aws
import (
"github.com/hashicorp/vault/logical/framework"
)
func Backend() *framework.Backend {
var b backend
b.Backend = &framework.Backend{
PathsRoot: []string{
"root",
"policy/*",
},
Paths: []*framework.Path{
pathRoot(),
pathPolicy(),
pathUser(&b),
},
Secrets: []*framework.Secret{
secretAccessKeys(),
},
}
return b.Backend
}
type backend struct {
*framework.Backend
}

View File

@ -0,0 +1,115 @@
package aws
import (
"encoding/base64"
"log"
"os"
"testing"
"time"
"github.com/hashicorp/aws-sdk-go/aws"
"github.com/hashicorp/aws-sdk-go/gen/ec2"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"github.com/mitchellh/mapstructure"
)
func TestBackend_basic(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Backend: Backend(),
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepWritePolicy(t, "test", testPolicy),
testAccStepReadUser(t, "test"),
},
})
}
func testAccPreCheck(t *testing.T) {
if v := os.Getenv("AWS_ACCESS_KEY_ID"); v == "" {
t.Fatal("AWS_ACCESS_KEY_ID must be set for acceptance tests")
}
if v := os.Getenv("AWS_SECRET_ACCESS_KEY"); v == "" {
t.Fatal("AWS_SECRET_ACCESS_KEY must be set for acceptance tests")
}
if v := os.Getenv("AWS_DEFAULT_REGION"); v == "" {
log.Println("[INFO] Test: Using us-west-2 as test region")
os.Setenv("AWS_DEFAULT_REGION", "us-west-2")
}
}
func testAccStepConfig(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "root",
Data: map[string]interface{}{
"access_key": os.Getenv("AWS_ACCESS_KEY_ID"),
"secret_key": os.Getenv("AWS_SECRET_ACCESS_KEY"),
"region": os.Getenv("AWS_DEFAULT_REGION"),
},
}
}
func testAccStepReadUser(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: name,
Check: func(resp *logical.Response) error {
var d struct {
AccessKey string `mapstructure:"access_key"`
SecretKey string `mapstructure:"secret_key"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[WARN] Generated credentials: %v", d)
// Sleep sometime because AWS is eventually consistent
log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...")
time.Sleep(10 * time.Second)
// Build a client and verify that the credentials work
creds := aws.Creds(d.AccessKey, d.SecretKey, "")
client := ec2.New(creds, "us-east-1", nil)
log.Printf("[WARN] Verifying that the generated credentials work...")
_, err := client.DescribeInstances(&ec2.DescribeInstancesRequest{})
if err != nil {
return err
}
return nil
},
}
}
func testAccStepWritePolicy(t *testing.T, name string, policy string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "policy/" + name,
Data: map[string]interface{}{
"policy": base64.StdEncoding.EncodeToString([]byte(policy)),
},
}
}
const testPolicy = `
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "Stmt1426528957000",
"Effect": "Allow",
"Action": [
"ec2:*"
],
"Resource": [
"*"
]
}
]
}
`

View File

@ -0,0 +1,29 @@
package aws
import (
"fmt"
"github.com/hashicorp/aws-sdk-go/aws"
"github.com/hashicorp/aws-sdk-go/gen/iam"
"github.com/hashicorp/vault/logical"
)
func clientIAM(s logical.Storage) (*iam.IAM, error) {
entry, err := s.Get("root")
if err != nil {
return nil, err
}
if entry == nil {
return nil, fmt.Errorf(
"root credentials haven't been configured. Please configure\n" +
"them at the '/root' endpoint")
}
var config rootConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, fmt.Errorf("error reading root configuration: %s", err)
}
creds := aws.Creds(config.AccessKey, config.SecretKey, "")
return iam.New(creds, config.Region, nil), nil
}

View File

@ -0,0 +1,59 @@
package aws
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathPolicy() *framework.Path {
return &framework.Path{
Pattern: `policy/(?P<name>\w+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the policy",
},
"policy": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Policy document, base64 encoded.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.WriteOperation: pathPolicyWrite,
},
}
}
func pathPolicyWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Decode and compact the policy. AWS requires a JSON-compacted policy
// because it mustn't contain newlines.
var policyBuf bytes.Buffer
policyRaw, err := base64.StdEncoding.DecodeString(d.Get("policy").(string))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error decoding policy base64: %s", err)), nil
}
if err := json.Compact(&policyBuf, []byte(policyRaw)); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error compacting policy: %s", err)), nil
}
// Write the policy into storage
err = req.Storage.Put(&logical.StorageEntry{
Key: "policy/" + d.Get("name").(string),
Value: policyBuf.Bytes(),
})
if err != nil {
return nil, err
}
return nil, nil
}

View File

@ -0,0 +1,56 @@
package aws
import (
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathRoot() *framework.Path {
return &framework.Path{
Pattern: "root",
Fields: map[string]*framework.FieldSchema{
"access_key": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Access key with permission to create new keys.",
},
"secret_key": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Secret key with permission to create new keys.",
},
"region": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Region for API calls.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.WriteOperation: pathRootWrite,
},
}
}
func pathRootWrite(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := logical.StorageEntryJSON("root", rootConfig{
AccessKey: data.Get("access_key").(string),
SecretKey: data.Get("secret_key").(string),
Region: data.Get("region").(string),
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
return nil, nil
}
type rootConfig struct {
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
Region string `json:"region"`
}

View File

@ -0,0 +1,85 @@
package aws
import (
"fmt"
"math/rand"
"time"
"github.com/hashicorp/aws-sdk-go/aws"
"github.com/hashicorp/aws-sdk-go/gen/iam"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathUser(b *backend) *framework.Path {
return &framework.Path{
Pattern: `(?P<name>\w+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the policy",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathUserRead,
},
}
}
func (b *backend) pathUserRead(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
client, err := clientIAM(req.Storage)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
// Read the policy
policy, err := req.Storage.Get("policy/" + d.Get("name").(string))
if err != nil {
return nil, fmt.Errorf("error retrieving policy: %s", err)
}
if policy == nil {
return logical.ErrorResponse(fmt.Sprintf(
"Policy '%s' not found", d.Get("name").(string))), nil
}
// Generate a random username. We don't put the policy names in the
// username because the AWS console makes it pretty easy to see that.
username := fmt.Sprintf("vault-%d-%d", time.Now().Unix(), rand.Int31n(10000))
_, err = client.CreateUser(&iam.CreateUserRequest{
UserName: aws.String(username),
})
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error creating IAM user: %s", err)), nil
}
// Add the user to all the groups
err = client.PutUserPolicy(&iam.PutUserPolicyRequest{
UserName: aws.String(username),
PolicyName: aws.String(d.Get("name").(string)),
PolicyDocument: aws.String(string(policy.Value)),
})
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error adding user to group: %s", err)), nil
}
// Create the keys
keyResp, err := client.CreateAccessKey(&iam.CreateAccessKeyRequest{
UserName: aws.String(username),
})
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error creating access keys: %s", err)), nil
}
// Return the info!
return b.Secret(SecretAccessKeyType).Response(map[string]interface{}{
"access_key": *keyResp.AccessKey.AccessKeyID,
"secret_key": *keyResp.AccessKey.SecretAccessKey,
}, map[string]interface{}{
"username": username,
}), nil
}

View File

@ -0,0 +1,110 @@
package aws
import (
"fmt"
"github.com/hashicorp/aws-sdk-go/aws"
"github.com/hashicorp/aws-sdk-go/gen/iam"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
const SecretAccessKeyType = "access_keys"
func secretAccessKeys() *framework.Secret {
return &framework.Secret{
Type: SecretAccessKeyType,
Fields: map[string]*framework.FieldSchema{
"access_key": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Access Key",
},
"secret_key": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Secret Key",
},
},
Revoke: secretAccessKeysRevoke,
}
}
func secretAccessKeysRevoke(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the username from the internal data
usernameRaw, ok := req.Secret.InternalData["username"]
if !ok {
return nil, fmt.Errorf("secret is missing username internal data")
}
username, ok := usernameRaw.(string)
if !ok {
return nil, fmt.Errorf("secret is missing username internal data")
}
// Get the client
client, err := clientIAM(req.Storage)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
// Get information about this user
groupsResp, err := client.ListGroupsForUser(&iam.ListGroupsForUserRequest{
UserName: aws.String(username),
MaxItems: aws.Integer(1000),
})
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
groups := groupsResp.Groups
policiesResp, err := client.ListUserPolicies(&iam.ListUserPoliciesRequest{
UserName: aws.String(username),
MaxItems: aws.Integer(1000),
})
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
policies := policiesResp.PolicyNames
// Revoke it!
err = client.DeleteAccessKey(&iam.DeleteAccessKeyRequest{
AccessKeyID: aws.String(d.Get("access_key").(string)),
UserName: aws.String(username),
})
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
// Delete any policies
for _, p := range policies {
err = client.DeleteUserPolicy(&iam.DeleteUserPolicyRequest{
UserName: aws.String(username),
PolicyName: aws.String(p),
})
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}
// Remove the user from all their groups
for _, g := range groups {
err = client.RemoveUserFromGroup(&iam.RemoveUserFromGroupRequest{
GroupName: g.GroupName,
UserName: aws.String(username),
})
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}
// Delete the user
err = client.DeleteUser(&iam.DeleteUserRequest{
UserName: aws.String(username),
})
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
return nil, nil
}

View File

@ -190,13 +190,8 @@ func (b *Backend) handleRevokeRenew(
return nil, logical.ErrUnsupportedOperation
}
var data map[string]interface{}
if raw, ok := req.Data["previous_data"]; ok {
data = raw.(map[string]interface{})
}
return fn(req, &FieldData{
Raw: data,
Raw: req.Data,
Schema: secret.Fields,
})
}

View File

@ -13,18 +13,10 @@ type Response struct {
Data map[string]interface{}
}
/*
// Validate is used to sanity check a lease
func (l *Lease) Validate() error {
if l.Duration <= 0 {
return fmt.Errorf("lease duration must be greater than zero")
}
if l.GracePeriod < 0 {
return fmt.Errorf("grace period cannot be less than zero")
}
return nil
// IsError returns true if this response seems to indicate an error.
func (r *Response) IsError() bool {
return r != nil && len(r.Data) == 1 && r.Data["error"] != nil
}
*/
// HelpResponse is used to format a help response
func HelpResponse(text string, seeAlso []string) *Response {

View File

@ -1,5 +1,10 @@
package logical
import (
"bytes"
"encoding/json"
)
// Storage is the way that logical backends are able read/write data.
type Storage interface {
List(prefix string) ([]string, error)
@ -13,3 +18,21 @@ type StorageEntry struct {
Key string
Value []byte
}
func (e *StorageEntry) DecodeJSON(out interface{}) error {
return json.Unmarshal(e.Value, out)
}
// StorageEntryJSON creates a StorageEntry with a JSON-encoded value.
func StorageEntryJSON(k string, v interface{}) (*StorageEntry, error) {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
if err := enc.Encode(v); err != nil {
return nil, err
}
return &StorageEntry{
Key: k,
Value: buf.Bytes(),
}, nil
}

View File

@ -115,9 +115,9 @@ func Test(t TestT, c TestCase) {
}
// Unseal the core
if sealed, err := core.Unseal(init.SecretShares[0]); err != nil {
if unsealed, err := core.Unseal(init.SecretShares[0]); err != nil {
t.Fatal("error unsealing core: ", err)
} else if sealed {
} else if !unsealed {
t.Fatal("vault shouldn't be sealed")
}
@ -162,6 +162,9 @@ func Test(t TestT, c TestCase) {
resp.Data,
))
}
if err == nil && resp.IsError() {
err = fmt.Errorf("Erroneous response:\n\n%#v", resp)
}
if err == nil && s.Check != nil {
// Call the test method
err = s.Check(resp)
@ -176,7 +179,10 @@ func Test(t TestT, c TestCase) {
var failedRevokes []*logical.Secret
for _, req := range revoke {
log.Printf("[WARN] Revoking secret: %#v", req.Secret)
_, err = core.HandleRequest(req)
resp, err := core.HandleRequest(req)
if err == nil && resp.IsError() {
err = fmt.Errorf("Erroneous response:\n\n%#v", resp)
}
if err != nil {
failedRevokes = append(failedRevokes, req.Secret)
t.Error(fmt.Sprintf("[ERR] Revoke error: %s", err))
@ -185,8 +191,11 @@ func Test(t TestT, c TestCase) {
// Perform any rollbacks. This should no-op if there aren't any.
log.Printf("[WARN] Requesting RollbackOperation")
_, err = core.HandleRequest(logical.RollbackRequest(prefix + "/"))
if err != nil {
resp, err := core.HandleRequest(logical.RollbackRequest(prefix + "/"))
if err == nil && resp.IsError() {
err = fmt.Errorf("Erroneous response:\n\n%#v", resp)
}
if err != nil && err != logical.ErrUnsupportedOperation {
t.Error(fmt.Sprintf("[ERR] Rollback error: %s", err))
}