Merge remote-tracking branch 'origin/master' into pr-1425

This commit is contained in:
Jeff Mitchell 2016-06-08 12:10:29 -04:00
commit da6371ffc3
204 changed files with 5492 additions and 7511 deletions

View file

@ -1,4 +1,4 @@
## 0.5.3 (Unreleased)
## 0.6.0 (Unreleased)
SECURITY:
@ -38,13 +38,23 @@ DEPRECATIONS/BREAKING CHANGES:
FEATURES:
* **AWS EC2 Auth Backend**: Provides a recure introduction mechanism for AWS EC2
instances allowing automated retrieval of Vault tokens. Unlike most Vault
authentication backends, this backend does not require first-deploying,
* **AWS EC2 Auth Backend**: Provides a secure introduction mechanism for AWS
EC2 instances allowing automated retrieval of Vault tokens. Unlike most
Vault authentication backends, this backend does not require first deploying
or provisioning security-sensitive credentials (tokens, username/password,
client certificates,etc). Instead, it treats AWS as a Trusted Third Party
client certificates, etc). Instead, it treats AWS as a Trusted Third Party
and uses the cryptographically signed dynamic metadata information that
uniquely represents each EC2 instance.
uniquely represents each EC2 instance. [Vault
Enterprise](https://www.hashicorp.com/vault.html) customers have access to a
turnkey client that speaks the backend API and makes access to a Vault token
easy.
* **Response Wrapping**: Nearly any response within Vault can now be wrapped
inside a single-use, time-limited token's cubbyhole, taking the [Cubbyhole
Authentication
Principles](https://www.hashicorp.com/blog/vault-cubbyhole-principles.html)
mechanism to its logical conclusion. Retrieving the original response is as
simple as a single API command or the new `vault unwrap` command. This makes
secret distribution easier and more secure, including secure introduction.
* **Azure Physical Backend**: You can now use Azure blob object storage as
your Vault physical data store [GH-1266]
* **Consul Backend Health Checks**: The Consul backend will automatically
@ -58,12 +68,18 @@ FEATURES:
system- or mount-set values. This is useful, for instance, when the max TTL
of the system or the `auth/token` mount must be set high to accommodate
certain needs but you want more granular restrictions on tokens being issued
directly from `auth/token`. [GH-1399]
directly from the Token authentication backend at `auth/token`. [GH-1399]
* **RabbitMQ Secret Backend**: Vault can now generate credentials for
RabbitMQ. Vhosts and tags can be defined within roles. [GH-788]
IMPROVEMENTS:
* audit: Add the DisplayName value to the copy of the Request object embedded
in the associated Response, to match the original Request object [GH-1387]
* audit: Enable auditing of the `seal` and `step-down` commands [GH-1435]
* backends: Remove most `root`/`sudo` paths in favor of normal ACL mechanisms.
A particular exception are any current MFA paths. A few paths in `token` and
`sys` also require `root` or `sudo`. [GH-1478]
* command/auth: Restore the previous authenticated token if the `auth` command
fails to authenticate the provided token [GH-1233]
* command/write: `-format` and `-field` can now be used with the `write`
@ -78,13 +94,21 @@ IMPROVEMENTS:
backend [GH-1404]
* credential/ldap: If `groupdn` is not configured, skip searching LDAP and
only return policies for local groups, plus a warning [GH-1283]
* credential/ldap: `vault list` support for users and groups [GH-1270]
* credential/ldap: Support for the `memberOf` attribute for group membership
searching [GH-1245]
* credential/userpass: Add list support for users [GH-911]
* credential/userpass: Remove user configuration paths from requiring sudo, in
favor of normal ACL mechanisms [GH-1312]
* credential/token: Sanitize policies and add `default` policies in appropriate
places [GH-1235]
* secret/aws: Use chain credentials to allow environment/EC2 instance/shared
providers [GH-307]
* secret/aws: Support for STS AssumeRole functionality [GH-1318]
* secret/pki: Added `exclude_cn_from_sans` field to prevent adding the CN to
DNS or Email Subject Alternate Names [GH-1220]
* secret/consul: Reading consul access configuration supported. The response
will contain non-sensitive information only [GH-1445]
* sys/capabilities: Enforce ACL checks for requests that query the capabilities
of a token on a given path [GH-1221]
@ -97,6 +121,11 @@ BUG FIXES:
* command/various: Tell the JSON decoder to not convert all numbers to floats;
fixes some various places where numbers were showing up in scientific
notation
* command/server: Prioritized `devRootTokenID` and `devListenAddress` flags
over their respective env vars [GH-1480]
* command/ssh: Provided option to disable host key checking. The automated
variant of `vault ssh` command uses `sshpass` which was failing to handle
host key checking presented by the `ssh` binary. [GH-1473]
* core: Properly persist mount-tuned TTLs for auth backends [GH-1371]
* core: Don't accidentally crosswire SIGINT to the reload handler [GH-1372]
* credential/github: Make organization comparison case-insensitive during
@ -114,9 +143,51 @@ BUG FIXES:
* credential/various: Fix renewal conditions when `default` policy is not
contained in the backend config [GH-1256]
* physical/s3: Don't panic in certain error cases from bad S3 responses [GH-1353]
* secret/consul: Use non-pooled Consul API client to avoid leaving files open
[GH-1428]
* secret/pki: Don't check whether a certificate is destined to be a CA
certificate if sign-verbatim endpoint is used [GH-1250]
## 0.5.3 (May 27th, 2016)
SECURITY:
* Consul ACL Token Revocation: An issue was reported to us indicating that
generated Consul ACL tokens were not being properly revoked. Upon
investigation, we found that this behavior was reproducible in a specific
scenario: when a generated lease for a Consul ACL token had been renewed
prior to revocation. In this case, the generated token was not being
properly persisted internally through the renewal function, leading to an
error during revocation due to the missing token. Unfortunately, this was
coded as a user error rather than an internal error, and the revocation
logic was expecting internal errors if revocation failed. As a result, the
revocation logic believed the revocation to have succeeded when it in fact
failed, causing the lease to be dropped while the token was still valid
within Consul. In this release, the Consul backend properly persists the
token through renewals, and the revocation logic has been changed to
consider any error type to have been a failure to revoke, causing the lease
to persist and attempt to be revoked later.
We have written an example shell script that searches through Consul's ACL
tokens and looks for those generated by Vault, which can be used as a template
for a revocation script as deemed necessary for any particular security
response. The script is available at
https://gist.github.com/jefferai/6233c2963f9407a858d84f9c27d725c0
Please note that any outstanding leases for Consul tokens produced prior to
0.5.3 that have been renewed will continue to exhibit this behavior. As a
result, we recommend either revoking all tokens produced by the backend and
issuing new ones, or if needed, a more advanced variant of the provided example
could use the timestamp embedded in each generated token's name to decide which
tokens are too old and should be deleted. This could then be run periodically
up until the maximum lease time for any outstanding pre-0.5.3 tokens has
expired.
This is a security-only release. There are no other code changes since 0.5.2.
The binaries have one additional change: they are built against Go 1.6.1 rather
than Go 1.6, as Go 1.6.1 contains two security fixes to the Go programming
language itself.
## 0.5.2 (March 16th, 2016)
FEATURES:

View file

@ -6,7 +6,8 @@ Vault [![Build Status](https://travis-ci.org/hashicorp/vault.svg)](https://travi
- Website: https://www.vaultproject.io
- IRC: `#vault-tool` on Freenode
- Mailing list: [Google Groups](https://groups.google.com/group/vault-tool)
- Announcement list: [Google Groups](https://groups.google.com/group/hashicorp-announce)
- Discussion list: [Google Groups](https://groups.google.com/group/vault-tool)
![Vault](https://raw.githubusercontent.com/hashicorp/vault/master/website/source/assets/images/logo-big.png?token=AAAFE8XmW6YF5TNuk3cosDGBK-sUGPEjks5VSAa2wA%3D%3D)

View file

@ -23,11 +23,19 @@ const EnvVaultClientCert = "VAULT_CLIENT_CERT"
const EnvVaultClientKey = "VAULT_CLIENT_KEY"
const EnvVaultInsecure = "VAULT_SKIP_VERIFY"
const EnvVaultTLSServerName = "VAULT_TLS_SERVER_NAME"
const EnvVaultWrapTTL = "VAULT_WRAP_TTL"
var (
errRedirect = errors.New("redirect")
)
// WrappingLookupFunc is a function that, given an HTTP verb and a path,
// returns an optional string duration to be used for response wrapping (e.g.
// "15s", or simply "15"). The path will not begin with "/v1/" or "v1/" or "/",
// however, end-of-path forward slashes are not trimmed, so must match your
// called path precisely.
type WrappingLookupFunc func(operation, path string) string
// Config is used to configure the creation of the client.
type Config struct {
// Address is the address of the Vault server. This should be a complete
@ -155,9 +163,10 @@ func (c *Config) ReadEnvironment() error {
// Client is the client to the Vault API. Create a client with
// NewClient.
type Client struct {
addr *url.URL
config *Config
token string
addr *url.URL
config *Config
token string
wrappingLookupFunc WrappingLookupFunc
}
// NewClient returns a new client for the given configuration.
@ -166,7 +175,6 @@ type Client struct {
// automatically added to the client. Otherwise, you must manually call
// `SetToken()`.
func NewClient(c *Config) (*Client, error) {
u, err := url.Parse(c.Address)
if err != nil {
return nil, err
@ -200,6 +208,12 @@ func NewClient(c *Config) (*Client, error) {
return client, nil
}
// SetWrappingLookupFunc sets a lookup function that returns desired wrap TTLs
// for a given operation and path
func (c *Client) SetWrappingLookupFunc(lookupFunc WrappingLookupFunc) {
c.wrappingLookupFunc = lookupFunc
}
// Token returns the access token being used by this client. It will
// return the empty string if there is no token set.
func (c *Client) Token() string {
@ -232,6 +246,19 @@ func (c *Client) NewRequest(method, path string) *Request {
Params: make(map[string][]string),
}
if c.wrappingLookupFunc != nil {
var lookupPath string
switch {
case strings.HasPrefix(path, "/v1/"):
lookupPath = strings.TrimPrefix(path, "/v1/")
case strings.HasPrefix(path, "v1/"):
lookupPath = strings.TrimPrefix(path, "v1/")
default:
lookupPath = path
}
req.WrapTTL = c.wrappingLookupFunc(method, lookupPath)
}
return req
}

View file

@ -1,5 +1,15 @@
package api
import (
"bytes"
"encoding/json"
"fmt"
)
const (
wrappedResponseLocation = "cubbyhole/response"
)
// Logical is used to perform logical backend operations on Vault.
type Logical struct {
c *Client
@ -80,3 +90,34 @@ func (c *Logical) Delete(path string) (*Secret, error) {
return nil, nil
}
func (c *Logical) Unwrap(wrappingToken string) (*Secret, error) {
origToken := c.c.Token()
defer c.c.SetToken(origToken)
c.c.SetToken(wrappingToken)
secret, err := c.Read(wrappedResponseLocation)
if err != nil {
return nil, fmt.Errorf("error reading %s: %s", wrappedResponseLocation, err)
}
if secret == nil {
return nil, fmt.Errorf("no value found at %s", wrappedResponseLocation)
}
if secret.Data == nil {
return nil, fmt.Errorf("\"data\" not found in wrapping response")
}
if _, ok := secret.Data["response"]; !ok {
return nil, fmt.Errorf("\"response\" not found in wrapping response \"data\" map")
}
wrappedSecret := new(Secret)
buf := bytes.NewBufferString(secret.Data["response"].(string))
dec := json.NewDecoder(buf)
dec.UseNumber()
if err := dec.Decode(wrappedSecret); err != nil {
return nil, fmt.Errorf("error unmarshaling wrapped secret: %s", err)
}
return wrappedSecret, nil
}

View file

@ -15,6 +15,7 @@ type Request struct {
URL *url.URL
Params url.Values
ClientToken string
WrapTTL string
Obj interface{}
Body io.Reader
BodySize int64
@ -62,5 +63,9 @@ func (r *Request) ToHTTP() (*http.Request, error) {
req.Header.Set("X-Vault-Token", r.ClientToken)
}
if len(r.WrapTTL) != 0 {
req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL)
}
return req, nil
}

View file

@ -3,6 +3,7 @@ package api
import (
"encoding/json"
"io"
"time"
)
// Secret is the structure returned for every secret within Vault.
@ -23,6 +24,18 @@ type Secret struct {
// Auth, if non-nil, means that there was authentication information
// attached to this response.
Auth *SecretAuth `json:"auth,omitempty"`
// WrapInfo, if non-nil, means that the initial response was wrapped in the
// cubbyhole of the given token (which has a TTL of the given number of
// seconds)
WrapInfo *SecretWrapInfo `json:"wrap_info,omitempty"`
}
// SecretWrapInfo contains wrapping information if we have it.
type SecretWrapInfo struct {
Token string `json:"token"`
TTL int `json:"ttl"`
CreationTime time.Time `json:"creation_time"`
}
// SecretAuth is the structure containing auth information if we have it.

View file

@ -4,6 +4,7 @@ import (
"reflect"
"strings"
"testing"
"time"
)
func TestParseSecret(t *testing.T) {
@ -17,9 +18,16 @@ func TestParseSecret(t *testing.T) {
},
"warnings": [
"a warning!"
]
],
"wrap_info": {
"token": "token",
"ttl": 60,
"creation_time": "2016-06-07T15:52:10-04:00"
}
}`)
rawTime, _ := time.Parse(time.RFC3339, "2016-06-07T15:52:10-04:00")
secret, err := ParseSecret(strings.NewReader(raw))
if err != nil {
t.Fatalf("err: %s", err)
@ -35,8 +43,13 @@ func TestParseSecret(t *testing.T) {
Warnings: []string{
"a warning!",
},
WrapInfo: &SecretWrapInfo{
Token: "token",
TTL: 60,
CreationTime: rawTime,
},
}
if !reflect.DeepEqual(secret, expected) {
t.Fatalf("bad: %#v %#v", secret, expected)
t.Fatalf("bad:\ngot\n%#v\nexpected\n%#v\n", secret, expected)
}
}

View file

@ -46,6 +46,7 @@ func (f *FormatJSON) FormatRequest(
Path: req.Path,
Data: req.Data,
RemoteAddr: getRemoteAddr(req),
WrapTTL: int(req.WrapTTL / time.Second),
},
})
}
@ -86,6 +87,15 @@ func (f *FormatJSON) FormatResponse(
}
}
var respWrapInfo *JSONWrapInfo
if resp.WrapInfo != nil {
respWrapInfo = &JSONWrapInfo{
TTL: int(resp.WrapInfo.TTL / time.Second),
Token: resp.WrapInfo.Token,
CreationTime: resp.WrapInfo.CreationTime,
}
}
// Encode!
enc := json.NewEncoder(w)
return enc.Encode(&JSONResponseEntry{
@ -104,6 +114,7 @@ func (f *FormatJSON) FormatResponse(
Path: req.Path,
Data: req.Data,
RemoteAddr: getRemoteAddr(req),
WrapTTL: int(req.WrapTTL / time.Second),
},
Response: JSONResponse{
@ -111,6 +122,7 @@ func (f *FormatJSON) FormatResponse(
Secret: respSecret,
Data: resp.Data,
Redirect: resp.Redirect,
WrapInfo: respWrapInfo,
},
})
}
@ -140,6 +152,7 @@ type JSONRequest struct {
Path string `json:"path"`
Data map[string]interface{} `json:"data"`
RemoteAddr string `json:"remote_address"`
WrapTTL int `json:"wrap_ttl"`
}
type JSONResponse struct {
@ -147,6 +160,7 @@ type JSONResponse struct {
Secret *JSONSecret `json:"secret,emitempty"`
Data map[string]interface{} `json:"data"`
Redirect string `json:"redirect"`
WrapInfo *JSONWrapInfo `json:"wrap_info,omitempty"`
}
type JSONAuth struct {
@ -161,6 +175,12 @@ type JSONSecret struct {
LeaseID string `json:"lease_id"`
}
type JSONWrapInfo struct {
TTL int `json:"ttl"`
Token string `json:"token"`
CreationTime time.Time `json:"creation_time"`
}
// getRemoteAddr safely gets the remote address avoiding a nil pointer
func getRemoteAddr(req *logical.Request) string {
if req != nil && req.Connection != nil {

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"strings"
"testing"
"time"
"errors"
@ -26,6 +27,7 @@ func TestFormatJSON_formatRequest(t *testing.T) {
Connection: &logical.Connection{
RemoteAddr: "127.0.0.1",
},
WrapTTL: 60 * time.Second,
},
errors.New("this is an error"),
testFormatJSONReqBasicStr,
@ -64,5 +66,5 @@ func TestFormatJSON_formatRequest(t *testing.T) {
}
}
const testFormatJSONReqBasicStr = `{"time":"2015-08-05T13:45:46Z","type":"request","auth":{"display_name":"","policies":["root"],"metadata":null},"request":{"operation":"update","path":"/foo","data":null,"remote_address":"127.0.0.1"},"error":"this is an error"}
const testFormatJSONReqBasicStr = `{"time":"2015-08-05T13:45:46Z","type":"request","auth":{"display_name":"","policies":["root"],"metadata":null},"request":{"operation":"update","path":"/foo","data":null,"wrap_ttl":60,"remote_address":"127.0.0.1"},"error":"this is an error"}
`

View file

@ -67,12 +67,25 @@ func Hash(salter *salt.Salt, raw interface{}) error {
}
}
if s.WrapInfo != nil {
if err := Hash(salter, s.WrapInfo); err != nil {
return err
}
}
data, err := HashStructure(s.Data, fn)
if err != nil {
return err
}
s.Data = data.(map[string]interface{})
case *logical.WrapInfo:
if s == nil {
return nil
}
s.Token = fn(s.Token)
}
return nil

View file

@ -44,6 +44,7 @@ func TestCopy_request(t *testing.T) {
Data: map[string]interface{}{
"foo": "bar",
},
WrapTTL: 60 * time.Second,
}
arg := expected
@ -66,6 +67,11 @@ func TestCopy_response(t *testing.T) {
Data: map[string]interface{}{
"foo": "bar",
},
WrapInfo: &logical.WrapInfo{
TTL: 60,
Token: "foo",
CreationTime: time.Now(),
},
}
arg := expected
@ -131,11 +137,21 @@ func TestHash(t *testing.T) {
Data: map[string]interface{}{
"foo": "bar",
},
WrapInfo: &logical.WrapInfo{
TTL: 60,
Token: "bar",
CreationTime: now,
},
},
&logical.Response{
Data: map[string]interface{}{
"foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
},
WrapInfo: &logical.WrapInfo{
TTL: 60,
Token: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
CreationTime: now,
},
},
},
{

View file

@ -124,7 +124,7 @@ func (b *backend) pathLoginRenew(
return nil, err
}
if !policyutil.EquivalentPolicies(mapPolicies, req.Auth.Policies) {
return logical.ErrorResponse("policies do not match"), nil
return nil, fmt.Errorf("policies do not match")
}
return framework.LeaseExtend(0, 0, b.System())(req, d)

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"sync"
@ -51,7 +51,7 @@ type backend struct {
EC2ClientsMap map[string]*ec2.EC2
}
func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
func Backend(conf *logical.BackendConfig) (*backend, error) {
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
HashFunc: salt.SHA256Hash,
})
@ -96,7 +96,7 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
},
}
return b.Backend, nil
return b, nil
}
// periodicFunc performs the tasks that the backend wishes to do periodically.
@ -108,13 +108,12 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
// Tidying of blacklist and whitelist are by default enabled. This can be
// changed using `config/tidy/roletags` and `config/tidy/identities` endpoints.
func (b *backend) periodicFunc(req *logical.Request) error {
// Run the tidy operations for the first time. Then run it when current
// time matches the nextTidyTime.
if b.nextTidyTime.IsZero() || !time.Now().UTC().Before(b.nextTidyTime) {
// safety_buffer defaults to 180 days for roletag blacklist
safety_buffer := 15552000
tidyBlacklistConfigEntry, err := b.configTidyRoleTags(req.Storage)
tidyBlacklistConfigEntry, err := b.lockedConfigTidyRoleTags(req.Storage)
if err != nil {
return err
}
@ -135,7 +134,7 @@ func (b *backend) periodicFunc(req *logical.Request) error {
// reset the safety_buffer to 72h
safety_buffer = 259200
tidyWhitelistConfigEntry, err := b.configTidyIdentities(req.Storage)
tidyWhitelistConfigEntry, err := b.lockedConfigTidyIdentities(req.Storage)
if err != nil {
return err
}
@ -161,7 +160,7 @@ func (b *backend) periodicFunc(req *logical.Request) error {
}
const backendHelp = `
AWS auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client
aws-ec2 auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client
created nonce to authenticates the EC2 instance with Vault.
Authentication is backed by a preconfigured role in the backend. The role

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"encoding/base64"
@ -6,75 +6,23 @@ import (
"os"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
logicaltest "github.com/hashicorp/vault/logical/testing"
)
func createBackend(conf *logical.BackendConfig) (*backend, error) {
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
HashFunc: salt.SHA256Hash,
})
if err != nil {
return nil, err
}
b := &backend{
// Setting the periodic func to be run once in an hour.
// If there is a real need, this can be made configurable.
tidyCooldownPeriod: time.Hour,
Salt: salt,
EC2ClientsMap: make(map[string]*ec2.EC2),
}
b.Backend = &framework.Backend{
PeriodicFunc: b.periodicFunc,
AuthRenew: b.pathLoginRenew,
Help: backendHelp,
PathsSpecial: &logical.Paths{
Unauthenticated: []string{
"login",
},
},
Paths: []*framework.Path{
pathLogin(b),
pathListRole(b),
pathListRoles(b),
pathRole(b),
pathRoleTag(b),
pathConfigClient(b),
pathConfigCertificate(b),
pathConfigTidyRoletagBlacklist(b),
pathConfigTidyIdentityWhitelist(b),
pathListCertificates(b),
pathListRoletagBlacklist(b),
pathRoletagBlacklist(b),
pathTidyRoletagBlacklist(b),
pathListIdentityWhitelist(b),
pathIdentityWhitelist(b),
pathTidyIdentityWhitelist(b),
},
}
return b, nil
}
func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
// create a backend
config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
@ -98,7 +46,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}
// read the created role entry
roleEntry, err := b.awsRole(storage, "abcd-123")
roleEntry, err := b.lockedAWSRole(storage, "abcd-123")
if err != nil {
t.Fatal(err)
}
@ -165,7 +113,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}
// get the entry of the newly created role entry
roleEntry2, err := b.awsRole(storage, "ami-6789")
roleEntry2, err := b.lockedAWSRole(storage, "ami-6789")
if err != nil {
t.Fatal(err)
}
@ -293,11 +241,11 @@ func TestBackend_ConfigTidyIdentities(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
@ -347,11 +295,11 @@ func TestBackend_ConfigTidyRoleTags(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
@ -401,11 +349,11 @@ func TestBackend_TidyIdentities(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
@ -426,11 +374,11 @@ func TestBackend_TidyRoleTags(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
@ -451,11 +399,11 @@ func TestBackend_ConfigClient(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
@ -588,11 +536,11 @@ func TestBackend_pathConfigCertificate(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
@ -739,11 +687,11 @@ func TestBackend_pathRole(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
@ -865,11 +813,11 @@ func TestBackend_parseAndVerifyRoleTagValue(t *testing.T) {
config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
@ -942,11 +890,11 @@ func TestBackend_PathRoleTag(t *testing.T) {
config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
@ -1006,11 +954,11 @@ func TestBackend_PathBlacklistRoleTag(t *testing.T) {
storage := &logical.InmemStorage{}
config := logical.TestBackendConfig()
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
@ -1098,7 +1046,7 @@ func TestBackend_PathBlacklistRoleTag(t *testing.T) {
}
// try to read the deleted entry
tagEntry, err := b.blacklistRoleTagEntry(storage, tag)
tagEntry, err := b.lockedBlacklistRoleTagEntry(storage, tag)
if err != nil {
t.Fatal(err)
}
@ -1135,11 +1083,11 @@ func TestBackendAcc_LoginAndWhitelistIdentity(t *testing.T) {
storage := &logical.InmemStorage{}
config := logical.TestBackendConfig()
config.StorageView = storage
b, err := createBackend(config)
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Backend.Setup(config)
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"fmt"
@ -24,7 +24,7 @@ func (b *backend) getClientConfig(s logical.Storage, region string) (*aws.Config
}
// Read the configured secret key and access key
config, err := b.clientConfigEntryInternal(s)
config, err := b.nonLockedClientConfigEntry(s)
if err != nil {
return nil, err
}

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"crypto/x509"
@ -96,7 +96,7 @@ func (b *backend) pathConfigCertificateExistenceCheck(req *logical.Request, data
return false, fmt.Errorf("missing cert_name")
}
entry, err := b.awsPublicCertificateEntry(req.Storage, certName)
entry, err := b.lockedAWSPublicCertificateEntry(req.Storage, certName)
if err != nil {
return false, err
}
@ -161,7 +161,7 @@ func (b *backend) awsPublicCertificates(s logical.Storage) ([]*x509.Certificate,
// Iterate through each certificate, parse and append it to a slice.
for _, cert := range registeredCerts {
certEntry, err := b.awsPublicCertificateEntryInternal(s, cert)
certEntry, err := b.nonLockedAWSPublicCertificateEntry(s, cert)
if err != nil {
return nil, err
}
@ -180,15 +180,15 @@ func (b *backend) awsPublicCertificates(s logical.Storage) ([]*x509.Certificate,
// awsPublicCertificate is used to get the configured AWS Public Key that is used
// to verify the PKCS#7 signature of the instance identity document.
func (b *backend) awsPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
func (b *backend) lockedAWSPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.awsPublicCertificateEntryInternal(s, certName)
return b.nonLockedAWSPublicCertificateEntry(s, certName)
}
// Internal version of the above that does no locking
func (b *backend) awsPublicCertificateEntryInternal(s logical.Storage, certName string) (*awsPublicCert, error) {
func (b *backend) nonLockedAWSPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
entry, err := s.Get("config/certificate/" + certName)
if err != nil {
return nil, err
@ -227,7 +227,7 @@ func (b *backend) pathConfigCertificateRead(
return logical.ErrorResponse("missing cert_name"), nil
}
certificateEntry, err := b.awsPublicCertificateEntry(req.Storage, certName)
certificateEntry, err := b.lockedAWSPublicCertificateEntry(req.Storage, certName)
if err != nil {
return nil, err
}
@ -253,7 +253,7 @@ func (b *backend) pathConfigCertificateCreateUpdate(
defer b.configMutex.Unlock()
// Check if there is already a certificate entry registered.
certEntry, err := b.awsPublicCertificateEntryInternal(req.Storage, certName)
certEntry, err := b.nonLockedAWSPublicCertificateEntry(req.Storage, certName)
if err != nil {
return nil, err
}

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"github.com/fatih/structs"
@ -48,23 +48,23 @@ func pathConfigClient(b *backend) *framework.Path {
func (b *backend) pathConfigClientExistenceCheck(
req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.clientConfigEntry(req.Storage)
entry, err := b.lockedClientConfigEntry(req.Storage)
if err != nil {
return false, err
}
return entry != nil, nil
}
// Fetch the client configuration required to access the AWS API.
func (b *backend) clientConfigEntry(s logical.Storage) (*clientConfig, error) {
// Fetch the client configuration required to access the AWS API, after acquiring an exclusive lock.
func (b *backend) lockedClientConfigEntry(s logical.Storage) (*clientConfig, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.clientConfigEntryInternal(s)
return b.nonLockedClientConfigEntry(s)
}
// Internal version that does no locking
func (b *backend) clientConfigEntryInternal(s logical.Storage) (*clientConfig, error) {
// Fetch the client configuration required to access the AWS API.
func (b *backend) nonLockedClientConfigEntry(s logical.Storage) (*clientConfig, error) {
entry, err := s.Get("config/client")
if err != nil {
return nil, err
@ -82,7 +82,7 @@ func (b *backend) clientConfigEntryInternal(s logical.Storage) (*clientConfig, e
func (b *backend) pathConfigClientRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
clientConfig, err := b.clientConfigEntry(req.Storage)
clientConfig, err := b.lockedClientConfigEntry(req.Storage)
if err != nil {
return nil, err
}
@ -118,7 +118,7 @@ func (b *backend) pathConfigClientCreateUpdate(
b.configMutex.Lock()
defer b.configMutex.Unlock()
configEntry, err := b.clientConfigEntryInternal(req.Storage)
configEntry, err := b.nonLockedClientConfigEntry(req.Storage)
if err != nil {
return nil, err
}
@ -193,7 +193,7 @@ Configure the client credentials that are used to query instance details from AW
`
const pathConfigClientHelpDesc = `
AWS auth backend makes DescribeInstances API call to retrieve information regarding
the instance that performs login. The aws_secret_key and aws_access_key registered with Vault should have the
permissions to make the API call.
aws-ec2 auth backend makes DescribeInstances API call to retrieve information regarding
the instance that performs login. The aws_secret_key and aws_access_key registered with
Vault should have the permissions to make the API call.
`

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"fmt"
@ -44,21 +44,21 @@ expiration, before it is removed from the backend storage.`,
}
func (b *backend) pathConfigTidyIdentityWhitelistExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.configTidyIdentities(req.Storage)
entry, err := b.lockedConfigTidyIdentities(req.Storage)
if err != nil {
return false, err
}
return entry != nil, nil
}
func (b *backend) configTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
func (b *backend) lockedConfigTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.configTidyIdentitiesInternal(s)
return b.nonLockedConfigTidyIdentities(s)
}
func (b *backend) configTidyIdentitiesInternal(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
func (b *backend) nonLockedConfigTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
entry, err := s.Get(identityWhitelistConfigPath)
if err != nil {
return nil, err
@ -78,7 +78,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(req *logical.Reque
b.configMutex.Lock()
defer b.configMutex.Unlock()
configEntry, err := b.configTidyIdentitiesInternal(req.Storage)
configEntry, err := b.nonLockedConfigTidyIdentities(req.Storage)
if err != nil {
return nil, err
}
@ -113,7 +113,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(req *logical.Reque
}
func (b *backend) pathConfigTidyIdentityWhitelistRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
clientConfig, err := b.configTidyIdentities(req.Storage)
clientConfig, err := b.lockedConfigTidyIdentities(req.Storage)
if err != nil {
return nil, err
}

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"fmt"
@ -46,21 +46,21 @@ Defaults to 4320h (180 days).`,
}
func (b *backend) pathConfigTidyRoletagBlacklistExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.configTidyRoleTags(req.Storage)
entry, err := b.lockedConfigTidyRoleTags(req.Storage)
if err != nil {
return false, err
}
return entry != nil, nil
}
func (b *backend) configTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
func (b *backend) lockedConfigTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.configTidyRoleTagsInternal(s)
return b.nonLockedConfigTidyRoleTags(s)
}
func (b *backend) configTidyRoleTagsInternal(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
func (b *backend) nonLockedConfigTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
entry, err := s.Get(roletagBlacklistConfigPath)
if err != nil {
return nil, err
@ -81,7 +81,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(req *logical.Reques
b.configMutex.Lock()
defer b.configMutex.Unlock()
configEntry, err := b.configTidyRoleTagsInternal(req.Storage)
configEntry, err := b.nonLockedConfigTidyRoleTags(req.Storage)
if err != nil {
return nil, err
}
@ -114,7 +114,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(req *logical.Reques
}
func (b *backend) pathConfigTidyRoletagBlacklistRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
clientConfig, err := b.configTidyRoleTags(req.Storage)
clientConfig, err := b.lockedConfigTidyRoleTags(req.Storage)
if err != nil {
return nil, err
}

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"time"

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"encoding/json"
@ -236,7 +236,7 @@ func (b *backend) pathLoginUpdate(
}
// Get the entry for the role used by the instance.
roleEntry, err := b.awsRole(req.Storage, roleName)
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
if err != nil {
return nil, err
}
@ -442,7 +442,7 @@ func (b *backend) handleRoleTagLogin(s logical.Storage, identityDoc *identityDoc
}
// Check if the role tag is blacklisted.
blacklistEntry, err := b.blacklistRoleTagEntry(s, rTagValue)
blacklistEntry, err := b.lockedBlacklistRoleTagEntry(s, rTagValue)
if err != nil {
return nil, err
}
@ -478,7 +478,7 @@ func (b *backend) pathLoginRenew(
// Cross check that the instance is still in 'running' state
_, err := b.validateInstance(req.Storage, instanceID, region)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to verify instance ID: %s", err)), nil
return nil, fmt.Errorf("failed to verify instance ID: %s", err)
}
storedIdentity, err := whitelistIdentityEntry(req.Storage, instanceID)
@ -487,12 +487,12 @@ func (b *backend) pathLoginRenew(
}
// Ensure that role entry is not deleted.
roleEntry, err := b.awsRole(req.Storage, storedIdentity.Role)
roleEntry, err := b.lockedAWSRole(req.Storage, storedIdentity.Role)
if err != nil {
return nil, err
}
if roleEntry == nil {
return logical.ErrorResponse("role entry not found"), nil
return nil, fmt.Errorf("role entry not found")
}
// If the login was made using the role tag, then max_ttl from tag

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"fmt"
@ -54,7 +54,7 @@ using the AMI ID specified by this parameter.`,
"disallow_reauthentication": &framework.FieldSchema{
Type: framework.TypeBool,
Default: false,
Description: "If set, only allows a single token to be granted per instance ID. In order to perform a fresh login, the entry in whitelist for the instance ID needs to be cleared using 'auth/aws/identity-whitelist/<instance_id>' endpoint.",
Description: "If set, only allows a single token to be granted per instance ID. In order to perform a fresh login, the entry in whitelist for the instance ID needs to be cleared using 'auth/aws-ec2/identity-whitelist/<instance_id>' endpoint.",
},
},
@ -101,7 +101,7 @@ func pathListRoles(b *backend) *framework.Path {
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
func (b *backend) pathRoleExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.awsRole(req.Storage, strings.ToLower(data.Get("role").(string)))
entry, err := b.lockedAWSRole(req.Storage, strings.ToLower(data.Get("role").(string)))
if err != nil {
return false, err
}
@ -109,14 +109,14 @@ func (b *backend) pathRoleExistenceCheck(req *logical.Request, data *framework.F
}
// awsRole is used to get the information registered for the given AMI ID.
func (b *backend) awsRole(s logical.Storage, role string) (*awsRoleEntry, error) {
func (b *backend) lockedAWSRole(s logical.Storage, role string) (*awsRoleEntry, error) {
b.roleMutex.RLock()
defer b.roleMutex.RUnlock()
return b.awsRoleInternal(s, role)
return b.nonLockedAWSRole(s, role)
}
func (b *backend) awsRoleInternal(s logical.Storage, role string) (*awsRoleEntry, error) {
func (b *backend) nonLockedAWSRole(s logical.Storage, role string) (*awsRoleEntry, error) {
entry, err := s.Get("role/" + strings.ToLower(role))
if err != nil {
return nil, err
@ -162,7 +162,7 @@ func (b *backend) pathRoleList(
// pathRoleRead is used to view the information registered for a given AMI ID.
func (b *backend) pathRoleRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
roleEntry, err := b.awsRole(req.Storage, strings.ToLower(data.Get("role").(string)))
roleEntry, err := b.lockedAWSRole(req.Storage, strings.ToLower(data.Get("role").(string)))
if err != nil {
return nil, err
}
@ -196,7 +196,7 @@ func (b *backend) pathRoleCreateUpdate(
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
roleEntry, err := b.awsRoleInternal(req.Storage, roleName)
roleEntry, err := b.nonLockedAWSRole(req.Storage, roleName)
if err != nil {
return nil, err
}

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"crypto/hmac"
@ -54,7 +54,7 @@ If set, the created tag can only be used by the instance with the given ID.`,
"disallow_reauthentication": &framework.FieldSchema{
Type: framework.TypeBool,
Default: false,
Description: "If set, only allows a single token to be granted per instance ID. In order to perform a fresh login, the entry in whitelist for the instance ID needs to be cleared using the 'auth/aws/identity-whitelist/<instance_id>' endpoint.",
Description: "If set, only allows a single token to be granted per instance ID. In order to perform a fresh login, the entry in whitelist for the instance ID needs to be cleared using the 'auth/aws-ec2/identity-whitelist/<instance_id>' endpoint.",
},
},
@ -78,7 +78,7 @@ func (b *backend) pathRoleTagUpdate(
}
// Fetch the role entry
roleEntry, err := b.awsRole(req.Storage, roleName)
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
if err != nil {
return nil, err
}
@ -346,7 +346,7 @@ func (b *backend) parseAndVerifyRoleTagValue(s logical.Storage, tag string) (*ro
return nil, fmt.Errorf("missing role name")
}
roleEntry, err := b.awsRole(s, rTag.Role)
roleEntry, err := b.lockedAWSRole(s, rTag.Role)
if err != nil {
return nil, err
}

View file

@ -1,4 +1,4 @@
package aws
package awsec2
import (
"encoding/base64"
@ -72,14 +72,14 @@ func (b *backend) pathRoletagBlacklistsList(
// Fetch an entry from the role tag blacklist for a given tag.
// This method takes a role tag in its original form and not a base64 encoded form.
func (b *backend) blacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
func (b *backend) lockedBlacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
b.blacklistMutex.RLock()
defer b.blacklistMutex.RUnlock()
return b.blacklistRoleTagEntryInternal(s, tag)
return b.nonLockedBlacklistRoleTagEntry(s, tag)
}
func (b *backend) blacklistRoleTagEntryInternal(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
func (b *backend) nonLockedBlacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
entry, err := s.Get("blacklist/roletag/" + base64.StdEncoding.EncodeToString([]byte(tag)))
if err != nil {
return nil, err
@ -119,7 +119,7 @@ func (b *backend) pathRoletagBlacklistRead(
return logical.ErrorResponse("missing role_tag"), nil
}
entry, err := b.blacklistRoleTagEntry(req.Storage, tag)
entry, err := b.lockedBlacklistRoleTagEntry(req.Storage, tag)
if err != nil {
return nil, err
}
@ -166,7 +166,7 @@ func (b *backend) pathRoletagBlacklistUpdate(
}
// Get the entry for the role mentioned in the role tag.
roleEntry, err := b.awsRole(req.Storage, rTag.Role)
roleEntry, err := b.lockedAWSRole(req.Storage, rTag.Role)
if err != nil {
return nil, err
}
@ -178,7 +178,7 @@ func (b *backend) pathRoletagBlacklistUpdate(
defer b.blacklistMutex.Unlock()
// Check if the role tag is already blacklisted. If yes, update it.
blEntry, err := b.blacklistRoleTagEntryInternal(req.Storage, tag)
blEntry, err := b.nonLockedBlacklistRoleTagEntry(req.Storage, tag)
if err != nil {
return nil, err
}

View file

@ -101,15 +101,6 @@ func connectionState(t *testing.T, serverCAPath, serverCertPath, serverKeyPath,
return connState
}
func failOnError(t *testing.T, resp *logical.Response, err error) {
if resp != nil && resp.IsError() {
t.Fatalf("error returned in response: %s", resp.Data["error"])
}
if err != nil {
t.Fatal(err)
}
}
func TestBackend_RegisteredNonCA_CRL(t *testing.T) {
config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
@ -140,7 +131,9 @@ func TestBackend_RegisteredNonCA_CRL(t *testing.T) {
}
resp, err := b.HandleRequest(certReq)
failOnError(t, resp, err)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Connection state is presenting the client Non-CA cert and its key.
// This is exactly what is registered at the backend.
@ -155,7 +148,9 @@ func TestBackend_RegisteredNonCA_CRL(t *testing.T) {
}
// Login should succeed.
resp, err = b.HandleRequest(loginReq)
failOnError(t, resp, err)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Register a CRL containing the issued client certificate used above.
issuedCRL, err := ioutil.ReadFile(testIssuedCertCRL)
@ -172,7 +167,9 @@ func TestBackend_RegisteredNonCA_CRL(t *testing.T) {
Data: crlData,
}
resp, err = b.HandleRequest(crlReq)
failOnError(t, resp, err)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Attempt login with the same connection state but with the CRL registered
resp, err = b.HandleRequest(loginReq)
@ -214,7 +211,9 @@ func TestBackend_CRLs(t *testing.T) {
}
resp, err := b.HandleRequest(certReq)
failOnError(t, resp, err)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Connection state is presenting the client CA cert and its key.
// This is exactly what is registered at the backend.
@ -228,7 +227,9 @@ func TestBackend_CRLs(t *testing.T) {
},
}
resp, err = b.HandleRequest(loginReq)
failOnError(t, resp, err)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Now, without changing the registered client CA cert, present from
// the client side, a cert issued using the registered CA.
@ -237,7 +238,9 @@ func TestBackend_CRLs(t *testing.T) {
// Attempt login with the updated connection
resp, err = b.HandleRequest(loginReq)
failOnError(t, resp, err)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Register a CRL containing the issued client certificate used above.
issuedCRL, err := ioutil.ReadFile(testIssuedCertCRL)
@ -255,7 +258,9 @@ func TestBackend_CRLs(t *testing.T) {
Data: crlData,
}
resp, err = b.HandleRequest(crlReq)
failOnError(t, resp, err)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Attempt login with the revoked certificate.
resp, err = b.HandleRequest(loginReq)
@ -273,7 +278,9 @@ func TestBackend_CRLs(t *testing.T) {
}
certData["certificate"] = clientCA2
resp, err = b.HandleRequest(certReq)
failOnError(t, resp, err)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Test login using a different client CA cert pair.
connState = connectionState(t, serverCAPath, serverCertPath, serverKeyPath, testRootCACertPath2, testRootCAKeyPath2)
@ -281,7 +288,9 @@ func TestBackend_CRLs(t *testing.T) {
// Attempt login with the updated connection
resp, err = b.HandleRequest(loginReq)
failOnError(t, resp, err)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Register a CRL containing the root CA certificate used above.
rootCRL, err := ioutil.ReadFile(testRootCertCRL)
@ -290,7 +299,9 @@ func TestBackend_CRLs(t *testing.T) {
}
crlData["crl"] = rootCRL
resp, err = b.HandleRequest(crlReq)
failOnError(t, resp, err)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Attempt login with the same connection state but with the CRL registered
resp, err = b.HandleRequest(loginReq)
@ -754,13 +765,7 @@ func Test_Renew(t *testing.T) {
}
resp, err = b.pathLoginRenew(req, nil)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response from renew")
}
if !resp.IsError() {
if err == nil {
t.Fatal("expected error")
}

View file

@ -7,6 +7,7 @@ import (
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"strings"
"github.com/hashicorp/vault/helper/certutil"
@ -106,7 +107,7 @@ func (b *backend) pathLoginRenew(
clientCerts := req.Connection.ConnState.PeerCertificates
if len(clientCerts) == 0 {
return logical.ErrorResponse("no client certificate found"), nil
return nil, fmt.Errorf("no client certificate found")
}
skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId)
akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId)
@ -114,7 +115,7 @@ func (b *backend) pathLoginRenew(
// Certificate should not only match a registered certificate policy.
// Also, the identity of the certificate presented should match the identity of the certificate used during login
if req.Auth.InternalData["subject_key_id"] != skid && req.Auth.InternalData["authority_key_id"] != akid {
return logical.ErrorResponse("client identity during renewal not matching client identity used during login"), nil
return nil, fmt.Errorf("client identity during renewal not matching client identity used during login")
}
}
@ -129,7 +130,7 @@ func (b *backend) pathLoginRenew(
}
if !policyutil.EquivalentPolicies(cert.Policies, req.Auth.Policies) {
return logical.ErrorResponse("policies have changed, not renewing"), nil
return nil, fmt.Errorf("policies have changed, not renewing")
}
return framework.LeaseExtend(cert.TTL, 0, b.System())(req, d)

View file

@ -84,7 +84,7 @@ func (b *backend) pathLoginRenew(
verifyResp = verifyResponse
}
if !policyutil.EquivalentPolicies(verifyResp.Policies, req.Auth.Policies) {
return logical.ErrorResponse("policies do not match"), nil
return nil, fmt.Errorf("policies do not match")
}
config, err := b.Config(req.Storage)

View file

@ -3,6 +3,8 @@ package ldap
import (
"fmt"
"strings"
"github.com/go-ldap/ldap"
"github.com/hashicorp/vault/helper/mfa"
"github.com/hashicorp/vault/logical"
@ -19,13 +21,7 @@ func Backend() *framework.Backend {
Help: backendHelp,
PathsSpecial: &logical.Paths{
Root: append([]string{
"config",
"groups/*",
"users/*",
},
mfa.MFARootPaths()...,
),
Root: mfa.MFARootPaths(),
Unauthenticated: []string{
"login/*",
@ -35,7 +31,9 @@ func Backend() *framework.Backend {
Paths: append([]*framework.Path{
pathConfig(&b),
pathGroups(&b),
pathGroupsList(&b),
pathUsers(&b),
pathUsersList(&b),
},
mfa.MFAPaths(b.Backend, pathLogin(&b))...,
),
@ -101,90 +99,50 @@ func (b *backend) Login(req *logical.Request, username string, password string)
if c == nil {
return nil, logical.ErrorResponse("invalid connection returned from LDAP dial"), nil
}
binddn := ""
if cfg.DiscoverDN || (cfg.BindDN != "" && cfg.BindPassword != "") {
if err = c.Bind(cfg.BindDN, cfg.BindPassword); err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP bind (service) failed: %v", err)), nil
}
sresult, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Filter: fmt.Sprintf("(%s=%s)", cfg.UserAttr, ldap.EscapeFilter(username)),
})
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP search for binddn failed: %v", err)), nil
}
if len(sresult.Entries) != 1 {
return nil, logical.ErrorResponse("LDAP search for binddn 0 or not uniq"), nil
}
binddn = sresult.Entries[0].DN
} else {
if cfg.UPNDomain != "" {
binddn = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain)
} else {
binddn = fmt.Sprintf("%s=%s,%s", cfg.UserAttr, EscapeLDAPValue(username), cfg.UserDN)
}
bindDN, err := getBindDN(cfg, c, username)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil
}
if err = c.Bind(binddn, password); err != nil {
if err = c.Bind(bindDN, password); err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP bind failed: %v", err)), nil
}
userdn := ""
if cfg.UPNDomain != "" {
// Find the distinguished name for the user if userPrincipalName used for login
sresult, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Filter: fmt.Sprintf("(userPrincipalName=%s)", ldap.EscapeFilter(binddn)),
})
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP search failed: %v", err)), nil
}
for _, e := range sresult.Entries {
userdn = e.DN
}
} else {
userdn = binddn
userDN, err := getUserDN(cfg, c, bindDN)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil
}
var allgroups []string
var policies []string
resp := &logical.Response{
ldapGroups, err := getLdapGroups(cfg, c, userDN, username)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil
}
ldapResponse := &logical.Response{
Data: map[string]interface{}{},
}
if len(ldapGroups) == 0 {
errString := fmt.Sprintf(
"no LDAP groups found in userDN '%s' or groupDN '%s';only policies from locally-defined groups available",
cfg.UserDN,
cfg.GroupDN)
ldapResponse.AddWarning(errString)
}
// Fetch custom (local) groups the user has been added to
var allGroups []string
// Import the custom added groups from ldap backend
user, err := b.User(req.Storage, username)
if err == nil && user != nil {
allgroups = append(allgroups, user.Groups...)
allGroups = append(allGroups, user.Groups...)
}
// add the LDAP groups
allGroups = append(allGroups, ldapGroups...)
if cfg.GroupDN != "" {
// Enumerate all groups the user is member of. The search filter should
// work with both openldap and MS AD standard schemas.
sresult, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.GroupDN,
Scope: 2, // subtree
Filter: fmt.Sprintf("(|(memberUid=%s)(member=%s)(uniqueMember=%s))", ldap.EscapeFilter(username), ldap.EscapeFilter(userdn), ldap.EscapeFilter(userdn)),
})
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP search failed: %v", err)), nil
}
for _, e := range sresult.Entries {
dn, err := ldap.ParseDN(e.DN)
if err != nil || len(dn.RDNs) == 0 || len(dn.RDNs[0].Attributes) == 0 {
continue
}
gname := dn.RDNs[0].Attributes[0].Value
allgroups = append(allgroups, gname)
}
} else {
resp.AddWarning("no group DN configured; only policies from locally-defined groups available")
}
for _, gname := range allgroups {
group, err := b.Group(req.Storage, gname)
// Retrieve policies
var policies []string
for _, groupName := range allGroups {
group, err := b.Group(req.Storage, groupName)
if err == nil && group != nil {
policies = append(policies, group.Policies...)
}
@ -192,15 +150,140 @@ func (b *backend) Login(req *logical.Request, username string, password string)
if len(policies) == 0 {
errStr := "user is not a member of any authorized group"
if len(resp.Warnings()) > 0 {
errStr = fmt.Sprintf("%s; additionally, %s", errStr, resp.Warnings()[0])
if len(ldapResponse.Warnings()) > 0 {
errStr = fmt.Sprintf("%s; additionally, %s", errStr, ldapResponse.Warnings()[0])
}
resp.Data["error"] = errStr
return nil, resp, nil
ldapResponse.Data["error"] = errStr
return nil, ldapResponse, nil
}
return policies, resp, nil
return policies, ldapResponse, nil
}
func getBindDN(cfg *ConfigEntry, c *ldap.Conn, username string) (string, error) {
bindDN := ""
if cfg.DiscoverDN || (cfg.BindDN != "" && cfg.BindPassword != "") {
if err := c.Bind(cfg.BindDN, cfg.BindPassword); err != nil {
return bindDN, fmt.Errorf("LDAP bind (service) failed: %v", err)
}
result, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Filter: fmt.Sprintf("(%s=%s)", cfg.UserAttr, ldap.EscapeFilter(username)),
})
if err != nil {
return bindDN, fmt.Errorf("LDAP search for binddn failed: %v", err)
}
if len(result.Entries) != 1 {
return bindDN, fmt.Errorf("LDAP search for binddn 0 or not unique")
}
bindDN = result.Entries[0].DN
} else {
if cfg.UPNDomain != "" {
bindDN = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain)
} else {
bindDN = fmt.Sprintf("%s=%s,%s", cfg.UserAttr, EscapeLDAPValue(username), cfg.UserDN)
}
}
return bindDN, nil
}
func getUserDN(cfg *ConfigEntry, c *ldap.Conn, bindDN string) (string, error) {
userDN := ""
if cfg.UPNDomain != "" {
// Find the distinguished name for the user if userPrincipalName used for login
result, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Filter: fmt.Sprintf("(userPrincipalName=%s)", ldap.EscapeFilter(bindDN)),
})
if err != nil {
return userDN, fmt.Errorf("LDAP search failed for detecting user: %v", err)
}
for _, e := range result.Entries {
userDN = e.DN
}
} else {
userDN = bindDN
}
return userDN, nil
}
func getLdapGroups(cfg *ConfigEntry, c *ldap.Conn, userDN string, username string) ([]string, error) {
// retrieve the groups in a string/bool map as a structure to avoid duplicates inside
ldapMap := make(map[string]bool)
// Fetch the optional memberOf property values on the user object
// This is the most common method used in Active Directory setup to retrieve the groups
result, err := c.Search(&ldap.SearchRequest{
BaseDN: userDN,
Scope: 0, // base scope to fetch only the userDN
Filter: "(cn=*)", // bogus filter, required to fetch the CN from userDN
Attributes: []string{
"memberOf",
},
})
// this check remains in case something happens with the ldap query or connection
if err != nil {
return nil, fmt.Errorf("LDAP fetch of distinguishedName=%s failed: %v", userDN, err)
}
// if there are more than one entry, we consider the results irrelevant and ignore them
if len(result.Entries) == 1 {
for _, attr := range result.Entries[0].Attributes {
// Find the groups the user is member of from the 'memberOf' attribute extracting the CN
if attr.Name == "memberOf" {
for _, value := range attr.Values {
memberOfDN, err := ldap.ParseDN(value)
if err != nil || len(memberOfDN.RDNs) == 0 {
continue
}
for _, rdn := range memberOfDN.RDNs {
for _, rdnTypeAndValue := range rdn.Attributes {
if strings.EqualFold(rdnTypeAndValue.Type, "CN") {
ldapMap[rdnTypeAndValue.Value] = true
}
}
}
}
}
}
}
// Find groups by searching in groupDN for any of the memberUid, member or uniqueMember attributes
// and retrieving the CN in the DN result
if cfg.GroupDN != "" {
result, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.GroupDN,
Scope: 2, // subtree
Filter: fmt.Sprintf("(|(memberUid=%s)(member=%s)(uniqueMember=%s))", ldap.EscapeFilter(username), ldap.EscapeFilter(userDN), ldap.EscapeFilter(userDN)),
})
if err != nil {
return nil, fmt.Errorf("LDAP search failed: %v", err)
}
for _, e := range result.Entries {
dn, err := ldap.ParseDN(e.DN)
if err != nil || len(dn.RDNs) == 0 {
continue
}
for _, rdn := range dn.RDNs {
for _, rdnTypeAndValue := range rdn.Attributes {
if strings.EqualFold(rdnTypeAndValue.Type, "CN") {
ldapMap[rdnTypeAndValue.Value] = true
}
}
}
}
}
ldapGroups := make([]string, len(ldapMap))
for key, _ := range ldapMap {
ldapGroups = append(ldapGroups, key)
}
return ldapGroups, nil
}
const backendHelp = `

View file

@ -2,6 +2,7 @@ package ldap
import (
"fmt"
"reflect"
"testing"
"time"
@ -38,6 +39,8 @@ func TestBackend_basic(t *testing.T) {
testAccStepGroup(t, "engineers", "bar"),
testAccStepUser(t, "tesla", "engineers"),
testAccStepLogin(t, "tesla", "password"),
testAccStepGroupList(t, []string{"engineers", "scientists"}),
testAccStepUserList(t, []string{"tesla"}),
},
})
}
@ -321,3 +324,39 @@ func TestLDAPEscape(t *testing.T) {
}
}
}
func testAccStepGroupList(t *testing.T, groups []string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ListOperation,
Path: "groups",
Check: func(resp *logical.Response) error {
if resp.IsError() {
return fmt.Errorf("Got error response: %#v", *resp)
}
exp := groups
if !reflect.DeepEqual(exp, resp.Data["keys"].([]string)) {
return fmt.Errorf("expected:\n%#v\ngot:\n%#v\n", exp, resp.Data["keys"])
}
return nil
},
}
}
func testAccStepUserList(t *testing.T, users []string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ListOperation,
Path: "users",
Check: func(resp *logical.Response) error {
if resp.IsError() {
return fmt.Errorf("Got error response: %#v", *resp)
}
exp := users
if !reflect.DeepEqual(exp, resp.Data["keys"].([]string)) {
return fmt.Errorf("expected:\n%#v\ngot:\n%#v\n", exp, resp.Data["keys"])
}
return nil
},
}
}

View file

@ -8,6 +8,19 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func pathGroupsList(b *backend) *framework.Path {
return &framework.Path{
Pattern: "groups/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathGroupList,
},
HelpSynopsis: pathGroupHelpSyn,
HelpDescription: pathGroupHelpDesc,
}
}
func pathGroups(b *backend) *framework.Path {
return &framework.Path{
Pattern: `groups/(?P<name>.+)`,
@ -94,6 +107,15 @@ func (b *backend) pathGroupWrite(
return nil, nil
}
func (b *backend) pathGroupList(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
groups, err := req.Storage.List("group/")
if err != nil {
return nil, err
}
return logical.ListResponse(groups), nil
}
type GroupEntry struct {
Policies []string
}

View file

@ -1,6 +1,7 @@
package ldap
import (
"fmt"
"sort"
"strings"
@ -83,7 +84,7 @@ func (b *backend) pathLoginRenew(
}
if !policyutil.EquivalentPolicies(loginPolicies, req.Auth.Policies) {
return logical.ErrorResponse("policies have changed, not renewing"), nil
return nil, fmt.Errorf("policies have changed, not renewing")
}
return framework.LeaseExtend(0, 0, b.System())(req, d)

View file

@ -7,6 +7,19 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func pathUsersList(b *backend) *framework.Path {
return &framework.Path{
Pattern: "users/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathUserList,
},
HelpSynopsis: pathUserHelpSyn,
HelpDescription: pathUserHelpDesc,
}
}
func pathUsers(b *backend) *framework.Path {
return &framework.Path{
Pattern: `users/(?P<name>.+)`,
@ -25,7 +38,7 @@ func pathUsers(b *backend) *framework.Path {
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.DeleteOperation: b.pathUserDelete,
logical.ReadOperation: b.pathUserRead,
logical.UpdateOperation: b.pathUserWrite,
logical.UpdateOperation: b.pathUserWrite,
},
HelpSynopsis: pathUserHelpSyn,
@ -99,6 +112,15 @@ func (b *backend) pathUserWrite(
return nil, nil
}
func (b *backend) pathUserList(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
users, err := req.Storage.List("user/")
if err != nil {
return nil, err
}
return logical.ListResponse(users), nil
}
type UserEntry struct {
Groups []string
}

View file

@ -94,7 +94,7 @@ func (b *backend) pathLoginRenew(
}
if !policyutil.EquivalentPolicies(user.Policies, req.Auth.Policies) {
return logical.ErrorResponse("policies have changed, not renewing"), nil
return nil, fmt.Errorf("policies have changed, not renewing")
}
return framework.LeaseExtend(user.TTL, user.MaxTTL, b.System())(req, d)

View file

@ -17,12 +17,6 @@ func Backend() *framework.Backend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
Root: []string{
"config/*",
},
},
Paths: []*framework.Path{
pathConfigRoot(),
pathConfigLease(&b),

View file

@ -6,6 +6,7 @@ import (
"fmt"
"log"
"os"
"strings"
"testing"
"time"
@ -13,6 +14,7 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
@ -40,15 +42,21 @@ func TestBackend_basic(t *testing.T) {
func TestBackend_basicSTS(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
AcceptanceTest: true,
PreCheck: func() { testAccPreCheck(t) },
Backend: getBackend(t),
PreCheck: func() {
testAccPreCheck(t)
createRole(t)
},
Backend: getBackend(t),
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepWritePolicy(t, "test", testPolicy),
testAccStepReadSTS(t, "test"),
testAccStepWriteArnPolicyRef(t, "test", testPolicyArn),
testAccStepReadSTSWithArnPolicy(t, "test"),
testAccStepWriteArnRoleRef(t, testRoleName),
testAccStepReadSTS(t, testRoleName),
},
Teardown: teardown,
})
}
@ -84,6 +92,123 @@ func testAccPreCheck(t *testing.T) {
log.Println("[INFO] Test: Using us-west-2 as test region")
os.Setenv("AWS_DEFAULT_REGION", "us-west-2")
}
if v := os.Getenv("AWS_ACCOUNT_ID"); v == "" {
accountId, err := getAccountId()
if err != nil {
t.Fatal("AWS_ACCOUNT_ID could not be read from iam:GetUser for acceptance tests")
}
log.Printf("[INFO] Test: Used %s as AWS_ACCOUNT_ID", accountId)
os.Setenv("AWS_ACCOUNT_ID", accountId)
}
}
func getAccountId() (string, error) {
creds := credentials.NewStaticCredentials(os.Getenv("AWS_ACCESS_KEY_ID"),
os.Getenv("AWS_SECRET_ACCESS_KEY"),
"")
awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
svc := iam.New(session.New(awsConfig))
params := &iam.GetUserInput{}
res, err := svc.GetUser(params)
if err != nil {
return "", err
}
// split "arn:aws:iam::012345678912:user/username"
accountId := strings.Split(*res.User.Arn, ":")[4]
return accountId, nil
}
const testRoleName = "Vault-Acceptance-Test-AWS-Assume-Role"
func createRole(t *testing.T) {
const testRoleAssumePolicy = `{
"Version": "2012-10-17",
"Statement": [
{
"Effect":"Allow",
"Principal": {
"AWS": "arn:aws:iam::%s:root"
},
"Action": "sts:AssumeRole"
}
]
}
`
creds := credentials.NewStaticCredentials(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), "")
awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
svc := iam.New(session.New(awsConfig))
trustPolicy := fmt.Sprintf(testRoleAssumePolicy, os.Getenv("AWS_ACCOUNT_ID"))
params := &iam.CreateRoleInput{
AssumeRolePolicyDocument: aws.String(trustPolicy),
RoleName: aws.String(testRoleName),
Path: aws.String("/"),
}
log.Printf("[INFO] AWS CreateRole: %s", testRoleName)
_, err := svc.CreateRole(params)
if err != nil {
t.Fatal("AWS CreateRole failed: %v", err)
}
attachment := &iam.AttachRolePolicyInput{
PolicyArn: aws.String(testPolicyArn),
RoleName: aws.String(testRoleName), // Required
}
_, err = svc.AttachRolePolicy(attachment)
if err != nil {
t.Fatal("AWS CreateRole failed: %v", err)
}
// Sleep sometime because AWS is eventually consistent
log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...")
time.Sleep(10 * time.Second)
}
func teardown() error {
creds := credentials.NewStaticCredentials(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), "")
awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
svc := iam.New(session.New(awsConfig))
attachment := &iam.DetachRolePolicyInput{
PolicyArn: aws.String(testPolicyArn),
RoleName: aws.String(testRoleName), // Required
}
_, err := svc.DetachRolePolicy(attachment)
params := &iam.DeleteRoleInput{
RoleName: aws.String(testRoleName),
}
log.Printf("[INFO] AWS DeleteRole: %s", testRoleName)
_, err = svc.DeleteRole(params)
if err != nil {
log.Printf("[WARN] AWS DeleteRole failed: %v", err)
}
return err
}
func testAccStepConfig(t *testing.T) logicaltest.TestStep {
@ -178,7 +303,7 @@ func testAccStepReadSTSWithArnPolicy(t *testing.T, name string) logicaltest.Test
ErrorOk: true,
Check: func(resp *logical.Response) error {
if resp.Data["error"] !=
"Can't generate STS credentials for a managed policy; use an inline policy instead" {
"Can't generate STS credentials for a managed policy; use a role to assume or an inline policy instead" {
t.Fatalf("bad: %v", resp)
}
return nil
@ -317,3 +442,13 @@ func testAccStepReadArnPolicy(t *testing.T, name string, value string) logicalte
},
}
}
func testAccStepWriteArnRoleRef(t *testing.T, roleName string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "roles/" + roleName,
Data: map[string]interface{}{
"arn": fmt.Sprintf("arn:aws:iam::%s:role/%s", os.Getenv("AWS_ACCOUNT_ID"), roleName),
},
}
}

View file

@ -48,15 +48,23 @@ func (b *backend) pathSTSRead(
}
policyValue := string(policy.Value)
if strings.HasPrefix(policyValue, "arn:") {
return logical.ErrorResponse(
"Can't generate STS credentials for a managed policy; use an inline policy instead"),
logical.ErrInvalidRequest
if strings.Contains(policyValue, ":role/") {
return b.assumeRole(
req.Storage,
req.DisplayName, policyName, policyValue,
ttl,
)
} else {
return logical.ErrorResponse(
"Can't generate STS credentials for a managed policy; use a role to assume or an inline policy instead"),
logical.ErrInvalidRequest
}
}
// Use the helper to create the secret
return b.secretTokenCreate(
req.Storage,
req.DisplayName, policyName, policyValue,
&ttl,
ttl,
)
}

View file

@ -71,7 +71,7 @@ func genUsername(displayName, policyName, userType string) (ret string, warning
func (b *backend) secretTokenCreate(s logical.Storage,
displayName, policyName, policy string,
lifeTimeInSeconds *int64) (*logical.Response, error) {
lifeTimeInSeconds int64) (*logical.Response, error) {
STSClient, err := clientSTS(s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
@ -83,7 +83,7 @@ func (b *backend) secretTokenCreate(s logical.Storage,
&sts.GetFederationTokenInput{
Name: aws.String(username),
Policy: aws.String(policy),
DurationSeconds: lifeTimeInSeconds,
DurationSeconds: &lifeTimeInSeconds,
})
if err != nil {
@ -111,6 +111,48 @@ func (b *backend) secretTokenCreate(s logical.Storage,
return resp, nil
}
func (b *backend) assumeRole(s logical.Storage,
displayName, policyName, policy string,
lifeTimeInSeconds int64) (*logical.Response, error) {
STSClient, err := clientSTS(s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
username, usernameWarning := genUsername(displayName, policyName, "iam_user")
tokenResp, err := STSClient.AssumeRole(
&sts.AssumeRoleInput{
RoleSessionName: aws.String(username),
RoleArn: aws.String(policy),
DurationSeconds: &lifeTimeInSeconds,
})
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error assuming role: %s", err)), nil
}
resp := b.Secret(SecretAccessKeyType).Response(map[string]interface{}{
"access_key": *tokenResp.Credentials.AccessKeyId,
"secret_key": *tokenResp.Credentials.SecretAccessKey,
"security_token": *tokenResp.Credentials.SessionToken,
}, map[string]interface{}{
"username": username,
"policy": policy,
"is_sts": true,
})
// Set the secret TTL to appropriately match the expiration of the token
resp.Secret.TTL = tokenResp.Credentials.Expiration.Sub(time.Now())
if usernameWarning != "" {
resp.AddWarning(usernameWarning)
}
return resp, nil
}
func (b *backend) secretAccessKeysCreate(
s logical.Storage,
displayName, policyName string, policy string) (*logical.Response, error) {

View file

@ -9,15 +9,9 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
return Backend().Setup(conf)
}
func Backend() *framework.Backend {
func Backend() *backend {
var b backend
b.Backend = &framework.Backend{
PathsSpecial: &logical.Paths{
Root: []string{
"config/*",
},
},
Paths: []*framework.Path{
pathConfigAccess(),
pathRoles(),
@ -29,7 +23,7 @@ func Backend() *framework.Backend {
},
}
return b.Backend
return &b
}
type backend struct {

View file

@ -8,6 +8,7 @@ import (
"log"
"os"
"os/exec"
"reflect"
"strings"
"testing"
"time"
@ -18,6 +19,55 @@ import (
"github.com/mitchellh/mapstructure"
)
func TestBackend_config_access(t *testing.T) {
if os.Getenv(logicaltest.TestEnvVar) == "" {
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
return
}
accessConfig, process := testStartConsulServer(t)
defer testStopConsulServer(t, process)
config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage
b := Backend()
_, err := b.Setup(config)
if err != nil {
t.Fatal(err)
}
confReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/access",
Storage: storage,
Data: accessConfig,
}
resp, err := b.HandleRequest(confReq)
if err != nil || (resp != nil && resp.IsError()) || resp != nil {
t.Fatalf("failed to write configuration: resp:%#v err:%s", resp, err)
}
confReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(confReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("failed to write configuration: resp:%#v err:%s", resp, err)
}
expected := map[string]interface{}{
"address": "127.0.0.1:8500",
"scheme": "http",
}
if !reflect.DeepEqual(expected, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
}
if resp.Data["token"] != nil {
t.Fatalf("token should not be set in the response")
}
}
func TestBackend_basic(t *testing.T) {
if os.Getenv(logicaltest.TestEnvVar) == "" {
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
@ -40,6 +90,106 @@ func TestBackend_basic(t *testing.T) {
})
}
func TestBackend_renew_revoke(t *testing.T) {
if os.Getenv(logicaltest.TestEnvVar) == "" {
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
return
}
config, process := testStartConsulServer(t)
defer testStopConsulServer(t, process)
beConfig := logical.TestBackendConfig()
beConfig.StorageView = &logical.InmemStorage{}
b, _ := Factory(beConfig)
req := &logical.Request{
Storage: beConfig.StorageView,
Operation: logical.UpdateOperation,
Path: "config/access",
Data: config,
}
resp, err := b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
req.Path = "roles/test"
req.Data = map[string]interface{}{
"policy": base64.StdEncoding.EncodeToString([]byte(testPolicy)),
"lease": "6h",
}
resp, err = b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
req.Operation = logical.ReadOperation
req.Path = "creds/test"
resp, err = b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
if resp == nil || resp.IsError() {
t.Fatal("resp nil or error")
}
generatedSecret := resp.Secret
generatedSecret.IssueTime = time.Now()
generatedSecret.TTL = 6 * time.Hour
var d struct {
Token string `mapstructure:"token"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
t.Fatal(err)
}
log.Printf("[WARN] Generated token: %s", d.Token)
// Build a client and verify that the credentials work
apiConfig := api.DefaultConfig()
apiConfig.Address = config["address"].(string)
apiConfig.Token = d.Token
client, err := api.NewClient(apiConfig)
if err != nil {
t.Fatal(err)
}
log.Printf("[WARN] Verifying that the generated token works...")
_, err = client.KV().Put(&api.KVPair{
Key: "foo",
Value: []byte("bar"),
}, nil)
if err != nil {
t.Fatal(err)
}
req.Operation = logical.RenewOperation
req.Secret = generatedSecret
resp, err = b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response from renew")
}
req.Operation = logical.RevokeOperation
resp, err = b.HandleRequest(req)
if err != nil {
t.Fatal(err)
}
log.Printf("[WARN] Verifying that the generated token does not work...")
_, err = client.KV().Put(&api.KVPair{
Key: "foo",
Value: []byte("bar"),
}, nil)
if err == nil {
t.Fatal("expected error")
}
}
func TestBackend_management(t *testing.T) {
if os.Getenv(logicaltest.TestEnvVar) == "" {
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))

View file

@ -7,26 +7,23 @@ import (
"github.com/hashicorp/vault/logical"
)
func client(s logical.Storage) (*api.Client, error) {
entry, err := s.Get("config/access")
if err != nil {
return nil, err
func client(s logical.Storage) (*api.Client, error, error) {
conf, userErr, intErr := readConfigAccess(s)
if intErr != nil {
return nil, nil, intErr
}
if entry == nil {
return nil, fmt.Errorf(
"root credentials haven't been configured. Please configure\n" +
"them at the '/root' endpoint")
if userErr != nil {
return nil, userErr, nil
}
if conf == nil {
return nil, nil, fmt.Errorf("no error received but no configuration found")
}
var conf accessConfig
if err := entry.DecodeJSON(&conf); err != nil {
return nil, fmt.Errorf("error reading root configuration: %s", err)
}
consulConf := api.DefaultConfig()
consulConf := api.DefaultNonPooledConfig()
consulConf.Address = conf.Address
consulConf.Scheme = conf.Scheme
consulConf.Token = conf.Token
return api.NewClient(consulConf)
client, err := api.NewClient(consulConf)
return client, nil, err
}

View file

@ -1,6 +1,8 @@
package consul
import (
"fmt"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -31,11 +33,52 @@ func pathConfigAccess() *framework.Path {
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: pathConfigAccessRead,
logical.UpdateOperation: pathConfigAccessWrite,
},
}
}
func readConfigAccess(storage logical.Storage) (*accessConfig, error, error) {
entry, err := storage.Get("config/access")
if err != nil {
return nil, nil, err
}
if entry == nil {
return nil, fmt.Errorf(
"Access credentials for the backend itself haven't been configured. Please configure them at the '/config/access' endpoint"),
nil
}
conf := &accessConfig{}
if err := entry.DecodeJSON(conf); err != nil {
return nil, nil, fmt.Errorf("error reading consul access configuration: %s", err)
}
return conf, nil, nil
}
func pathConfigAccessRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
conf, userErr, intErr := readConfigAccess(req.Storage)
if intErr != nil {
return nil, intErr
}
if userErr != nil {
return logical.ErrorResponse(userErr.Error()), nil
}
if conf == nil {
return nil, fmt.Errorf("no user error reported but consul access configuration not found")
}
return &logical.Response{
Data: map[string]interface{}{
"address": conf.Address,
"scheme": conf.Scheme,
},
}, nil
}
func pathConfigAccessWrite(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := logical.StorageEntryJSON("config/access", accessConfig{

View file

@ -47,8 +47,11 @@ func (b *backend) pathTokenRead(
}
// Get the consul client
c, err := client(req.Storage)
if err != nil {
c, userErr, intErr := client(req.Storage)
if intErr != nil {
return nil, intErr
}
if userErr != nil {
return logical.ErrorResponse(err.Error()), nil
}
@ -67,7 +70,9 @@ func (b *backend) pathTokenRead(
// Use the helper to create the secret
s := b.Secret(SecretTokenType).Response(map[string]interface{}{
"token": token,
}, nil)
}, map[string]interface{}{
"token": token,
})
s.Secret.TTL = result.Lease
return s, nil

View file

@ -32,14 +32,27 @@ func (b *backend) secretTokenRenew(
func secretTokenRevoke(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
c, err := client(req.Storage)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
c, userErr, intErr := client(req.Storage)
if intErr != nil {
return nil, intErr
}
if userErr != nil {
// Returning logical.ErrorResponse from revocation function is risky
return nil, userErr
}
_, err = c.ACL().Destroy(d.Get("token").(string), nil)
tokenRaw, ok := req.Secret.InternalData["token"]
if !ok {
// We return nil here because this is a pre-0.5.3 problem and there is
// nothing we can do about it. We already can't revoke the lease
// properly if it has been renewed and this is documented pre-0.5.3
// behavior with a security bulletin about it.
return nil, nil
}
_, err := c.ACL().Destroy(tokenRaw.(string), nil)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
return nil, err
}
return nil, nil

View file

@ -141,12 +141,10 @@ func (b *backend) secretCredsRevoke(
// can't drop if not all database users are dropped
if rows.Err() != nil {
return logical.ErrorResponse(fmt.Sprintf(
"could not generate sql statements for all rows: %v", rows.Err())), nil
return nil, fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err())
}
if lastStmtError != nil {
return logical.ErrorResponse(fmt.Sprintf(
"could not perform all sql statements: %v", lastStmtError)), nil
return nil, fmt.Errorf("could not perform all sql statements: %s", lastStmtError)
}
// Drop this login

View file

@ -20,12 +20,6 @@ func Backend() *framework.Backend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
Root: []string{
"config/*",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),

View file

@ -271,6 +271,15 @@ func (b *backend) pathRoleRead(
Data: structs.New(role).Map(),
}
if resp.Data == nil {
return nil, fmt.Errorf("error converting role data to response")
}
// These values are deprecated and the entries are migrated on read
delete(resp.Data, "lease")
delete(resp.Data, "lease_max")
delete(resp.Data, "allowed_base_domain")
return resp, nil
}

View file

@ -38,12 +38,12 @@ reference`,
func (b *backend) secretCredsRevoke(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
if req.Secret == nil {
return nil, fmt.Errorf("Secret is nil in request")
return nil, fmt.Errorf("secret is nil in request")
}
serialInt, ok := req.Secret.InternalData["serial_number"]
if !ok {
return nil, fmt.Errorf("Could not find serial in internal secret data")
return nil, fmt.Errorf("could not find serial in internal secret data")
}
serial := strings.Replace(strings.ToLower(serialInt.(string)), "-", ":", -1)

View file

@ -19,12 +19,6 @@ func Backend() *framework.Backend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
Root: []string{
"config/*",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),

View file

@ -171,10 +171,10 @@ func (b *backend) secretCredsRevoke(
// can't drop if not all privileges are revoked
if rows.Err() != nil {
return logical.ErrorResponse(fmt.Sprintf("could not generate revocation statements for all rows: %v", rows.Err())), nil
return nil, fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err())
}
if lastStmtError != nil {
return logical.ErrorResponse(fmt.Sprintf("could not perform all revocation statements: %v", lastStmtError)), nil
return nil, fmt.Errorf("could not perform all revocation statements: %s", lastStmtError)
}
// Drop this user

View file

@ -0,0 +1,10 @@
# RabbitMQ Backend
## Testing
There are unit and integration RabbitMQ backend tests. Unit tests can be run by `go test`. Integration tests require setting the following environment variables:
```
RABBITMQ_CONNECTION_URI=
RABBITMQ_USERNAME=
RABBITMQ_PASSWORD=
```

View file

@ -0,0 +1,125 @@
package rabbitmq
import (
"fmt"
"strings"
"sync"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/michaelklishin/rabbit-hole"
)
// Factory creates and configures the backend
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
return Backend().Setup(conf)
}
// Creates a new backend with all the paths and secrets belonging to it
func Backend() *backend {
var b backend
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),
pathListRoles(&b),
pathCreds(&b),
pathRoles(&b),
},
Secrets: []*framework.Secret{
secretCreds(&b),
},
Clean: b.resetClient,
}
return &b
}
type backend struct {
*framework.Backend
client *rabbithole.Client
lock sync.RWMutex
}
// DB returns the database connection.
func (b *backend) Client(s logical.Storage) (*rabbithole.Client, error) {
b.lock.RLock()
// If we already have a client, return it
if b.client != nil {
b.lock.RUnlock()
return b.client, nil
}
b.lock.RUnlock()
// Otherwise, attempt to make connection
entry, err := s.Get("config/connection")
if err != nil {
return nil, err
}
if entry == nil {
return nil, fmt.Errorf("configure the client connection with config/connection first")
}
var connConfig connectionConfig
if err := entry.DecodeJSON(&connConfig); err != nil {
return nil, err
}
b.lock.Lock()
defer b.lock.Unlock()
// If the client was creted during the lock switch, return it
if b.client != nil {
return b.client, nil
}
b.client, err = rabbithole.NewClient(connConfig.URI, connConfig.Username, connConfig.Password)
if err != nil {
return nil, err
}
// Use a default pooled transport so there would be no leaked file descriptors
b.client.SetTransport(cleanhttp.DefaultPooledTransport())
return b.client, nil
}
// resetClient forces a connection next time Client() is called.
func (b *backend) resetClient() {
b.lock.Lock()
defer b.lock.Unlock()
b.client = nil
}
// Lease returns the lease information
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result configLease
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
const backendHelp = `
The RabbitMQ backend dynamically generates RabbitMQ users.
After mounting this backend, configure it using the endpoints within
the "config/" path.
`

View file

@ -0,0 +1,218 @@
package rabbitmq
import (
"encoding/json"
"fmt"
"log"
"os"
"testing"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"github.com/michaelklishin/rabbit-hole"
"github.com/mitchellh/mapstructure"
)
func TestBackend_basic(t *testing.T) {
if os.Getenv(logicaltest.TestEnvVar) == "" {
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
return
}
b, _ := Factory(logical.TestBackendConfig())
logicaltest.Test(t, logicaltest.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepRole(t),
testAccStepReadCreds(t, b, "web"),
},
})
}
func TestBackend_roleCrud(t *testing.T) {
if os.Getenv(logicaltest.TestEnvVar) == "" {
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
return
}
b, _ := Factory(logical.TestBackendConfig())
logicaltest.Test(t, logicaltest.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepRole(t),
testAccStepReadRole(t, "web", "administrator", `{"/": {"configure": ".*", "write": ".*", "read": ".*"}}`),
testAccStepDeleteRole(t, "web"),
testAccStepReadRole(t, "web", "", ""),
},
})
}
const (
uriEnv = "RABBITMQ_CONNECTION_URI"
usernameEnv = "RABBITMQ_USERNAME"
passwordEnv = "RABBITMQ_PASSWORD"
)
func mustSet(name string) string {
return fmt.Sprintf("%s must be set for acceptance tests", name)
}
func testAccPreCheck(t *testing.T) {
if uri := os.Getenv(uriEnv); uri == "" {
t.Fatal(mustSet(uriEnv))
}
if username := os.Getenv(usernameEnv); username == "" {
t.Fatal(mustSet(usernameEnv))
}
if password := os.Getenv(passwordEnv); password == "" {
t.Fatal(mustSet(passwordEnv))
}
}
func testAccStepConfig(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/connection",
Data: map[string]interface{}{
"connection_uri": os.Getenv(uriEnv),
"username": os.Getenv(usernameEnv),
"password": os.Getenv(passwordEnv),
},
}
}
func testAccStepRole(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "roles/web",
Data: map[string]interface{}{
"tags": "administrator",
"vhosts": `{"/": {"configure": ".*", "write": ".*", "read": ".*"}}`,
},
}
}
func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "roles/" + n,
}
}
func testAccStepReadCreds(t *testing.T, b logical.Backend, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "creds/" + name,
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[WARN] Generated credentials: %v", d)
uri := os.Getenv(uriEnv)
client, err := rabbithole.NewClient(uri, d.Username, d.Password)
if err != nil {
t.Fatal(err)
}
_, err = client.ListVhosts()
if err != nil {
t.Fatalf("unable to list vhosts with generated credentials: %s", err)
}
resp, err = b.HandleRequest(&logical.Request{
Operation: logical.RevokeOperation,
Secret: &logical.Secret{
InternalData: map[string]interface{}{
"secret_type": "creds",
"username": d.Username,
},
},
})
if err != nil {
return err
}
if resp != nil {
if resp.IsError() {
return fmt.Errorf("Error on resp: %#v", *resp)
}
}
client, err = rabbithole.NewClient(uri, d.Username, d.Password)
if err != nil {
t.Fatal(err)
}
_, err = client.ListVhosts()
if err == nil {
t.Fatalf("expected to fail listing vhosts: %s", err)
}
return nil
},
}
}
func testAccStepReadRole(t *testing.T, name, tags, rawVHosts string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "roles/" + name,
Check: func(resp *logical.Response) error {
if resp == nil {
if tags == "" && rawVHosts == "" {
return nil
}
return fmt.Errorf("bad: %#v", resp)
}
var d struct {
Tags string `mapstructure:"tags"`
VHosts map[string]vhostPermission `mapstructure:"vhosts"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.Tags != tags {
return fmt.Errorf("bad: %#v", resp)
}
var vhosts map[string]vhostPermission
if err := json.Unmarshal([]byte(rawVHosts), &vhosts); err != nil {
return fmt.Errorf("bad expected vhosts %#v: %s", vhosts, err)
}
for host, permission := range vhosts {
actualPermission, ok := d.VHosts[host]
if !ok {
return fmt.Errorf("expected vhost: %s", host)
}
if actualPermission.Configure != permission.Configure {
return fmt.Errorf("expected permission %s to be %s, got %s", "configure", permission.Configure, actualPermission.Configure)
}
if actualPermission.Write != permission.Write {
return fmt.Errorf("expected permission %s to be %s, got %s", "write", permission.Write, actualPermission.Write)
}
if actualPermission.Read != permission.Read {
return fmt.Errorf("expected permission %s to be %s, got %s", "read", permission.Read, actualPermission.Read)
}
}
return nil
},
}
}

View file

@ -0,0 +1,118 @@
package rabbitmq
import (
"fmt"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/michaelklishin/rabbit-hole"
)
func pathConfigConnection(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/connection",
Fields: map[string]*framework.FieldSchema{
"connection_uri": &framework.FieldSchema{
Type: framework.TypeString,
Description: "RabbitMQ Management URI",
},
"username": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Username of a RabbitMQ management administrator",
},
"password": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Password of the provided RabbitMQ management user",
},
"verify_connection": &framework.FieldSchema{
Type: framework.TypeBool,
Default: true,
Description: `If set, connection_uri is verified by actually connecting to the RabbitMQ management API`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConnectionUpdate,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
HelpDescription: pathConfigConnectionHelpDesc,
}
}
func (b *backend) pathConnectionUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
uri := data.Get("connection_uri").(string)
if uri == "" {
return logical.ErrorResponse("missing connection_uri"), nil
}
username := data.Get("username").(string)
if username == "" {
return logical.ErrorResponse("missing username"), nil
}
password := data.Get("password").(string)
if password == "" {
return logical.ErrorResponse("missing password"), nil
}
// Don't check the connection_url if verification is disabled
verifyConnection := data.Get("verify_connection").(bool)
if verifyConnection {
// Create RabbitMQ management client
client, err := rabbithole.NewClient(uri, username, password)
if err != nil {
return nil, fmt.Errorf("failed to create client: %s", err)
}
// Verify that configured credentials is capable of listing
if _, err = client.ListUsers(); err != nil {
return nil, fmt.Errorf("failed to validate the connection: %s", err)
}
}
// Store it
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
URI: uri,
Username: username,
Password: password,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
// Reset the client connection
b.resetClient()
return nil, nil
}
// connectionConfig contains the information required to make a connection to a RabbitMQ node
type connectionConfig struct {
// URI of the RabbitMQ server
URI string `json:"connection_uri"`
// Username which has 'administrator' tag attached to it
Username string `json:"username"`
// Password for the Username
Password string `json:"password"`
}
const pathConfigConnectionHelpSyn = `
Configure the connection URI, username, and password to talk to RabbitMQ management HTTP API.
`
const pathConfigConnectionHelpDesc = `
This path configures the connection properties used to connect to RabbitMQ management HTTP API.
The "connection_uri" parameter is a string that is used to connect to the API. The "username"
and "password" parameters are strings that are used as credentials to the API. The "verify_connection"
parameter is a boolean that is used to verify whether the provided connection URI, username, and password
are valid.
The URI looks like:
"http://localhost:15672"
`

View file

@ -0,0 +1,83 @@
package rabbitmq
import (
"time"
"github.com/fatih/structs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathConfigLease(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/lease",
Fields: map[string]*framework.FieldSchema{
"ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Default: 0,
Description: "Duration before which the issued credentials needs renewal",
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Default: 0,
Description: `Duration after which the issued credentials should not be allowed to be renewed`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathLeaseRead,
logical.UpdateOperation: b.pathLeaseUpdate,
},
HelpSynopsis: pathConfigLeaseHelpSyn,
HelpDescription: pathConfigLeaseHelpDesc,
}
}
// Sets the lease configuration parameters
func (b *backend) pathLeaseUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
TTL: time.Second * time.Duration(d.Get("ttl").(int)),
MaxTTL: time.Second * time.Duration(d.Get("max_ttl").(int)),
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
return nil, nil
}
// Returns the lease configuration parameters
func (b *backend) pathLeaseRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
lease, err := b.Lease(req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
return nil, nil
}
lease.TTL = lease.TTL / time.Second
lease.MaxTTL = lease.MaxTTL / time.Second
return &logical.Response{
Data: structs.New(lease).Map(),
}, nil
}
// Lease configuration information for the secrets issued by this backend
type configLease struct {
TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"`
MaxTTL time.Duration `json:"max_ttl" structs:"max_ttl" mapstructure:"max_ttl"`
}
var pathConfigLeaseHelpSyn = "Configure the lease parameters for generated credentials"
var pathConfigLeaseHelpDesc = `
Sets the ttl and max_ttl values for the secrets to be issued by this backend.
Both ttl and max_ttl takes in an integer number of seconds as input as well as
inputs like "1h".
`

View file

@ -0,0 +1,53 @@
package rabbitmq
import (
"testing"
"time"
"github.com/hashicorp/vault/logical"
)
func TestBackend_config_lease_RU(t *testing.T) {
var resp *logical.Response
var err error
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b := Backend()
if _, err = b.Setup(config); err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"ttl": "10h",
"max_ttl": "20h",
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/lease",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr:%s", resp, err)
}
if resp != nil {
t.Fatal("expected a nil response")
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: resp: %#v\nerr:%s", resp, err)
}
if resp == nil {
t.Fatal("expected a response")
}
if resp.Data["ttl"].(time.Duration) != 36000 {
t.Fatalf("bad: ttl: expected:36000 actual:%d", resp.Data["ttl"].(time.Duration))
}
if resp.Data["max_ttl"].(time.Duration) != 72000 {
t.Fatalf("bad: ttl: expected:72000 actual:%d", resp.Data["ttl"].(time.Duration))
}
}

View file

@ -0,0 +1,120 @@
package rabbitmq
import (
"fmt"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/michaelklishin/rabbit-hole"
)
func pathCreds(b *backend) *framework.Path {
return &framework.Path{
Pattern: "creds/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the role.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathCredsRead,
},
HelpSynopsis: pathRoleCreateReadHelpSyn,
HelpDescription: pathRoleCreateReadHelpDesc,
}
}
// Issues the credential based on the role name
func (b *backend) pathCredsRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
if name == "" {
return logical.ErrorResponse("missing name"), nil
}
// Get the role
role, err := b.Role(req.Storage, name)
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
}
// Ensure username is unique
uuidVal, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
username := fmt.Sprintf("%s-%s", req.DisplayName, uuidVal)
password, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
// Get the client configuration
client, err := b.Client(req.Storage)
if err != nil {
return nil, err
}
if client == nil {
return logical.ErrorResponse("failed to get the client"), nil
}
// Register the generated credentials in the backend, with the RabbitMQ server
if _, err = client.PutUser(username, rabbithole.UserSettings{
Password: password,
Tags: role.Tags,
}); err != nil {
return nil, fmt.Errorf("failed to create a new user with the generated credentials")
}
// If the role had vhost permissions specified, assign those permissions
// to the created username for respective vhosts.
for vhost, permission := range role.VHosts {
if _, err := client.UpdatePermissionsIn(vhost, username, rabbithole.Permissions{
Configure: permission.Configure,
Write: permission.Write,
Read: permission.Read,
}); err != nil {
// Delete the user because it's in an unknown state
if _, rmErr := client.DeleteUser(username); rmErr != nil {
return nil, fmt.Errorf("failed to delete user:%s, err: %s. %s", username, err, rmErr)
}
return nil, fmt.Errorf("failed to update permissions to the %s user. err:%s", username, err)
}
}
// Return the secret
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
"username": username,
"password": password,
}, map[string]interface{}{
"username": username,
})
// Determine if we have a lease
lease, err := b.Lease(req.Storage)
if err != nil {
return nil, err
}
if lease != nil {
resp.Secret.TTL = lease.TTL
}
return resp, nil
}
const pathRoleCreateReadHelpSyn = `
Request RabbitMQ credentials for a certain role.
`
const pathRoleCreateReadHelpDesc = `
This path reads RabbitMQ credentials for a certain role. The
RabbitMQ credentials will be generated on demand and will be automatically
revoked when the lease is up.
`

View file

@ -0,0 +1,182 @@
package rabbitmq
import (
"encoding/json"
"fmt"
"github.com/fatih/structs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathListRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathRoleList,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func pathRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the role.",
},
"tags": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Comma-separated list of tags for this role.",
},
"vhosts": &framework.FieldSchema{
Type: framework.TypeString,
Description: "A map of virtual hosts to permissions.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleRead,
logical.UpdateOperation: b.pathRoleUpdate,
logical.DeleteOperation: b.pathRoleDelete,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
// Reads the role configuration from the storage
func (b *backend) Role(s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get("role/" + n)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result roleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
// Deletes an existing role
func (b *backend) pathRoleDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
if name == "" {
return logical.ErrorResponse("missing name"), nil
}
return nil, req.Storage.Delete("role/" + name)
}
// Reads an existing role
func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
if name == "" {
return logical.ErrorResponse("missing name"), nil
}
role, err := b.Role(req.Storage, name)
if err != nil {
return nil, err
}
if role == nil {
return nil, nil
}
return &logical.Response{
Data: structs.New(role).Map(),
}, nil
}
// Lists all the roles registered with the backend
func (b *backend) pathRoleList(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
roles, err := req.Storage.List("role/")
if err != nil {
return nil, err
}
return logical.ListResponse(roles), nil
}
// Registers a new role with the backend
func (b *backend) pathRoleUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
if name == "" {
return logical.ErrorResponse("missing name"), nil
}
tags := d.Get("tags").(string)
rawVHosts := d.Get("vhosts").(string)
if tags == "" && rawVHosts == "" {
return logical.ErrorResponse("both tags and vhosts not specified"), nil
}
var vhosts map[string]vhostPermission
if len(rawVHosts) > 0 {
err := json.Unmarshal([]byte(rawVHosts), &vhosts)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to unmarshal vhosts: %s", err)), nil
}
}
// Store it
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
Tags: tags,
VHosts: vhosts,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
return nil, nil
}
// Role that defines the capabilities of the credentials issued against it
type roleEntry struct {
Tags string `json:"tags" structs:"tags" mapstructure:"tags"`
VHosts map[string]vhostPermission `json:"vhosts" structs:"vhosts" mapstructure:"vhosts"`
}
// Structure representing the permissions of a vhost
type vhostPermission struct {
Configure string `json:"configure" structs:"configure" mapstructure:"configure"`
Write string `json:"write" structs:"write" mapstructure:"write"`
Read string `json:"read" structs:"read" mapstructure:"read"`
}
const pathRoleHelpSyn = `
Manage the roles that can be created with this backend.
`
const pathRoleHelpDesc = `
This path lets you manage the roles that can be created with this backend.
The "tags" parameter customizes the tags used to create the role.
This is a comma separated list of strings. The "vhosts" parameter customizes
the virtual hosts that this user will be associated with. This is a JSON object
passed as a string in the form:
{
"vhostOne": {
"configure": ".*",
"write": ".*",
"read": ".*"
},
"vhostTwo": {
"configure": ".*",
"write": ".*",
"read": ".*"
}
}
`

View file

@ -0,0 +1 @@
package rabbitmq

View file

@ -0,0 +1,67 @@
package rabbitmq
import (
"fmt"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
// SecretCredsType is the key for this backend's secrets.
const SecretCredsType = "creds"
func secretCreds(b *backend) *framework.Secret {
return &framework.Secret{
Type: SecretCredsType,
Fields: map[string]*framework.FieldSchema{
"username": &framework.FieldSchema{
Type: framework.TypeString,
Description: "RabbitMQ username",
},
"password": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Password for the RabbitMQ username",
},
},
Renew: b.secretCredsRenew,
Revoke: b.secretCredsRevoke,
}
}
// Renew the previously issued secret
func (b *backend) secretCredsRenew(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the lease information
lease, err := b.Lease(req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
lease = &configLease{}
}
return framework.LeaseExtend(lease.TTL, lease.MaxTTL, b.System())(req, d)
}
// Revoke the previously issued secret
func (b *backend) secretCredsRevoke(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the username from the internal data
usernameRaw, ok := req.Secret.InternalData["username"]
if !ok {
return nil, fmt.Errorf("secret is missing username internal data")
}
username := usernameRaw.(string)
// Get our connection
client, err := b.Client(req.Storage)
if err != nil {
return nil, err
}
if _, err = client.DeleteUser(username); err != nil {
return nil, fmt.Errorf("could not delete user: %s", err)
}
return nil, nil
}

View file

@ -21,7 +21,7 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
return b.Setup(conf)
}
func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
func Backend(conf *logical.BackendConfig) (*backend, error) {
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
HashFunc: salt.SHA256Hash,
})
@ -35,10 +35,6 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
Root: []string{
"config/*",
"keys/*",
},
Unauthenticated: []string{
"verify",
},
@ -59,7 +55,7 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
secretOTP(&b),
},
}
return b.Backend, nil
return &b, nil
}
const backendHelp = `

View file

@ -61,6 +61,120 @@ oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F
`
)
func TestBackend_allowed_users(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
_, err = b.Setup(config)
if err != nil {
t.Fatal(err)
}
roleData := map[string]interface{}{
"key_type": "otp",
"default_user": "ubuntu",
"cidr_list": "52.207.235.245/16",
"allowed_users": "test",
}
roleReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/role1",
Storage: config.StorageView,
Data: roleData,
}
resp, err := b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) || resp != nil {
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
}
credsData := map[string]interface{}{
"ip": "52.207.235.245",
"username": "ubuntu",
}
credsReq := &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "creds/role1",
Data: credsData,
}
resp, err = b.HandleRequest(credsReq)
if err != nil || (resp != nil && resp.IsError()) || resp == nil {
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
}
if resp.Data["key"] == "" ||
resp.Data["key_type"] != "otp" ||
resp.Data["ip"] != "52.207.235.245" ||
resp.Data["username"] != "ubuntu" {
t.Fatalf("failed to create credential: resp:%#v", resp)
}
credsData["username"] = "test"
resp, err = b.HandleRequest(credsReq)
if err != nil || (resp != nil && resp.IsError()) || resp == nil {
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
}
if resp.Data["key"] == "" ||
resp.Data["key_type"] != "otp" ||
resp.Data["ip"] != "52.207.235.245" ||
resp.Data["username"] != "test" {
t.Fatalf("failed to create credential: resp:%#v", resp)
}
credsData["username"] = "random"
resp, err = b.HandleRequest(credsReq)
if err != nil || resp == nil || (resp != nil && !resp.IsError()) {
t.Fatalf("expected failure: resp:%#v err:%s", resp, err)
}
delete(roleData, "allowed_users")
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) || resp != nil {
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
}
credsData["username"] = "ubuntu"
resp, err = b.HandleRequest(credsReq)
if err != nil || (resp != nil && resp.IsError()) || resp == nil {
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
}
if resp.Data["key"] == "" ||
resp.Data["key_type"] != "otp" ||
resp.Data["ip"] != "52.207.235.245" ||
resp.Data["username"] != "ubuntu" {
t.Fatalf("failed to create credential: resp:%#v", resp)
}
credsData["username"] = "test"
resp, err = b.HandleRequest(credsReq)
if err != nil || resp == nil || (resp != nil && !resp.IsError()) {
t.Fatalf("expected failure: resp:%#v err:%s", resp, err)
}
roleData["allowed_users"] = "*"
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) || resp != nil {
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
}
resp, err = b.HandleRequest(credsReq)
if err != nil || (resp != nil && resp.IsError()) || resp == nil {
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
}
if resp.Data["key"] == "" ||
resp.Data["key_type"] != "otp" ||
resp.Data["ip"] != "52.207.235.245" ||
resp.Data["username"] != "test" {
t.Fatalf("failed to create credential: resp:%#v", resp)
}
}
func testingFactory(conf *logical.BackendConfig) (logical.Backend, error) {
_, err := vault.StartSSHHostTestServer()
if err != nil {

View file

@ -79,8 +79,10 @@ func (b *backend) pathCredsCreateWrite(
// is the default username in the role. If neither is true, then
// that username is not allowed to generate a credential.
if err != nil && username != role.DefaultUser {
return logical.ErrorResponse("Username is not present is allowed users list."), nil
return logical.ErrorResponse("Username is not present is allowed users list"), nil
}
} else if username != role.DefaultUser {
return logical.ErrorResponse("Username has to be either in allowed users list or has to be a default username"), nil
}
// Validate the IP address
@ -285,12 +287,22 @@ func validateIP(ip, roleName, cidrList, excludeCidrList string, zeroAddressRoles
// Checks if the username supplied by the user is present in the list of
// allowed users registered which creation of role.
func validateUsername(username, allowedUsers string) error {
if allowedUsers == "" {
return fmt.Errorf("username not in allowed users list")
}
// Role was explicitly configured to allow any username.
if allowedUsers == "*" {
return nil
}
userList := strings.Split(allowedUsers, ",")
for _, user := range userList {
if user == username {
if strings.TrimSpace(user) == username {
return nil
}
}
return fmt.Errorf("username not in allowed users list")
}

View file

@ -575,7 +575,7 @@ func TestPolicyFuzzing(t *testing.T) {
}
func testPolicyFuzzingCommon(t *testing.T, be *backend) {
storage := &logical.LockingInmemStorage{}
storage := &logical.InmemStorage{}
wg := sync.WaitGroup{}
funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"}

View file

@ -9,7 +9,7 @@ import (
"github.com/hashicorp/vault/version"
credAppId "github.com/hashicorp/vault/builtin/credential/app-id"
credAws "github.com/hashicorp/vault/builtin/credential/aws"
credAwsEc2 "github.com/hashicorp/vault/builtin/credential/aws-ec2"
credCert "github.com/hashicorp/vault/builtin/credential/cert"
credGitHub "github.com/hashicorp/vault/builtin/credential/github"
credLdap "github.com/hashicorp/vault/builtin/credential/ldap"
@ -22,6 +22,7 @@ import (
"github.com/hashicorp/vault/builtin/logical/mysql"
"github.com/hashicorp/vault/builtin/logical/pki"
"github.com/hashicorp/vault/builtin/logical/postgresql"
"github.com/hashicorp/vault/builtin/logical/rabbitmq"
"github.com/hashicorp/vault/builtin/logical/ssh"
"github.com/hashicorp/vault/builtin/logical/transit"
@ -64,7 +65,7 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
},
CredentialBackends: map[string]logical.Factory{
"cert": credCert.Factory,
"aws": credAws.Factory,
"aws-ec2": credAwsEc2.Factory,
"app-id": credAppId.Factory,
"github": credGitHub.Factory,
"userpass": credUserpass.Factory,
@ -80,6 +81,7 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
"mssql": mssql.Factory,
"mysql": mysql.Factory,
"ssh": ssh.Factory,
"rabbitmq": rabbitmq.Factory,
},
ShutdownCh: command.MakeShutdownCh(),
SighupCh: command.MakeSighupCh(),
@ -171,6 +173,12 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
}, nil
},
"unwrap": func() (cli.Command, error) {
return &command.UnwrapCommand{
Meta: *metaPtr,
}, nil
},
"list": func() (cli.Command, error) {
return &command.ListCommand{
Meta: *metaPtr,

View file

@ -21,6 +21,7 @@ func HelpFunc(commands map[string]cli.CommandFactory) string {
"write": struct{}{},
"server": struct{}{},
"status": struct{}{},
"unwrap": struct{}{},
}
// Determine the maximum key length, and classify based on type

View file

@ -129,12 +129,17 @@ func (t TableFormatter) OutputSecret(ui cli.Ui, secret, s *api.Secret) error {
input = append(input, fmt.Sprintf("Key %s Value", config.Delim))
input = append(input, fmt.Sprintf("--- %s -----", config.Delim))
if s.LeaseDuration > 0 {
if s.LeaseID != "" {
input = append(input, fmt.Sprintf("lease_id %s %s", config.Delim, s.LeaseID))
input = append(input, fmt.Sprintf(
"lease_duration %s %d", config.Delim, s.LeaseDuration))
} else {
input = append(input, fmt.Sprintf(
"refresh_interval %s %d", config.Delim, s.LeaseDuration))
}
input = append(input, fmt.Sprintf(
"lease_duration %s %d", config.Delim, s.LeaseDuration))
if s.LeaseID != "" {
input = append(input, fmt.Sprintf(
"lease_renewable %s %s", config.Delim, strconv.FormatBool(s.Renewable)))
@ -152,6 +157,12 @@ func (t TableFormatter) OutputSecret(ui cli.Ui, secret, s *api.Secret) error {
}
}
if s.WrapInfo != nil {
input = append(input, fmt.Sprintf("wrapping_token: %s %s", config.Delim, s.WrapInfo.Token))
input = append(input, fmt.Sprintf("wrapping_token_ttl: %s %d", config.Delim, s.WrapInfo.TTL))
input = append(input, fmt.Sprintf("wrapping_token_creation_time: %s %s", config.Delim, s.WrapInfo.CreationTime.String()))
}
keys := make([]string, 0, len(s.Data))
for k := range s.Data {
keys = append(keys, k)

View file

@ -3,8 +3,6 @@ package command
import (
"flag"
"fmt"
"os"
"reflect"
"strings"
"github.com/hashicorp/vault/api"
@ -63,23 +61,7 @@ func (c *ReadCommand) Run(args []string) int {
// Handle single field output
if field != "" {
if val, ok := secret.Data[field]; ok {
// c.Ui.Output() prints a CR character which in this case is
// not desired. Since Vault CLI currently only uses BasicUi,
// which writes to standard output, os.Stdout is used here to
// directly print the message. If mitchellh/cli exposes method
// to print without CR, this check needs to be removed.
if reflect.TypeOf(c.Ui).String() == "*cli.BasicUi" {
fmt.Fprintf(os.Stdout, fmt.Sprintf("%v", val))
} else {
c.Ui.Output(fmt.Sprintf("%v", val))
}
return 0
} else {
c.Ui.Error(fmt.Sprintf(
"Field %s not present in secret", field))
return 1
}
return PrintRawField(c.Ui, secret, field)
}
return OutputSecret(c.Ui, format, secret)

View file

@ -44,6 +44,8 @@ type ServerCommand struct {
meta.Meta
logger *log.Logger
ReloadFuncs map[string][]server.ReloadFunc
}
@ -63,11 +65,11 @@ func (c *ServerCommand) Run(args []string) int {
return 1
}
if os.Getenv("VAULT_DEV_ROOT_TOKEN_ID") != "" {
if os.Getenv("VAULT_DEV_ROOT_TOKEN_ID") != "" && devRootTokenID == "" {
devRootTokenID = os.Getenv("VAULT_DEV_ROOT_TOKEN_ID")
}
if os.Getenv("VAULT_DEV_LISTEN_ADDRESS") != "" {
if os.Getenv("VAULT_DEV_LISTEN_ADDRESS") != "" && devListenAddress == "" {
devListenAddress = os.Getenv("VAULT_DEV_LISTEN_ADDRESS")
}
@ -136,7 +138,7 @@ func (c *ServerCommand) Run(args []string) int {
// Create a logger. We wrap it in a gated writer so that it doesn't
// start logging too early.
logGate := &gatedwriter.Writer{Writer: os.Stderr}
logger := log.New(&logutils.LevelFilter{
c.logger = log.New(&logutils.LevelFilter{
Levels: []logutils.LogLevel{
"TRACE", "DEBUG", "INFO", "WARN", "ERR"},
MinLevel: logutils.LogLevel(strings.ToUpper(logLevel)),
@ -150,7 +152,7 @@ func (c *ServerCommand) Run(args []string) int {
// Initialize the backend
backend, err := physical.NewBackend(
config.Backend.Type, logger, config.Backend.Config)
config.Backend.Type, c.logger, config.Backend.Config)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing backend of type %s: %s",
@ -179,7 +181,7 @@ func (c *ServerCommand) Run(args []string) int {
AuditBackends: c.AuditBackends,
CredentialBackends: c.CredentialBackends,
LogicalBackends: c.LogicalBackends,
Logger: logger,
Logger: c.logger,
DisableCache: config.DisableCache,
DisableMlock: config.DisableMlock,
MaxLeaseTTL: config.MaxLeaseTTL,
@ -190,7 +192,7 @@ func (c *ServerCommand) Run(args []string) int {
var ok bool
if config.HABackend != nil {
habackend, err := physical.NewBackend(
config.HABackend.Type, logger, config.HABackend.Config)
config.HABackend.Type, c.logger, config.HABackend.Config)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing backend of type %s: %s",
@ -299,14 +301,14 @@ func (c *ServerCommand) Run(args []string) int {
sd, ok := coreConfig.HAPhysical.(physical.ServiceDiscovery)
if ok {
activeFunc := func() bool {
if isLeader, _, err := core.Leader(); err != nil {
if isLeader, _, err := core.Leader(); err == nil {
return isLeader
}
return false
}
sealedFunc := func() bool {
if sealed, err := core.Sealed(); err != nil {
if sealed, err := core.Sealed(); err == nil {
return sealed
}
return true
@ -322,7 +324,7 @@ func (c *ServerCommand) Run(args []string) int {
// Initialize the listeners
lns := make([]net.Listener, 0, len(config.Listeners))
for i, lnConfig := range config.Listeners {
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config)
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, logGate)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing listener of type %s: %s",
@ -351,6 +353,13 @@ func (c *ServerCommand) Run(args []string) int {
}
}
// Make sure we close all listeners from this point on
defer func() {
for _, ln := range lns {
ln.Close()
}
}()
infoKeys = append(infoKeys, "version")
info["version"] = version.GetVersion().String()
@ -368,9 +377,6 @@ func (c *ServerCommand) Run(args []string) int {
c.Ui.Output("")
if verifyOnly {
for _, listener := range lns {
listener.Close()
}
return 0
}
@ -410,10 +416,6 @@ func (c *ServerCommand) Run(args []string) int {
}
}
for _, listener := range lns {
listener.Close()
}
return 0
}

View file

@ -200,6 +200,7 @@ func ParseConfig(d string) (*Config, error) {
}
valid := []string{
"atlas",
"backend",
"ha_backend",
"listener",
@ -414,6 +415,8 @@ func parseHABackends(result *Config, list *ast.ObjectList) error {
}
func parseListeners(result *Config, list *ast.ObjectList) error {
var foundAtlas bool
listeners := make([]*Listener, 0, len(list.Items))
for _, item := range list.Items {
key := "listener"
@ -423,10 +426,14 @@ func parseListeners(result *Config, list *ast.ObjectList) error {
valid := []string{
"address",
"endpoint",
"infrastructure",
"node_id",
"tls_disable",
"tls_cert_file",
"tls_key_file",
"tls_min_version",
"token",
}
if err := checkHCLKeys(item.Val, valid); err != nil {
return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key))
@ -437,8 +444,27 @@ func parseListeners(result *Config, list *ast.ObjectList) error {
return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key))
}
lnType := strings.ToLower(key)
if lnType == "atlas" {
if foundAtlas {
return multierror.Prefix(fmt.Errorf("only one listener of type 'atlas' is permitted"), fmt.Sprintf("listeners.%s", key))
} else {
foundAtlas = true
if m["token"] == "" {
return multierror.Prefix(fmt.Errorf("'token' must be specified for an Atlas listener"), fmt.Sprintf("listeners.%s", key))
}
if m["infrastructure"] == "" {
return multierror.Prefix(fmt.Errorf("'infrastructure' must be specified for an Atlas listener"), fmt.Sprintf("listeners.%s", key))
}
if m["node_id"] == "" {
return multierror.Prefix(fmt.Errorf("'node_id' must be specified for an Atlas listener"), fmt.Sprintf("listeners.%s", key))
}
}
}
listeners = append(listeners, &Listener{
Type: strings.ToLower(key),
Type: lnType,
Config: m,
})
}

View file

@ -15,6 +15,15 @@ func TestLoadConfigFile(t *testing.T) {
expected := &Config{
Listeners: []*Listener{
&Listener{
Type: "atlas",
Config: map[string]string{
"token": "foobar",
"infrastructure": "foo/bar",
"endpoint": "https://foo.bar:1111",
"node_id": "foo_node",
},
},
&Listener{
Type: "tcp",
Config: map[string]string{
@ -72,6 +81,15 @@ func TestLoadConfigFile_json(t *testing.T) {
"address": "127.0.0.1:443",
},
},
&Listener{
Type: "atlas",
Config: map[string]string{
"token": "foobar",
"infrastructure": "foo/bar",
"endpoint": "https://foo.bar:1111",
"node_id": "foo_node",
},
},
},
Backend: &Backend{

View file

@ -6,17 +6,19 @@ import (
_ "crypto/sha512"
"crypto/tls"
"fmt"
"io"
"net"
"strconv"
"sync"
)
// ListenerFactory is the factory function to create a listener.
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, ReloadFunc, error)
type ListenerFactory func(map[string]string, io.Writer) (net.Listener, map[string]string, ReloadFunc, error)
// BuiltinListeners is the list of built-in listener types.
var BuiltinListeners = map[string]ListenerFactory{
"tcp": tcpListenerFactory,
"tcp": tcpListenerFactory,
"atlas": atlasListenerFactory,
}
// tlsLookup maps the tls_min_version configuration to the internal value
@ -28,13 +30,13 @@ var tlsLookup = map[string]uint16{
// NewListener creates a new listener of the given type with the given
// configuration. The type is looked up in the BuiltinListeners map.
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
func NewListener(t string, config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
f, ok := BuiltinListeners[t]
if !ok {
return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t)
}
return f(config)
return f(config, logger)
}
func listenerWrapTLS(

View file

@ -0,0 +1,60 @@
package server
import (
"io"
"net"
"github.com/hashicorp/scada-client/scada"
"github.com/hashicorp/vault/version"
)
type SCADAListener struct {
ln net.Listener
scadaProvider *scada.Provider
}
func (s *SCADAListener) Accept() (net.Conn, error) {
return s.ln.Accept()
}
func (s *SCADAListener) Close() error {
s.scadaProvider.Shutdown()
return s.ln.Close()
}
func (s *SCADAListener) Addr() net.Addr {
return s.ln.Addr()
}
func atlasListenerFactory(config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
scadaConfig := &scada.Config{
Service: "vault",
Version: version.GetVersion().String(),
ResourceType: "vault-cluster",
Meta: map[string]string{
"node_id": config["node_id"],
},
Atlas: scada.AtlasConfig{
Endpoint: config["endpoint"],
Infrastructure: config["infrastructure"],
Token: config["token"],
},
}
provider, list, err := scada.NewHTTPProvider(scadaConfig, logger)
if err != nil {
return nil, nil, nil, err
}
ln := &SCADAListener{
ln: list,
scadaProvider: provider,
}
props := map[string]string{
"addr": "Atlas/SCADA",
"infrastructure": scadaConfig.Atlas.Infrastructure,
}
return listenerWrapTLS(ln, props, config)
}

View file

@ -1,11 +1,12 @@
package server
import (
"io"
"net"
"time"
)
func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
addr, ok := config["address"]
if !ok {
addr = "127.0.0.1:8200"

View file

@ -16,7 +16,7 @@ func TestTCPListener(t *testing.T) {
ln, _, _, err := tcpListenerFactory(map[string]string{
"address": "127.0.0.1:0",
"tls_disable": "1",
})
}, nil)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -52,7 +52,7 @@ func TestTCPListener_tls(t *testing.T) {
"address": "127.0.0.1:0",
"tls_cert_file": wd + "reload_foo.pem",
"tls_key_file": wd + "reload_foo.key",
})
}, nil)
if err != nil {
t.Fatalf("err: %s", err)
}

View file

@ -3,6 +3,13 @@ disable_mlock = true
statsd_addr = "bar"
statsite_addr = "foo"
listener "atlas" {
token = "foobar"
infrastructure = "foo/bar"
endpoint = "https://foo.bar:1111"
node_id = "foo_node"
}
listener "tcp" {
address = "127.0.0.1:443"
}

View file

@ -1,17 +1,24 @@
{
"listener":{
"tcp":{
"address":"127.0.0.1:443"
}
},
"backend":{
"consul":{
"foo":"bar"
}
},
"telemetry":{
"statsite_address":"baz"
},
"max_lease_ttl":"10h",
"default_lease_ttl":"10h"
"listener": [{
"tcp": {
"address": "127.0.0.1:443"
}
}, {
"atlas": {
"token": "foobar",
"infrastructure": "foo/bar",
"endpoint": "https://foo.bar:1111",
"node_id": "foo_node"
}
}],
"backend": {
"consul": {
"foo": "bar"
}
},
"telemetry": {
"statsite_address": "baz"
},
"max_lease_ttl": "10h",
"default_lease_ttl": "10h"
}

View file

@ -1,13 +1,13 @@
package command
import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"os"
"os/exec"
"os/user"
"strconv"
"strings"
"github.com/hashicorp/vault/builtin/logical/ssh"
@ -27,15 +27,17 @@ type SSHCredentialResp struct {
Key string `mapstructure:"key"`
Username string `mapstructure:"username"`
IP string `mapstructure:"ip"`
Port int `mapstructure:"port"`
Port string `mapstructure:"port"`
}
func (c *SSHCommand) Run(args []string) int {
var role, mountPoint, format string
var role, mountPoint, format, userKnownHostsFile, strictHostKeyChecking string
var noExec bool
var sshCmdArgs []string
var sshDynamicKeyFileName string
flags := c.Meta.FlagSet("ssh", meta.FlagSetDefault)
flags.StringVar(&strictHostKeyChecking, "strict-host-key-checking", "", "")
flags.StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "")
flags.StringVar(&format, "format", "table", "")
flags.StringVar(&role, "role", "", "")
flags.StringVar(&mountPoint, "mount-point", "ssh", "")
@ -45,6 +47,27 @@ func (c *SSHCommand) Run(args []string) int {
if err := flags.Parse(args); err != nil {
return 1
}
// If the flag is already set then it takes the precedence. If the flag is not
// set, try setting it from env var.
if os.Getenv("VAULT_SSH_STRICT_HOST_KEY_CHECKING") != "" && strictHostKeyChecking == "" {
strictHostKeyChecking = os.Getenv("VAULT_SSH_STRICT_HOST_KEY_CHECKING")
}
// Assign default value if both flag and env var are not set
if strictHostKeyChecking == "" {
strictHostKeyChecking = "ask"
}
// If the flag is already set then it takes the precedence. If the flag is not
// set, try setting it from env var.
if os.Getenv("VAULT_SSH_USER_KNOWN_HOSTS_FILE") != "" && userKnownHostsFile == "" {
userKnownHostsFile = os.Getenv("VAULT_SSH_USER_KNOWN_HOSTS_FILE")
}
// Assign default value if both flag and env var are not set
if userKnownHostsFile == "" {
userKnownHostsFile = "~/.ssh/known_hosts"
}
args = flags.Args()
if len(args) < 1 {
c.Ui.Error("ssh expects at least one argument")
@ -123,14 +146,16 @@ func (c *SSHCommand) Run(args []string) int {
return OutputSecret(c.Ui, format, keySecret)
}
// Port comes back as a json.Number which mapstructure doesn't like, so convert it
if keySecret.Data["port"] != nil {
keySecret.Data["port"] = keySecret.Data["port"].(json.Number).String()
}
var resp SSHCredentialResp
if err := mapstructure.Decode(keySecret.Data, &resp); err != nil {
c.Ui.Error(fmt.Sprintf("Error parsing the credential response:%s", err))
return 1
}
port := strconv.Itoa(resp.Port)
if resp.KeyType == ssh.KeyTypeDynamic {
if len(resp.Key) == 0 {
c.Ui.Error(fmt.Sprintf("Invalid key"))
@ -148,7 +173,7 @@ func (c *SSHCommand) Run(args []string) int {
// Feel free to try and remove this dependency.
sshpassPath, err := exec.LookPath("sshpass")
if err == nil {
sshCmdArgs = append(sshCmdArgs, []string{"-p", string(resp.Key), "ssh", "-p", port, username + "@" + ip.String()}...)
sshCmdArgs = append(sshCmdArgs, []string{"-p", string(resp.Key), "ssh", "-o UserKnownHostsFile=" + userKnownHostsFile, "-o StrictHostKeyChecking=" + strictHostKeyChecking, "-p", resp.Port, username + "@" + ip.String()}...)
sshCmd := exec.Command(sshpassPath, sshCmdArgs...)
sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
@ -161,7 +186,7 @@ func (c *SSHCommand) Run(args []string) int {
c.Ui.Output("OTP for the session is " + resp.Key)
c.Ui.Output("[Note: Install 'sshpass' to automate typing in OTP]")
}
sshCmdArgs = append(sshCmdArgs, []string{"-p", port, username + "@" + ip.String()}...)
sshCmdArgs = append(sshCmdArgs, []string{"-o UserKnownHostsFile=" + userKnownHostsFile, "-o StrictHostKeyChecking=" + strictHostKeyChecking, "-p", resp.Port, username + "@" + ip.String()}...)
sshCmd := exec.Command("ssh", sshCmdArgs...)
sshCmd.Stdin = os.Stdin
@ -259,24 +284,38 @@ General Options:
` + meta.GeneralOptionsUsage() + `
SSH Options:
-role Role to be used to create the key.
Each IP is associated with a role. To see the associated
roles with IP, use "lookup" endpoint. If you are certain
that there is only one role associated with the IP, you can
skip mentioning the role. It will be chosen by default. If
there are no roles associated with the IP, register the
CIDR block of that IP using the "roles/" endpoint.
-role Role to be used to create the key.
Each IP is associated with a role. To see the associated
roles with IP, use "lookup" endpoint. If you are certain
that there is only one role associated with the IP, you can
skip mentioning the role. It will be chosen by default. If
there are no roles associated with the IP, register the
CIDR block of that IP using the "roles/" endpoint.
-no-exec Shows the credentials but does not establish connection.
-no-exec Shows the credentials but does not establish connection.
-mount-point Mount point of SSH backend. If the backend is mounted at
'ssh', which is the default as well, this parameter can be
skipped.
-mount-point Mount point of SSH backend. If the backend is mounted at
'ssh', which is the default as well, this parameter can be
skipped.
-format If no-exec option is enabled, then the credentials will be
printed out and SSH connection will not be established. The
format of the output can be 'json' or 'table'. JSON output
is useful when writing scripts. Default is 'table'.
-format If no-exec option is enabled, then the credentials will be
printed out and SSH connection will not be established. The
format of the output can be 'json' or 'table'. JSON output
is useful when writing scripts. Default is 'table'.
-strict-host-key-checking This option corresponds to StrictHostKeyChecking of SSH configuration.
If 'sshpass' is employed to enable automated login, then if host key
is not "known" to the client, 'vault ssh' command will fail. Set this
option to "no" to bypass the host key checking. Defaults to "ask".
Can also be specified with VAULT_SSH_STRICT_HOST_KEY_CHECKING environment
variable.
-user-known-hosts-file This option corresponds to UserKnownHostsFile of SSH configuration.
Assigns the file to use for storing the host keys. If this option is
set to "/dev/null" along with "-strict-host-key-checking=no", both
warnings and host key checking can be avoided while establishing the
connection. Defaults to "~/.ssh/known_hosts". Can also be specified
with VAULT_SSH_USER_KNOWN_HOSTS_FILE environment variable.
`
return strings.TrimSpace(helpText)
}

View file

@ -123,8 +123,8 @@ Token Options:
-lease="1h" Deprecated; use "-ttl" instead.
-ttl="1h" TTL to associate with the token. This option enables
the tokens to be renewable.
-ttl="1h" Initial TTL to associate with the token; renewals can
extend this value.
-metadata="key=value" Metadata to associate with the token. This shows
up in the audit log. This can be specified multiple

98
command/unwrap.go Normal file
View file

@ -0,0 +1,98 @@
package command
import (
"flag"
"fmt"
"strings"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/meta"
)
// UnwrapCommand is a Command that behaves like ReadCommand but specifically
// for unwrapping cubbyhole-wrapped secrets
type UnwrapCommand struct {
meta.Meta
}
func (c *UnwrapCommand) Run(args []string) int {
var format string
var field string
var err error
var secret *api.Secret
var flags *flag.FlagSet
flags = c.Meta.FlagSet("unwrap", meta.FlagSetDefault)
flags.StringVar(&format, "format", "table", "")
flags.StringVar(&field, "field", "", "")
flags.Usage = func() { c.Ui.Error(c.Help()) }
if err := flags.Parse(args); err != nil {
return 1
}
args = flags.Args()
if len(args) != 1 || len(args[0]) == 0 {
c.Ui.Error("Unwrap expects one argument: the ID of the wrapping token")
flags.Usage()
return 1
}
tokenID := args[0]
_, err = uuid.ParseUUID(tokenID)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Given token could not be parsed as a UUID: %s", err))
return 1
}
client, err := c.Client()
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing client: %s", err))
return 2
}
secret, err = client.Logical().Unwrap(tokenID)
if err != nil {
c.Ui.Error(err.Error())
return 1
}
if secret == nil {
c.Ui.Error("Secret returned was nil")
return 1
}
// Handle single field output
if field != "" {
return PrintRawField(c.Ui, secret, field)
}
return OutputSecret(c.Ui, format, secret)
}
func (c *UnwrapCommand) Synopsis() string {
return "Unwrap a wrapped secret"
}
func (c *UnwrapCommand) Help() string {
helpText := `
Usage: vault unwrap [options] <wrapping token ID>
Unwrap a wrapped secret.
Unwraps the data wrapped by the given token ID. The returned result is the
same as a 'read' operation on a non-wrapped secret.
General Options:
` + meta.GeneralOptionsUsage() + `
Read Options:
-format=table The format for output. By default it is a whitespace-
delimited table. This can also be json or yaml.
-field=field If included, the raw value of the specified field
will be output raw to stdout.
`
return strings.TrimSpace(helpText)
}

74
command/unwrap_test.go Normal file
View file

@ -0,0 +1,74 @@
package command
import (
"testing"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
)
func TestUnwrap(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
ui := new(cli.MockUi)
c := &UnwrapCommand{
Meta: meta.Meta{
ClientToken: token,
Ui: ui,
},
}
args := []string{
"-address", addr,
"-field", "zip",
}
// Run once so the client is setup, ignore errors
c.Run(args)
// Get the client so we can write data
client, err := c.Client()
if err != nil {
t.Fatalf("err: %s", err)
}
wrapLookupFunc := func(method, path string) string {
if method == "GET" && path == "secret/foo" {
return "60s"
}
return ""
}
client.SetWrappingLookupFunc(wrapLookupFunc)
data := map[string]interface{}{"zip": "zap"}
if _, err := client.Logical().Write("secret/foo", data); err != nil {
t.Fatalf("err: %s", err)
}
outer, err := client.Logical().Read("secret/foo")
if err != nil {
t.Fatalf("err: %s", err)
}
if outer == nil {
t.Fatal("outer response was nil")
}
if outer.WrapInfo == nil {
t.Fatal("outer wrapinfo was nil, response was %#v", *outer)
}
args = append(args, outer.WrapInfo.Token)
// Run the read
if code := c.Run(args); code != 0 {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
}
output := ui.OutputWriter.String()
if output != "zap\n" {
t.Fatalf("unexpectd output:\n%s", output)
}
}

View file

@ -1,6 +1,14 @@
package command
import "github.com/hashicorp/vault/command/token"
import (
"fmt"
"os"
"reflect"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/token"
"github.com/mitchellh/cli"
)
// DefaultTokenHelper returns the token helper that is configured for Vault.
func DefaultTokenHelper() (token.TokenHelper, error) {
@ -20,3 +28,43 @@ func DefaultTokenHelper() (token.TokenHelper, error) {
}
return &token.ExternalTokenHelper{BinaryPath: path}, nil
}
func PrintRawField(ui cli.Ui, secret *api.Secret, field string) int {
var val interface{}
switch field {
case "wrapping_token":
if secret.WrapInfo != nil {
val = secret.WrapInfo.Token
}
case "wrapping_token_ttl":
if secret.WrapInfo != nil {
val = secret.WrapInfo.TTL
}
case "wrapping_token_creation_time":
if secret.WrapInfo != nil {
val = secret.WrapInfo.CreationTime.String()
}
case "refresh_interval":
val = secret.LeaseDuration
default:
val = secret.Data[field]
}
if val != nil {
// c.Ui.Output() prints a CR character which in this case is
// not desired. Since Vault CLI currently only uses BasicUi,
// which writes to standard output, os.Stdout is used here to
// directly print the message. If mitchellh/cli exposes method
// to print without CR, this check needs to be removed.
if reflect.TypeOf(ui).String() == "*cli.BasicUi" {
fmt.Fprintf(os.Stdout, fmt.Sprintf("%v", val))
} else {
ui.Output(fmt.Sprintf("%v", val))
}
return 0
} else {
ui.Error(fmt.Sprintf(
"Field %s not present in secret", field))
return 1
}
}

109
command/wrapping_test.go Normal file
View file

@ -0,0 +1,109 @@
package command
import (
"os"
"testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
)
func TestWrapping_Env(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
ui := new(cli.MockUi)
c := &TokenLookupCommand{
Meta: meta.Meta{
ClientToken: token,
Ui: ui,
},
}
args := []string{
"-address", addr,
}
// Run it once for client
c.Run(args)
// Create a new token for us to use
client, err := c.Client()
if err != nil {
t.Fatalf("err: %s", err)
}
resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{
Lease: "1h",
})
if err != nil {
t.Fatalf("err: %s", err)
}
prevWrapTTLEnv := os.Getenv(api.EnvVaultWrapTTL)
os.Setenv(api.EnvVaultWrapTTL, "5s")
defer func() {
os.Setenv(api.EnvVaultWrapTTL, prevWrapTTLEnv)
}()
// Now when we do a lookup-self the response should be wrapped
args = append(args, resp.Auth.ClientToken)
resp, err = client.Auth().Token().LookupSelf()
if err != nil {
t.Fatalf("err: %s", err)
}
if resp == nil {
t.Fatal("nil response")
}
if resp.WrapInfo == nil {
t.Fatal("nil wrap info")
}
if resp.WrapInfo.Token == "" || resp.WrapInfo.TTL != 5 {
t.Fatal("did not get token or ttl wrong")
}
}
func TestWrapping_Flag(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
ui := new(cli.MockUi)
c := &TokenLookupCommand{
Meta: meta.Meta{
ClientToken: token,
Ui: ui,
},
}
args := []string{
"-address", addr,
"-wrap-ttl", "5s",
}
// Run it once for client
c.Run(args)
// Create a new token for us to use
client, err := c.Client()
if err != nil {
t.Fatalf("err: %s", err)
}
resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{
Lease: "1h",
})
if err != nil {
t.Fatalf("err: %s", err)
}
if resp == nil {
t.Fatal("nil response")
}
if resp.WrapInfo == nil {
t.Fatal("nil wrap info")
}
if resp.WrapInfo.Token == "" || resp.WrapInfo.TTL != 5 {
t.Fatal("did not get token or ttl wrong")
}
}

View file

@ -7,6 +7,12 @@ import (
"github.com/hashicorp/vault/helper/strutil"
)
// ParsePolicies parses a comma-delimited list of policies.
// The resulting collection will have no duplicate elements.
// If 'root' policy was present in the list of policies, then
// all other policies will be ignored, the result will contain
// just the 'root'. In cases where 'root' is not present, if
// 'default' policy is not already present, it will be added.
func ParsePolicies(policiesRaw string) []string {
if policiesRaw == "" {
return []string{"default"}
@ -14,10 +20,18 @@ func ParsePolicies(policiesRaw string) []string {
policies := strings.Split(policiesRaw, ",")
return SanitizePolicies(policies)
return SanitizePolicies(policies, true)
}
func SanitizePolicies(policies []string) []string {
// SanitizePolicies performs the common input validation tasks
// which are performed on the list of policies across Vault.
// The resulting collection will have no duplicate elements.
// If 'root' policy was present in the list of policies, then
// all other policies will be ignored, the result will contain
// just the 'root'. In cases where 'root' is not present, if
// 'default' policy is not already present, it will be added
// if addDefault is set to true.
func SanitizePolicies(policies []string, addDefault bool) []string {
defaultFound := false
for i, p := range policies {
policies[i] = strings.ToLower(strings.TrimSpace(p))
@ -38,7 +52,7 @@ func SanitizePolicies(policies []string) []string {
}
// Always add 'default' except only if the policies contain 'root'.
if len(policies) == 0 || !defaultFound {
if addDefault && (len(policies) == 0 || !defaultFound) {
policies = append(policies, "default")
}

View file

@ -2,6 +2,21 @@ package policyutil
import "testing"
func TestSanitizePolicies(t *testing.T) {
expected := []string{"foo", "bar"}
actual := SanitizePolicies([]string{"foo", "bar"}, false)
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
// If 'default' is already added, do not remove it.
expected = []string{"foo", "bar", "default"}
actual = SanitizePolicies([]string{"foo", "bar", "default"}, false)
if !EquivalentPolicies(expected, actual) {
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
}
}
func TestParsePolicies(t *testing.T) {
expected := []string{"foo", "bar", "default"}
actual := ParsePolicies("foo,bar")

View file

@ -6,15 +6,23 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
// AuthHeaderName is the name of the header containing the token.
const AuthHeaderName = "X-Vault-Token"
const (
// AuthHeaderName is the name of the header containing the token.
AuthHeaderName = "X-Vault-Token"
// WrapHeaderName is the name of the header containing a directive to wrap the
// response.
WrapTTLHeaderName = "X-Vault-Wrap-TTL"
)
// Handler returns an http.Handler for the API. This can be used on
// its own to mount the Vault API within another web server.
@ -85,7 +93,7 @@ func parseRequest(r *http.Request, out interface{}) error {
// case of an error.
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) {
resp, err := core.HandleRequest(r)
if err == vault.ErrStandby {
if errwrap.Contains(err, vault.ErrStandby.Error()) {
respondStandby(core, w, rawReq.URL)
return resp, false
}
@ -153,13 +161,44 @@ func requestAuth(r *http.Request, req *logical.Request) *logical.Request {
return req
}
// requestWrapTTL adds the WrapTTL value to the logical.Request if it
// exists.
func requestWrapTTL(r *http.Request, req *logical.Request) (*logical.Request, error) {
// First try for the header value
wrapTTL := r.Header.Get(WrapTTLHeaderName)
if wrapTTL == "" {
return req, nil
}
// If it has an allowed suffix parse as a duration string
if strings.HasSuffix(wrapTTL, "s") || strings.HasSuffix(wrapTTL, "m") || strings.HasSuffix(wrapTTL, "h") {
dur, err := time.ParseDuration(wrapTTL)
if err != nil {
return req, err
}
req.WrapTTL = dur
} else {
// Parse as a straight number of seconds
seconds, err := strconv.ParseInt(wrapTTL, 10, 64)
if err != nil {
return req, err
}
req.WrapTTL = time.Duration(seconds) * time.Second
}
if int64(req.WrapTTL) < 0 {
return req, fmt.Errorf("requested wrap ttl cannot be negative")
}
return req, nil
}
// Determines the type of the error being returned and sets the HTTP
// status code appropriately
func respondErrorStatus(w http.ResponseWriter, err error) {
status := http.StatusInternalServerError
switch {
// Keep adding more error types here to appropriate the status codes
case errwrap.ContainsType(err, new(vault.StatusBadRequest)):
case err != nil && errwrap.ContainsType(err, new(vault.StatusBadRequest)):
status = http.StatusBadRequest
}
respondError(w, status, err)
@ -167,7 +206,7 @@ func respondErrorStatus(w http.ResponseWriter, err error) {
func respondError(w http.ResponseWriter, status int, err error) {
// Adjust status code when sealed
if err == vault.ErrSealed {
if errwrap.Contains(err, vault.ErrSealed.Error()) {
status = http.StatusServiceUnavailable
}
@ -194,19 +233,19 @@ func respondCommon(w http.ResponseWriter, resp *logical.Response, err error) boo
}
if resp.IsError() {
var statusCode int
statusCode := http.StatusBadRequest
switch err {
case logical.ErrPermissionDenied:
statusCode = http.StatusForbidden
case logical.ErrUnsupportedOperation:
statusCode = http.StatusMethodNotAllowed
case logical.ErrUnsupportedPath:
statusCode = http.StatusNotFound
case logical.ErrInvalidRequest:
statusCode = http.StatusBadRequest
default:
statusCode = http.StatusBadRequest
if err != nil {
switch err {
case logical.ErrPermissionDenied:
statusCode = http.StatusForbidden
case logical.ErrUnsupportedOperation:
statusCode = http.StatusMethodNotAllowed
case logical.ErrUnsupportedPath:
statusCode = http.StatusNotFound
case logical.ErrInvalidRequest:
statusCode = http.StatusBadRequest
}
}
err := fmt.Errorf("%s", resp.Data["error"].(string))

View file

@ -1,10 +1,12 @@
package http
import (
"bytes"
"errors"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"github.com/hashicorp/go-cleanhttp"
@ -64,6 +66,33 @@ func TestSysMounts_headerAuth(t *testing.T) {
}
}
// We use this test to verify header auth wrapping
func TestSysMounts_headerAuth_Wrapped(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := TestServer(t, core)
defer ln.Close()
req, err := http.NewRequest("GET", addr+"/v1/sys/mounts", nil)
if err != nil {
t.Fatalf("err: %s", err)
}
req.Header.Set(AuthHeaderName, token)
req.Header.Set(WrapTTLHeaderName, "60s")
client := cleanhttp.DefaultClient()
resp, err := client.Do(req)
if err != nil {
t.Fatalf("err: %s", err)
}
testResponseStatus(t, resp, 200)
buf := bytes.NewBuffer(nil)
buf.ReadFrom(resp.Body)
if strings.TrimSpace(buf.String()) != "null" {
t.Fatalf("bad: %v", buf.String())
}
}
func TestHandler_sealed(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := TestServer(t, core)

View file

@ -7,77 +7,89 @@ import (
"strconv"
"strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
type PrepareRequestFunc func(req *logical.Request) error
func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Request, int, error) {
// Determine the path...
if !strings.HasPrefix(r.URL.Path, "/v1/") {
return nil, http.StatusNotFound, nil
}
path := r.URL.Path[len("/v1/"):]
if path == "" {
return nil, http.StatusNotFound, nil
}
// Determine the operation
var op logical.Operation
switch r.Method {
case "DELETE":
op = logical.DeleteOperation
case "GET":
op = logical.ReadOperation
// Need to call ParseForm to get query params loaded
queryVals := r.URL.Query()
listStr := queryVals.Get("list")
if listStr != "" {
list, err := strconv.ParseBool(listStr)
if err != nil {
return nil, http.StatusBadRequest, nil
}
if list {
op = logical.ListOperation
}
}
case "POST", "PUT":
op = logical.UpdateOperation
case "LIST":
op = logical.ListOperation
default:
return nil, http.StatusMethodNotAllowed, nil
}
// Parse the request if we can
var data map[string]interface{}
if op == logical.UpdateOperation {
err := parseRequest(r, &data)
if err == io.EOF {
data = nil
err = nil
}
if err != nil {
return nil, http.StatusBadRequest, err
}
}
var err error
req := requestAuth(r, &logical.Request{
Operation: op,
Path: path,
Data: data,
Connection: getConnection(r),
})
req, err = requestWrapTTL(r, req)
if err != nil {
return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
}
return req, 0, nil
}
func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback PrepareRequestFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Determine the path...
if !strings.HasPrefix(r.URL.Path, "/v1/") {
respondError(w, http.StatusNotFound, nil)
return
}
path := r.URL.Path[len("/v1/"):]
if path == "" {
respondError(w, http.StatusNotFound, nil)
req, statusCode, err := buildLogicalRequest(w, r)
if err != nil || statusCode != 0 {
respondError(w, statusCode, err)
return
}
// Determine the operation
var op logical.Operation
switch r.Method {
case "DELETE":
op = logical.DeleteOperation
case "GET":
op = logical.ReadOperation
// Need to call ParseForm to get query params loaded
queryVals := r.URL.Query()
listStr := queryVals.Get("list")
if listStr != "" {
list, err := strconv.ParseBool(listStr)
if err != nil {
respondError(w, http.StatusBadRequest, nil)
}
if list {
op = logical.ListOperation
}
}
case "POST", "PUT":
op = logical.UpdateOperation
case "LIST":
op = logical.ListOperation
default:
respondError(w, http.StatusMethodNotAllowed, nil)
return
}
// Parse the request if we can
var data map[string]interface{}
if op == logical.UpdateOperation {
err := parseRequest(r, &data)
if err == io.EOF {
data = nil
err = nil
}
if err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
}
req := requestAuth(r, &logical.Request{
Operation: op,
Path: path,
Data: data,
Connection: getConnection(r),
})
// Certain endpoints may require changes to the request object.
// They will have a callback registered to do the needful.
// Invoking it before proceeding.
// Certain endpoints may require changes to the request object. They
// will have a callback registered to do the needed operations, so
// invoke it before proceeding.
if prepareRequestCallback != nil {
if err := prepareRequestCallback(req); err != nil {
respondError(w, http.StatusInternalServerError, err)
@ -93,7 +105,7 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
return
}
switch {
case op == logical.ReadOperation:
case req.Operation == logical.ReadOperation:
if resp == nil {
respondError(w, http.StatusNotFound, nil)
return
@ -101,7 +113,7 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
// Basically: if we have empty "keys" or no keys at all, 404. This
// provides consistency with GET.
case op == logical.ListOperation:
case req.Operation == logical.ListOperation:
if resp == nil || len(resp.Data) == 0 {
respondError(w, http.StatusNotFound, nil)
return
@ -123,7 +135,7 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
}
// Build the proper response
respondLogical(w, r, path, dataOnly, resp)
respondLogical(w, r, req.Path, dataOnly, resp)
})
}
@ -148,34 +160,22 @@ func respondLogical(w http.ResponseWriter, r *http.Request, path string, dataOnl
return
}
logicalResp := &LogicalResponse{
Data: resp.Data,
Warnings: resp.Warnings(),
}
if resp.Secret != nil {
logicalResp.LeaseID = resp.Secret.LeaseID
logicalResp.Renewable = resp.Secret.Renewable
logicalResp.LeaseDuration = int(resp.Secret.TTL.Seconds())
}
// If we have authentication information, then
// set up the result structure.
if resp.Auth != nil {
logicalResp.Auth = &Auth{
ClientToken: resp.Auth.ClientToken,
Accessor: resp.Auth.Accessor,
Policies: resp.Auth.Policies,
Metadata: resp.Auth.Metadata,
LeaseDuration: int(resp.Auth.TTL.Seconds()),
Renewable: resp.Auth.Renewable,
if resp.WrapInfo != nil && resp.WrapInfo.Token != "" {
httpResp = logical.HTTPResponse{
WrapInfo: &logical.HTTPWrapInfo{
Token: resp.WrapInfo.Token,
TTL: int(resp.WrapInfo.TTL.Seconds()),
CreationTime: resp.WrapInfo.CreationTime,
},
}
} else {
httpResp = logical.SanitizeResponse(resp)
}
httpResp = logicalResp
}
// Respond
respondOk(w, httpResp)
return
}
// respondRaw is used when the response is using HTTPContentType and HTTPRawBody
@ -246,21 +246,3 @@ func getConnection(r *http.Request) (connection *logical.Connection) {
}
return
}
type LogicalResponse struct {
LeaseID string `json:"lease_id"`
Renewable bool `json:"renewable"`
LeaseDuration int `json:"lease_duration"`
Data map[string]interface{} `json:"data"`
Warnings []string `json:"warnings"`
Auth *Auth `json:"auth"`
}
type Auth struct {
ClientToken string `json:"client_token"`
Accessor string `json:"accessor"`
Policies []string `json:"policies"`
Metadata map[string]string `json:"metadata"`
LeaseDuration int `json:"lease_duration"`
Renewable bool `json:"renewable"`
}

View file

@ -40,8 +40,9 @@ func TestLogical(t *testing.T) {
"data": map[string]interface{}{
"data": "bar",
},
"auth": nil,
"warnings": nilWarnings,
"auth": nil,
"wrap_info": nil,
"warnings": nilWarnings,
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -140,8 +141,9 @@ func TestLogical_StandbyRedirect(t *testing.T) {
"role": "",
"explicit_max_ttl": float64(0),
},
"warnings": nilWarnings,
"auth": nil,
"warnings": nilWarnings,
"wrap_info": nil,
"auth": nil,
}
testResponseStatus(t, resp, 200)
@ -178,6 +180,7 @@ func TestLogical_CreateToken(t *testing.T) {
"renewable": false,
"lease_duration": float64(0),
"data": nil,
"wrap_info": nil,
"auth": map[string]interface{}{
"policies": []interface{}{"root"},
"metadata": nil,

View file

@ -3,6 +3,7 @@ package http
import (
"net/http"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/vault"
)
@ -20,7 +21,7 @@ func handleSysLeader(core *vault.Core) http.Handler {
func handleSysLeaderGet(core *vault.Core, w http.ResponseWriter, r *http.Request) {
haEnabled := true
isLeader, address, err := core.Leader()
if err == vault.ErrHANotEnabled {
if errwrap.Contains(err, vault.ErrHANotEnabled.Error()) {
haEnabled = false
err = nil
}

View file

@ -17,8 +17,8 @@ func TestSysPolicies(t *testing.T) {
var actual map[string]interface{}
expected := map[string]interface{}{
"policies": []interface{}{"default", "root"},
"keys": []interface{}{"default", "root"},
"policies": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
"keys": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -62,14 +62,19 @@ func TestSysWritePolicy(t *testing.T) {
var actual map[string]interface{}
expected := map[string]interface{}{
"policies": []interface{}{"default", "foo", "root"},
"keys": []interface{}{"default", "foo", "root"},
"policies": []interface{}{"cubbyhole-response-wrapping", "default", "foo", "root"},
"keys": []interface{}{"cubbyhole-response-wrapping", "default", "foo", "root"},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad: got\n%#v\nexpected\n%#v\n", actual, expected)
}
resp = testHttpPost(t, token, addr+"/v1/sys/policy/cubbyhole-response-wrapping", map[string]interface{}{
"rules": ``,
})
testResponseStatus(t, resp, 400)
}
func TestSysDeletePolicy(t *testing.T) {
@ -86,12 +91,17 @@ func TestSysDeletePolicy(t *testing.T) {
resp = testHttpDelete(t, token, addr+"/v1/sys/policy/foo")
testResponseStatus(t, resp, 204)
// Also attempt to delete these since they should not be allowed (ignore
// responses, if they exist later that's sufficient)
resp = testHttpDelete(t, token, addr+"/v1/sys/policy/default")
resp = testHttpDelete(t, token, addr+"/v1/sys/policy/cubbyhole-response-wrapping")
resp = testHttpGet(t, token, addr+"/v1/sys/policy")
var actual map[string]interface{}
expected := map[string]interface{}{
"policies": []interface{}{"default", "root"},
"keys": []interface{}{"default", "root"},
"policies": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
"keys": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)

View file

@ -13,19 +13,21 @@ import (
func handleSysSeal(core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "PUT":
case "POST":
req, statusCode, err := buildLogicalRequest(w, r)
if err != nil || statusCode != 0 {
respondError(w, statusCode, err)
return
}
switch req.Operation {
case logical.UpdateOperation:
default:
respondError(w, http.StatusMethodNotAllowed, nil)
return
}
// Get the auth for the request so we can access the token directly
req := requestAuth(r, &logical.Request{})
// Seal with the token above
if err := core.Seal(req.ClientToken); err != nil {
if err := core.SealWithRequest(req); err != nil {
respondError(w, http.StatusInternalServerError, err)
return
}
@ -36,19 +38,21 @@ func handleSysSeal(core *vault.Core) http.Handler {
func handleSysStepDown(core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "PUT":
case "POST":
req, statusCode, err := buildLogicalRequest(w, r)
if err != nil || statusCode != 0 {
respondError(w, statusCode, err)
return
}
switch req.Operation {
case logical.UpdateOperation:
default:
respondError(w, http.StatusMethodNotAllowed, nil)
return
}
// Get the auth for the request so we can access the token directly
req := requestAuth(r, &logical.Request{})
// Seal with the token above
if err := core.StepDown(req.ClientToken); err != nil {
if err := core.StepDown(req); err != nil {
respondError(w, http.StatusInternalServerError, err)
return
}

View file

@ -1,58 +0,0 @@
package logical
import (
"strings"
"sync"
)
// LockingInmemStorage implements Storage and stores all data in memory.
type LockingInmemStorage struct {
sync.RWMutex
Data map[string]*StorageEntry
once sync.Once
}
func (s *LockingInmemStorage) List(prefix string) ([]string, error) {
s.once.Do(s.init)
s.RLock()
defer s.RUnlock()
var result []string
for k, _ := range s.Data {
if strings.HasPrefix(k, prefix) {
result = append(result, k)
}
}
return result, nil
}
func (s *LockingInmemStorage) Get(key string) (*StorageEntry, error) {
s.once.Do(s.init)
s.RLock()
defer s.RUnlock()
return s.Data[key], nil
}
func (s *LockingInmemStorage) Put(entry *StorageEntry) error {
s.once.Do(s.init)
s.Lock()
defer s.Unlock()
s.Data[entry.Key] = entry
return nil
}
func (s *LockingInmemStorage) Delete(k string) error {
s.once.Do(s.init)
s.Lock()
defer s.Unlock()
delete(s.Data, k)
return nil
}
func (s *LockingInmemStorage) init() {
s.Data = make(map[string]*StorageEntry)
}

View file

@ -1,13 +0,0 @@
package logical
import (
"testing"
)
// Note: This uses the normal TestStorage, but the best way to exercise this is
// to run transit's unit tests, which spawn 1000 goroutines to hammer the
// backend for 10 seconds with this as the storage.
func TestLockingInmemStorage(t *testing.T) {
TestStorage(t, new(LockingInmemStorage))
}

View file

@ -3,6 +3,7 @@ package logical
import (
"errors"
"fmt"
"time"
)
// Request is a struct that stores the parameters and context
@ -52,6 +53,10 @@ type Request struct {
// paths relative to itself. The `Path` is effectively the client
// request path with the MountPoint trimmed off.
MountPoint string
// WrapTTL contains the requested TTL of the token used to wrap the
// response in a cubbyhole.
WrapTTL time.Duration
}
// Get returns a data field and guards for nil Data

View file

@ -3,6 +3,7 @@ package logical
import (
"fmt"
"reflect"
"time"
"github.com/mitchellh/copystructure"
)
@ -26,6 +27,19 @@ const (
HTTPStatusCode = "http_status_code"
)
type WrapInfo struct {
// Setting to non-zero specifies that the response should be wrapped.
// Specifies the desired TTL of the wrapping token.
TTL time.Duration
// The token containing the wrapped response
Token string
// The creation time. This can be used with the TTL to figure out an
// expected expiration.
CreationTime time.Time
}
// Response is a struct that stores the response of a request.
// It is used to abstract the details of the higher level request protocol.
type Response struct {
@ -54,6 +68,9 @@ type Response struct {
// Vault (backend, core, etc.) to add warnings without accidentally
// replacing what exists.
warnings []string
// Information for wrapping the response in a cubbyhole
WrapInfo *WrapInfo
}
func init() {
@ -74,7 +91,7 @@ func init() {
if input.Auth != nil {
retAuth, err := copystructure.Copy(input.Auth)
if err != nil {
return nil, fmt.Errorf("error copying Secret: %v", err)
return nil, fmt.Errorf("error copying Auth: %v", err)
}
ret.Auth = retAuth.(*Auth)
}
@ -82,7 +99,7 @@ func init() {
if input.Data != nil {
retData, err := copystructure.Copy(&input.Data)
if err != nil {
return nil, fmt.Errorf("error copying Secret: %v", err)
return nil, fmt.Errorf("error copying Data: %v", err)
}
ret.Data = retData.(map[string]interface{})
}
@ -93,6 +110,14 @@ func init() {
}
}
if input.WrapInfo != nil {
retWrapInfo, err := copystructure.Copy(input.WrapInfo)
if err != nil {
return nil, fmt.Errorf("error copying WrapInfo: %v", err)
}
ret.WrapInfo = retWrapInfo.(*WrapInfo)
}
return &ret, nil
}
}
@ -115,6 +140,11 @@ func (r *Response) ClearWarnings() {
r.warnings = make([]string, 0, 1)
}
// Copies the warnings from the other response to this one
func (r *Response) CloneWarnings(other *Response) {
r.warnings = other.warnings
}
// IsError returns true if this response seems to indicate an error.
func (r *Response) IsError() bool {
return r != nil && len(r.Data) == 1 && r.Data["error"] != nil

60
logical/sanitize.go Normal file
View file

@ -0,0 +1,60 @@
package logical
import "time"
// This logic was pulled from the http package so that it can be used for
// encoding wrapped responses as well. It simply translates the logical request
// to an http response, with the values we want and omitting the values we
// don't.
func SanitizeResponse(input *Response) *HTTPResponse {
logicalResp := &HTTPResponse{
Data: input.Data,
Warnings: input.Warnings(),
}
if input.Secret != nil {
logicalResp.LeaseID = input.Secret.LeaseID
logicalResp.Renewable = input.Secret.Renewable
logicalResp.LeaseDuration = int(input.Secret.TTL.Seconds())
}
// If we have authentication information, then
// set up the result structure.
if input.Auth != nil {
logicalResp.Auth = &HTTPAuth{
ClientToken: input.Auth.ClientToken,
Accessor: input.Auth.Accessor,
Policies: input.Auth.Policies,
Metadata: input.Auth.Metadata,
LeaseDuration: int(input.Auth.TTL.Seconds()),
Renewable: input.Auth.Renewable,
}
}
return logicalResp
}
type HTTPResponse struct {
LeaseID string `json:"lease_id"`
Renewable bool `json:"renewable"`
LeaseDuration int `json:"lease_duration"`
Data map[string]interface{} `json:"data"`
WrapInfo *HTTPWrapInfo `json:"wrap_info"`
Warnings []string `json:"warnings"`
Auth *HTTPAuth `json:"auth"`
}
type HTTPAuth struct {
ClientToken string `json:"client_token"`
Accessor string `json:"accessor"`
Policies []string `json:"policies"`
Metadata map[string]string `json:"metadata"`
LeaseDuration int `json:"lease_duration"`
Renewable bool `json:"renewable"`
}
type HTTPWrapInfo struct {
Token string `json:"token"`
TTL int `json:"ttl"`
CreationTime time.Time `json:"creation_time"`
}

View file

@ -1,13 +1,14 @@
package logical
import (
"strings"
"sync"
"github.com/hashicorp/vault/physical"
)
// InmemStorage implements Storage and stores all data in memory.
type InmemStorage struct {
Data map[string]*StorageEntry
phys *physical.InmemBackend
once sync.Once
}
@ -15,33 +16,38 @@ type InmemStorage struct {
func (s *InmemStorage) List(prefix string) ([]string, error) {
s.once.Do(s.init)
var result []string
for k, _ := range s.Data {
if strings.HasPrefix(k, prefix) {
result = append(result, k)
}
}
return result, nil
return s.phys.List(prefix)
}
func (s *InmemStorage) Get(key string) (*StorageEntry, error) {
s.once.Do(s.init)
return s.Data[key], nil
entry, err := s.phys.Get(key)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
return &StorageEntry{
Key: entry.Key,
Value: entry.Value,
}, nil
}
func (s *InmemStorage) Put(entry *StorageEntry) error {
s.once.Do(s.init)
s.Data[entry.Key] = entry
return nil
physEntry := &physical.Entry{
Key: entry.Key,
Value: entry.Value,
}
return s.phys.Put(physEntry)
}
func (s *InmemStorage) Delete(k string) error {
s.once.Do(s.init)
delete(s.Data, k)
return nil
return s.phys.Delete(k)
}
func (s *InmemStorage) init() {
s.Data = make(map[string]*StorageEntry)
s.phys = physical.NewInmem(nil)
}

View file

@ -9,6 +9,7 @@ import (
"sort"
"testing"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
@ -302,8 +303,10 @@ func Test(t TestT, c TestCase) {
if err == nil && resp.IsError() {
err = fmt.Errorf("Erroneous response:\n\n%#v", resp)
}
if err != nil && err != logical.ErrUnsupportedOperation {
t.Error(fmt.Sprintf("[ERR] Rollback error: %s", err))
if err != nil {
if !errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()) {
t.Error(fmt.Sprintf("[ERR] Rollback error: %s", err))
}
}
// If we have any failed revokes, log it.

Some files were not shown because too many files have changed in this diff Show more