Merge remote-tracking branch 'origin/master' into pr-1425
This commit is contained in:
commit
da6371ffc3
85
CHANGELOG.md
85
CHANGELOG.md
|
@ -1,4 +1,4 @@
|
|||
## 0.5.3 (Unreleased)
|
||||
## 0.6.0 (Unreleased)
|
||||
|
||||
SECURITY:
|
||||
|
||||
|
@ -38,13 +38,23 @@ DEPRECATIONS/BREAKING CHANGES:
|
|||
|
||||
FEATURES:
|
||||
|
||||
* **AWS EC2 Auth Backend**: Provides a recure introduction mechanism for AWS EC2
|
||||
instances allowing automated retrieval of Vault tokens. Unlike most Vault
|
||||
authentication backends, this backend does not require first-deploying,
|
||||
* **AWS EC2 Auth Backend**: Provides a secure introduction mechanism for AWS
|
||||
EC2 instances allowing automated retrieval of Vault tokens. Unlike most
|
||||
Vault authentication backends, this backend does not require first deploying
|
||||
or provisioning security-sensitive credentials (tokens, username/password,
|
||||
client certificates,etc). Instead, it treats AWS as a Trusted Third Party
|
||||
client certificates, etc). Instead, it treats AWS as a Trusted Third Party
|
||||
and uses the cryptographically signed dynamic metadata information that
|
||||
uniquely represents each EC2 instance.
|
||||
uniquely represents each EC2 instance. [Vault
|
||||
Enterprise](https://www.hashicorp.com/vault.html) customers have access to a
|
||||
turnkey client that speaks the backend API and makes access to a Vault token
|
||||
easy.
|
||||
* **Response Wrapping**: Nearly any response within Vault can now be wrapped
|
||||
inside a single-use, time-limited token's cubbyhole, taking the [Cubbyhole
|
||||
Authentication
|
||||
Principles](https://www.hashicorp.com/blog/vault-cubbyhole-principles.html)
|
||||
mechanism to its logical conclusion. Retrieving the original response is as
|
||||
simple as a single API command or the new `vault unwrap` command. This makes
|
||||
secret distribution easier and more secure, including secure introduction.
|
||||
* **Azure Physical Backend**: You can now use Azure blob object storage as
|
||||
your Vault physical data store [GH-1266]
|
||||
* **Consul Backend Health Checks**: The Consul backend will automatically
|
||||
|
@ -58,12 +68,18 @@ FEATURES:
|
|||
system- or mount-set values. This is useful, for instance, when the max TTL
|
||||
of the system or the `auth/token` mount must be set high to accommodate
|
||||
certain needs but you want more granular restrictions on tokens being issued
|
||||
directly from `auth/token`. [GH-1399]
|
||||
directly from the Token authentication backend at `auth/token`. [GH-1399]
|
||||
* **RabbitMQ Secret Backend**: Vault can now generate credentials for
|
||||
RabbitMQ. Vhosts and tags can be defined within roles. [GH-788]
|
||||
|
||||
IMPROVEMENTS:
|
||||
|
||||
* audit: Add the DisplayName value to the copy of the Request object embedded
|
||||
in the associated Response, to match the original Request object [GH-1387]
|
||||
* audit: Enable auditing of the `seal` and `step-down` commands [GH-1435]
|
||||
* backends: Remove most `root`/`sudo` paths in favor of normal ACL mechanisms.
|
||||
A particular exception are any current MFA paths. A few paths in `token` and
|
||||
`sys` also require `root` or `sudo`. [GH-1478]
|
||||
* command/auth: Restore the previous authenticated token if the `auth` command
|
||||
fails to authenticate the provided token [GH-1233]
|
||||
* command/write: `-format` and `-field` can now be used with the `write`
|
||||
|
@ -78,13 +94,21 @@ IMPROVEMENTS:
|
|||
backend [GH-1404]
|
||||
* credential/ldap: If `groupdn` is not configured, skip searching LDAP and
|
||||
only return policies for local groups, plus a warning [GH-1283]
|
||||
* credential/ldap: `vault list` support for users and groups [GH-1270]
|
||||
* credential/ldap: Support for the `memberOf` attribute for group membership
|
||||
searching [GH-1245]
|
||||
* credential/userpass: Add list support for users [GH-911]
|
||||
* credential/userpass: Remove user configuration paths from requiring sudo, in
|
||||
favor of normal ACL mechanisms [GH-1312]
|
||||
* credential/token: Sanitize policies and add `default` policies in appropriate
|
||||
places [GH-1235]
|
||||
* secret/aws: Use chain credentials to allow environment/EC2 instance/shared
|
||||
providers [GH-307]
|
||||
* secret/aws: Support for STS AssumeRole functionality [GH-1318]
|
||||
* secret/pki: Added `exclude_cn_from_sans` field to prevent adding the CN to
|
||||
DNS or Email Subject Alternate Names [GH-1220]
|
||||
* secret/consul: Reading consul access configuration supported. The response
|
||||
will contain non-sensitive information only [GH-1445]
|
||||
* sys/capabilities: Enforce ACL checks for requests that query the capabilities
|
||||
of a token on a given path [GH-1221]
|
||||
|
||||
|
@ -97,6 +121,11 @@ BUG FIXES:
|
|||
* command/various: Tell the JSON decoder to not convert all numbers to floats;
|
||||
fixes some various places where numbers were showing up in scientific
|
||||
notation
|
||||
* command/server: Prioritized `devRootTokenID` and `devListenAddress` flags
|
||||
over their respective env vars [GH-1480]
|
||||
* command/ssh: Provided option to disable host key checking. The automated
|
||||
variant of `vault ssh` command uses `sshpass` which was failing to handle
|
||||
host key checking presented by the `ssh` binary. [GH-1473]
|
||||
* core: Properly persist mount-tuned TTLs for auth backends [GH-1371]
|
||||
* core: Don't accidentally crosswire SIGINT to the reload handler [GH-1372]
|
||||
* credential/github: Make organization comparison case-insensitive during
|
||||
|
@ -114,9 +143,51 @@ BUG FIXES:
|
|||
* credential/various: Fix renewal conditions when `default` policy is not
|
||||
contained in the backend config [GH-1256]
|
||||
* physical/s3: Don't panic in certain error cases from bad S3 responses [GH-1353]
|
||||
* secret/consul: Use non-pooled Consul API client to avoid leaving files open
|
||||
[GH-1428]
|
||||
* secret/pki: Don't check whether a certificate is destined to be a CA
|
||||
certificate if sign-verbatim endpoint is used [GH-1250]
|
||||
|
||||
## 0.5.3 (May 27th, 2016)
|
||||
|
||||
SECURITY:
|
||||
|
||||
* Consul ACL Token Revocation: An issue was reported to us indicating that
|
||||
generated Consul ACL tokens were not being properly revoked. Upon
|
||||
investigation, we found that this behavior was reproducible in a specific
|
||||
scenario: when a generated lease for a Consul ACL token had been renewed
|
||||
prior to revocation. In this case, the generated token was not being
|
||||
properly persisted internally through the renewal function, leading to an
|
||||
error during revocation due to the missing token. Unfortunately, this was
|
||||
coded as a user error rather than an internal error, and the revocation
|
||||
logic was expecting internal errors if revocation failed. As a result, the
|
||||
revocation logic believed the revocation to have succeeded when it in fact
|
||||
failed, causing the lease to be dropped while the token was still valid
|
||||
within Consul. In this release, the Consul backend properly persists the
|
||||
token through renewals, and the revocation logic has been changed to
|
||||
consider any error type to have been a failure to revoke, causing the lease
|
||||
to persist and attempt to be revoked later.
|
||||
|
||||
We have written an example shell script that searches through Consul's ACL
|
||||
tokens and looks for those generated by Vault, which can be used as a template
|
||||
for a revocation script as deemed necessary for any particular security
|
||||
response. The script is available at
|
||||
https://gist.github.com/jefferai/6233c2963f9407a858d84f9c27d725c0
|
||||
|
||||
Please note that any outstanding leases for Consul tokens produced prior to
|
||||
0.5.3 that have been renewed will continue to exhibit this behavior. As a
|
||||
result, we recommend either revoking all tokens produced by the backend and
|
||||
issuing new ones, or if needed, a more advanced variant of the provided example
|
||||
could use the timestamp embedded in each generated token's name to decide which
|
||||
tokens are too old and should be deleted. This could then be run periodically
|
||||
up until the maximum lease time for any outstanding pre-0.5.3 tokens has
|
||||
expired.
|
||||
|
||||
This is a security-only release. There are no other code changes since 0.5.2.
|
||||
The binaries have one additional change: they are built against Go 1.6.1 rather
|
||||
than Go 1.6, as Go 1.6.1 contains two security fixes to the Go programming
|
||||
language itself.
|
||||
|
||||
## 0.5.2 (March 16th, 2016)
|
||||
|
||||
FEATURES:
|
||||
|
|
|
@ -6,7 +6,8 @@ Vault [![Build Status](https://travis-ci.org/hashicorp/vault.svg)](https://travi
|
|||
|
||||
- Website: https://www.vaultproject.io
|
||||
- IRC: `#vault-tool` on Freenode
|
||||
- Mailing list: [Google Groups](https://groups.google.com/group/vault-tool)
|
||||
- Announcement list: [Google Groups](https://groups.google.com/group/hashicorp-announce)
|
||||
- Discussion list: [Google Groups](https://groups.google.com/group/vault-tool)
|
||||
|
||||
![Vault](https://raw.githubusercontent.com/hashicorp/vault/master/website/source/assets/images/logo-big.png?token=AAAFE8XmW6YF5TNuk3cosDGBK-sUGPEjks5VSAa2wA%3D%3D)
|
||||
|
||||
|
|
|
@ -23,11 +23,19 @@ const EnvVaultClientCert = "VAULT_CLIENT_CERT"
|
|||
const EnvVaultClientKey = "VAULT_CLIENT_KEY"
|
||||
const EnvVaultInsecure = "VAULT_SKIP_VERIFY"
|
||||
const EnvVaultTLSServerName = "VAULT_TLS_SERVER_NAME"
|
||||
const EnvVaultWrapTTL = "VAULT_WRAP_TTL"
|
||||
|
||||
var (
|
||||
errRedirect = errors.New("redirect")
|
||||
)
|
||||
|
||||
// WrappingLookupFunc is a function that, given an HTTP verb and a path,
|
||||
// returns an optional string duration to be used for response wrapping (e.g.
|
||||
// "15s", or simply "15"). The path will not begin with "/v1/" or "v1/" or "/",
|
||||
// however, end-of-path forward slashes are not trimmed, so must match your
|
||||
// called path precisely.
|
||||
type WrappingLookupFunc func(operation, path string) string
|
||||
|
||||
// Config is used to configure the creation of the client.
|
||||
type Config struct {
|
||||
// Address is the address of the Vault server. This should be a complete
|
||||
|
@ -155,9 +163,10 @@ func (c *Config) ReadEnvironment() error {
|
|||
// Client is the client to the Vault API. Create a client with
|
||||
// NewClient.
|
||||
type Client struct {
|
||||
addr *url.URL
|
||||
config *Config
|
||||
token string
|
||||
addr *url.URL
|
||||
config *Config
|
||||
token string
|
||||
wrappingLookupFunc WrappingLookupFunc
|
||||
}
|
||||
|
||||
// NewClient returns a new client for the given configuration.
|
||||
|
@ -166,7 +175,6 @@ type Client struct {
|
|||
// automatically added to the client. Otherwise, you must manually call
|
||||
// `SetToken()`.
|
||||
func NewClient(c *Config) (*Client, error) {
|
||||
|
||||
u, err := url.Parse(c.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -200,6 +208,12 @@ func NewClient(c *Config) (*Client, error) {
|
|||
return client, nil
|
||||
}
|
||||
|
||||
// SetWrappingLookupFunc sets a lookup function that returns desired wrap TTLs
|
||||
// for a given operation and path
|
||||
func (c *Client) SetWrappingLookupFunc(lookupFunc WrappingLookupFunc) {
|
||||
c.wrappingLookupFunc = lookupFunc
|
||||
}
|
||||
|
||||
// Token returns the access token being used by this client. It will
|
||||
// return the empty string if there is no token set.
|
||||
func (c *Client) Token() string {
|
||||
|
@ -232,6 +246,19 @@ func (c *Client) NewRequest(method, path string) *Request {
|
|||
Params: make(map[string][]string),
|
||||
}
|
||||
|
||||
if c.wrappingLookupFunc != nil {
|
||||
var lookupPath string
|
||||
switch {
|
||||
case strings.HasPrefix(path, "/v1/"):
|
||||
lookupPath = strings.TrimPrefix(path, "/v1/")
|
||||
case strings.HasPrefix(path, "v1/"):
|
||||
lookupPath = strings.TrimPrefix(path, "v1/")
|
||||
default:
|
||||
lookupPath = path
|
||||
}
|
||||
req.WrapTTL = c.wrappingLookupFunc(method, lookupPath)
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,15 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
wrappedResponseLocation = "cubbyhole/response"
|
||||
)
|
||||
|
||||
// Logical is used to perform logical backend operations on Vault.
|
||||
type Logical struct {
|
||||
c *Client
|
||||
|
@ -80,3 +90,34 @@ func (c *Logical) Delete(path string) (*Secret, error) {
|
|||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Logical) Unwrap(wrappingToken string) (*Secret, error) {
|
||||
origToken := c.c.Token()
|
||||
defer c.c.SetToken(origToken)
|
||||
|
||||
c.c.SetToken(wrappingToken)
|
||||
|
||||
secret, err := c.Read(wrappedResponseLocation)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading %s: %s", wrappedResponseLocation, err)
|
||||
}
|
||||
if secret == nil {
|
||||
return nil, fmt.Errorf("no value found at %s", wrappedResponseLocation)
|
||||
}
|
||||
if secret.Data == nil {
|
||||
return nil, fmt.Errorf("\"data\" not found in wrapping response")
|
||||
}
|
||||
if _, ok := secret.Data["response"]; !ok {
|
||||
return nil, fmt.Errorf("\"response\" not found in wrapping response \"data\" map")
|
||||
}
|
||||
|
||||
wrappedSecret := new(Secret)
|
||||
buf := bytes.NewBufferString(secret.Data["response"].(string))
|
||||
dec := json.NewDecoder(buf)
|
||||
dec.UseNumber()
|
||||
if err := dec.Decode(wrappedSecret); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling wrapped secret: %s", err)
|
||||
}
|
||||
|
||||
return wrappedSecret, nil
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ type Request struct {
|
|||
URL *url.URL
|
||||
Params url.Values
|
||||
ClientToken string
|
||||
WrapTTL string
|
||||
Obj interface{}
|
||||
Body io.Reader
|
||||
BodySize int64
|
||||
|
@ -62,5 +63,9 @@ func (r *Request) ToHTTP() (*http.Request, error) {
|
|||
req.Header.Set("X-Vault-Token", r.ClientToken)
|
||||
}
|
||||
|
||||
if len(r.WrapTTL) != 0 {
|
||||
req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package api
|
|||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Secret is the structure returned for every secret within Vault.
|
||||
|
@ -23,6 +24,18 @@ type Secret struct {
|
|||
// Auth, if non-nil, means that there was authentication information
|
||||
// attached to this response.
|
||||
Auth *SecretAuth `json:"auth,omitempty"`
|
||||
|
||||
// WrapInfo, if non-nil, means that the initial response was wrapped in the
|
||||
// cubbyhole of the given token (which has a TTL of the given number of
|
||||
// seconds)
|
||||
WrapInfo *SecretWrapInfo `json:"wrap_info,omitempty"`
|
||||
}
|
||||
|
||||
// SecretWrapInfo contains wrapping information if we have it.
|
||||
type SecretWrapInfo struct {
|
||||
Token string `json:"token"`
|
||||
TTL int `json:"ttl"`
|
||||
CreationTime time.Time `json:"creation_time"`
|
||||
}
|
||||
|
||||
// SecretAuth is the structure containing auth information if we have it.
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseSecret(t *testing.T) {
|
||||
|
@ -17,9 +18,16 @@ func TestParseSecret(t *testing.T) {
|
|||
},
|
||||
"warnings": [
|
||||
"a warning!"
|
||||
]
|
||||
],
|
||||
"wrap_info": {
|
||||
"token": "token",
|
||||
"ttl": 60,
|
||||
"creation_time": "2016-06-07T15:52:10-04:00"
|
||||
}
|
||||
}`)
|
||||
|
||||
rawTime, _ := time.Parse(time.RFC3339, "2016-06-07T15:52:10-04:00")
|
||||
|
||||
secret, err := ParseSecret(strings.NewReader(raw))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
|
@ -35,8 +43,13 @@ func TestParseSecret(t *testing.T) {
|
|||
Warnings: []string{
|
||||
"a warning!",
|
||||
},
|
||||
WrapInfo: &SecretWrapInfo{
|
||||
Token: "token",
|
||||
TTL: 60,
|
||||
CreationTime: rawTime,
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(secret, expected) {
|
||||
t.Fatalf("bad: %#v %#v", secret, expected)
|
||||
t.Fatalf("bad:\ngot\n%#v\nexpected\n%#v\n", secret, expected)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,6 +46,7 @@ func (f *FormatJSON) FormatRequest(
|
|||
Path: req.Path,
|
||||
Data: req.Data,
|
||||
RemoteAddr: getRemoteAddr(req),
|
||||
WrapTTL: int(req.WrapTTL / time.Second),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
@ -86,6 +87,15 @@ func (f *FormatJSON) FormatResponse(
|
|||
}
|
||||
}
|
||||
|
||||
var respWrapInfo *JSONWrapInfo
|
||||
if resp.WrapInfo != nil {
|
||||
respWrapInfo = &JSONWrapInfo{
|
||||
TTL: int(resp.WrapInfo.TTL / time.Second),
|
||||
Token: resp.WrapInfo.Token,
|
||||
CreationTime: resp.WrapInfo.CreationTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode!
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(&JSONResponseEntry{
|
||||
|
@ -104,6 +114,7 @@ func (f *FormatJSON) FormatResponse(
|
|||
Path: req.Path,
|
||||
Data: req.Data,
|
||||
RemoteAddr: getRemoteAddr(req),
|
||||
WrapTTL: int(req.WrapTTL / time.Second),
|
||||
},
|
||||
|
||||
Response: JSONResponse{
|
||||
|
@ -111,6 +122,7 @@ func (f *FormatJSON) FormatResponse(
|
|||
Secret: respSecret,
|
||||
Data: resp.Data,
|
||||
Redirect: resp.Redirect,
|
||||
WrapInfo: respWrapInfo,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
@ -140,6 +152,7 @@ type JSONRequest struct {
|
|||
Path string `json:"path"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
RemoteAddr string `json:"remote_address"`
|
||||
WrapTTL int `json:"wrap_ttl"`
|
||||
}
|
||||
|
||||
type JSONResponse struct {
|
||||
|
@ -147,6 +160,7 @@ type JSONResponse struct {
|
|||
Secret *JSONSecret `json:"secret,emitempty"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Redirect string `json:"redirect"`
|
||||
WrapInfo *JSONWrapInfo `json:"wrap_info,omitempty"`
|
||||
}
|
||||
|
||||
type JSONAuth struct {
|
||||
|
@ -161,6 +175,12 @@ type JSONSecret struct {
|
|||
LeaseID string `json:"lease_id"`
|
||||
}
|
||||
|
||||
type JSONWrapInfo struct {
|
||||
TTL int `json:"ttl"`
|
||||
Token string `json:"token"`
|
||||
CreationTime time.Time `json:"creation_time"`
|
||||
}
|
||||
|
||||
// getRemoteAddr safely gets the remote address avoiding a nil pointer
|
||||
func getRemoteAddr(req *logical.Request) string {
|
||||
if req != nil && req.Connection != nil {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"errors"
|
||||
|
||||
|
@ -26,6 +27,7 @@ func TestFormatJSON_formatRequest(t *testing.T) {
|
|||
Connection: &logical.Connection{
|
||||
RemoteAddr: "127.0.0.1",
|
||||
},
|
||||
WrapTTL: 60 * time.Second,
|
||||
},
|
||||
errors.New("this is an error"),
|
||||
testFormatJSONReqBasicStr,
|
||||
|
@ -64,5 +66,5 @@ func TestFormatJSON_formatRequest(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
const testFormatJSONReqBasicStr = `{"time":"2015-08-05T13:45:46Z","type":"request","auth":{"display_name":"","policies":["root"],"metadata":null},"request":{"operation":"update","path":"/foo","data":null,"remote_address":"127.0.0.1"},"error":"this is an error"}
|
||||
const testFormatJSONReqBasicStr = `{"time":"2015-08-05T13:45:46Z","type":"request","auth":{"display_name":"","policies":["root"],"metadata":null},"request":{"operation":"update","path":"/foo","data":null,"wrap_ttl":60,"remote_address":"127.0.0.1"},"error":"this is an error"}
|
||||
`
|
||||
|
|
|
@ -67,12 +67,25 @@ func Hash(salter *salt.Salt, raw interface{}) error {
|
|||
}
|
||||
}
|
||||
|
||||
if s.WrapInfo != nil {
|
||||
if err := Hash(salter, s.WrapInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
data, err := HashStructure(s.Data, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.Data = data.(map[string]interface{})
|
||||
|
||||
case *logical.WrapInfo:
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.Token = fn(s.Token)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -44,6 +44,7 @@ func TestCopy_request(t *testing.T) {
|
|||
Data: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
WrapTTL: 60 * time.Second,
|
||||
}
|
||||
arg := expected
|
||||
|
||||
|
@ -66,6 +67,11 @@ func TestCopy_response(t *testing.T) {
|
|||
Data: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
WrapInfo: &logical.WrapInfo{
|
||||
TTL: 60,
|
||||
Token: "foo",
|
||||
CreationTime: time.Now(),
|
||||
},
|
||||
}
|
||||
arg := expected
|
||||
|
||||
|
@ -131,11 +137,21 @@ func TestHash(t *testing.T) {
|
|||
Data: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
WrapInfo: &logical.WrapInfo{
|
||||
TTL: 60,
|
||||
Token: "bar",
|
||||
CreationTime: now,
|
||||
},
|
||||
},
|
||||
&logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
|
||||
},
|
||||
WrapInfo: &logical.WrapInfo{
|
||||
TTL: 60,
|
||||
Token: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
|
||||
CreationTime: now,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
@ -124,7 +124,7 @@ func (b *backend) pathLoginRenew(
|
|||
return nil, err
|
||||
}
|
||||
if !policyutil.EquivalentPolicies(mapPolicies, req.Auth.Policies) {
|
||||
return logical.ErrorResponse("policies do not match"), nil
|
||||
return nil, fmt.Errorf("policies do not match")
|
||||
}
|
||||
|
||||
return framework.LeaseExtend(0, 0, b.System())(req, d)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
@ -51,7 +51,7 @@ type backend struct {
|
|||
EC2ClientsMap map[string]*ec2.EC2
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
||||
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
|
||||
HashFunc: salt.SHA256Hash,
|
||||
})
|
||||
|
@ -96,7 +96,7 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
|||
},
|
||||
}
|
||||
|
||||
return b.Backend, nil
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// periodicFunc performs the tasks that the backend wishes to do periodically.
|
||||
|
@ -108,13 +108,12 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
|||
// Tidying of blacklist and whitelist are by default enabled. This can be
|
||||
// changed using `config/tidy/roletags` and `config/tidy/identities` endpoints.
|
||||
func (b *backend) periodicFunc(req *logical.Request) error {
|
||||
|
||||
// Run the tidy operations for the first time. Then run it when current
|
||||
// time matches the nextTidyTime.
|
||||
if b.nextTidyTime.IsZero() || !time.Now().UTC().Before(b.nextTidyTime) {
|
||||
// safety_buffer defaults to 180 days for roletag blacklist
|
||||
safety_buffer := 15552000
|
||||
tidyBlacklistConfigEntry, err := b.configTidyRoleTags(req.Storage)
|
||||
tidyBlacklistConfigEntry, err := b.lockedConfigTidyRoleTags(req.Storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -135,7 +134,7 @@ func (b *backend) periodicFunc(req *logical.Request) error {
|
|||
|
||||
// reset the safety_buffer to 72h
|
||||
safety_buffer = 259200
|
||||
tidyWhitelistConfigEntry, err := b.configTidyIdentities(req.Storage)
|
||||
tidyWhitelistConfigEntry, err := b.lockedConfigTidyIdentities(req.Storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -161,7 +160,7 @@ func (b *backend) periodicFunc(req *logical.Request) error {
|
|||
}
|
||||
|
||||
const backendHelp = `
|
||||
AWS auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client
|
||||
aws-ec2 auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client
|
||||
created nonce to authenticates the EC2 instance with Vault.
|
||||
|
||||
Authentication is backed by a preconfigured role in the backend. The role
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
@ -6,75 +6,23 @@ import (
|
|||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/ec2"
|
||||
"github.com/hashicorp/vault/helper/policyutil"
|
||||
"github.com/hashicorp/vault/helper/salt"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
logicaltest "github.com/hashicorp/vault/logical/testing"
|
||||
)
|
||||
|
||||
func createBackend(conf *logical.BackendConfig) (*backend, error) {
|
||||
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
|
||||
HashFunc: salt.SHA256Hash,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b := &backend{
|
||||
// Setting the periodic func to be run once in an hour.
|
||||
// If there is a real need, this can be made configurable.
|
||||
tidyCooldownPeriod: time.Hour,
|
||||
Salt: salt,
|
||||
EC2ClientsMap: make(map[string]*ec2.EC2),
|
||||
}
|
||||
|
||||
b.Backend = &framework.Backend{
|
||||
PeriodicFunc: b.periodicFunc,
|
||||
AuthRenew: b.pathLoginRenew,
|
||||
Help: backendHelp,
|
||||
PathsSpecial: &logical.Paths{
|
||||
Unauthenticated: []string{
|
||||
"login",
|
||||
},
|
||||
},
|
||||
Paths: []*framework.Path{
|
||||
pathLogin(b),
|
||||
pathListRole(b),
|
||||
pathListRoles(b),
|
||||
pathRole(b),
|
||||
pathRoleTag(b),
|
||||
pathConfigClient(b),
|
||||
pathConfigCertificate(b),
|
||||
pathConfigTidyRoletagBlacklist(b),
|
||||
pathConfigTidyIdentityWhitelist(b),
|
||||
pathListCertificates(b),
|
||||
pathListRoletagBlacklist(b),
|
||||
pathRoletagBlacklist(b),
|
||||
pathTidyRoletagBlacklist(b),
|
||||
pathListIdentityWhitelist(b),
|
||||
pathIdentityWhitelist(b),
|
||||
pathTidyIdentityWhitelist(b),
|
||||
},
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
|
||||
// create a backend
|
||||
config := logical.TestBackendConfig()
|
||||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -98,7 +46,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
|
|||
}
|
||||
|
||||
// read the created role entry
|
||||
roleEntry, err := b.awsRole(storage, "abcd-123")
|
||||
roleEntry, err := b.lockedAWSRole(storage, "abcd-123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -165,7 +113,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
|
|||
}
|
||||
|
||||
// get the entry of the newly created role entry
|
||||
roleEntry2, err := b.awsRole(storage, "ami-6789")
|
||||
roleEntry2, err := b.lockedAWSRole(storage, "ami-6789")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -293,11 +241,11 @@ func TestBackend_ConfigTidyIdentities(t *testing.T) {
|
|||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -347,11 +295,11 @@ func TestBackend_ConfigTidyRoleTags(t *testing.T) {
|
|||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -401,11 +349,11 @@ func TestBackend_TidyIdentities(t *testing.T) {
|
|||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -426,11 +374,11 @@ func TestBackend_TidyRoleTags(t *testing.T) {
|
|||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -451,11 +399,11 @@ func TestBackend_ConfigClient(t *testing.T) {
|
|||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -588,11 +536,11 @@ func TestBackend_pathConfigCertificate(t *testing.T) {
|
|||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -739,11 +687,11 @@ func TestBackend_pathRole(t *testing.T) {
|
|||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -865,11 +813,11 @@ func TestBackend_parseAndVerifyRoleTagValue(t *testing.T) {
|
|||
config := logical.TestBackendConfig()
|
||||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -942,11 +890,11 @@ func TestBackend_PathRoleTag(t *testing.T) {
|
|||
config := logical.TestBackendConfig()
|
||||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1006,11 +954,11 @@ func TestBackend_PathBlacklistRoleTag(t *testing.T) {
|
|||
storage := &logical.InmemStorage{}
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = storage
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1098,7 +1046,7 @@ func TestBackend_PathBlacklistRoleTag(t *testing.T) {
|
|||
}
|
||||
|
||||
// try to read the deleted entry
|
||||
tagEntry, err := b.blacklistRoleTagEntry(storage, tag)
|
||||
tagEntry, err := b.lockedBlacklistRoleTagEntry(storage, tag)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1135,11 +1083,11 @@ func TestBackendAcc_LoginAndWhitelistIdentity(t *testing.T) {
|
|||
storage := &logical.InmemStorage{}
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = storage
|
||||
b, err := createBackend(config)
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Backend.Setup(config)
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -24,7 +24,7 @@ func (b *backend) getClientConfig(s logical.Storage, region string) (*aws.Config
|
|||
}
|
||||
|
||||
// Read the configured secret key and access key
|
||||
config, err := b.clientConfigEntryInternal(s)
|
||||
config, err := b.nonLockedClientConfigEntry(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
|
@ -96,7 +96,7 @@ func (b *backend) pathConfigCertificateExistenceCheck(req *logical.Request, data
|
|||
return false, fmt.Errorf("missing cert_name")
|
||||
}
|
||||
|
||||
entry, err := b.awsPublicCertificateEntry(req.Storage, certName)
|
||||
entry, err := b.lockedAWSPublicCertificateEntry(req.Storage, certName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -161,7 +161,7 @@ func (b *backend) awsPublicCertificates(s logical.Storage) ([]*x509.Certificate,
|
|||
|
||||
// Iterate through each certificate, parse and append it to a slice.
|
||||
for _, cert := range registeredCerts {
|
||||
certEntry, err := b.awsPublicCertificateEntryInternal(s, cert)
|
||||
certEntry, err := b.nonLockedAWSPublicCertificateEntry(s, cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -180,15 +180,15 @@ func (b *backend) awsPublicCertificates(s logical.Storage) ([]*x509.Certificate,
|
|||
|
||||
// awsPublicCertificate is used to get the configured AWS Public Key that is used
|
||||
// to verify the PKCS#7 signature of the instance identity document.
|
||||
func (b *backend) awsPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
|
||||
func (b *backend) lockedAWSPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
|
||||
return b.awsPublicCertificateEntryInternal(s, certName)
|
||||
return b.nonLockedAWSPublicCertificateEntry(s, certName)
|
||||
}
|
||||
|
||||
// Internal version of the above that does no locking
|
||||
func (b *backend) awsPublicCertificateEntryInternal(s logical.Storage, certName string) (*awsPublicCert, error) {
|
||||
func (b *backend) nonLockedAWSPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
|
||||
entry, err := s.Get("config/certificate/" + certName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -227,7 +227,7 @@ func (b *backend) pathConfigCertificateRead(
|
|||
return logical.ErrorResponse("missing cert_name"), nil
|
||||
}
|
||||
|
||||
certificateEntry, err := b.awsPublicCertificateEntry(req.Storage, certName)
|
||||
certificateEntry, err := b.lockedAWSPublicCertificateEntry(req.Storage, certName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -253,7 +253,7 @@ func (b *backend) pathConfigCertificateCreateUpdate(
|
|||
defer b.configMutex.Unlock()
|
||||
|
||||
// Check if there is already a certificate entry registered.
|
||||
certEntry, err := b.awsPublicCertificateEntryInternal(req.Storage, certName)
|
||||
certEntry, err := b.nonLockedAWSPublicCertificateEntry(req.Storage, certName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"github.com/fatih/structs"
|
||||
|
@ -48,23 +48,23 @@ func pathConfigClient(b *backend) *framework.Path {
|
|||
func (b *backend) pathConfigClientExistenceCheck(
|
||||
req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
|
||||
entry, err := b.clientConfigEntry(req.Storage)
|
||||
entry, err := b.lockedClientConfigEntry(req.Storage)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return entry != nil, nil
|
||||
}
|
||||
|
||||
// Fetch the client configuration required to access the AWS API.
|
||||
func (b *backend) clientConfigEntry(s logical.Storage) (*clientConfig, error) {
|
||||
// Fetch the client configuration required to access the AWS API, after acquiring an exclusive lock.
|
||||
func (b *backend) lockedClientConfigEntry(s logical.Storage) (*clientConfig, error) {
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
|
||||
return b.clientConfigEntryInternal(s)
|
||||
return b.nonLockedClientConfigEntry(s)
|
||||
}
|
||||
|
||||
// Internal version that does no locking
|
||||
func (b *backend) clientConfigEntryInternal(s logical.Storage) (*clientConfig, error) {
|
||||
// Fetch the client configuration required to access the AWS API.
|
||||
func (b *backend) nonLockedClientConfigEntry(s logical.Storage) (*clientConfig, error) {
|
||||
entry, err := s.Get("config/client")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -82,7 +82,7 @@ func (b *backend) clientConfigEntryInternal(s logical.Storage) (*clientConfig, e
|
|||
|
||||
func (b *backend) pathConfigClientRead(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
clientConfig, err := b.clientConfigEntry(req.Storage)
|
||||
clientConfig, err := b.lockedClientConfigEntry(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -118,7 +118,7 @@ func (b *backend) pathConfigClientCreateUpdate(
|
|||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
configEntry, err := b.clientConfigEntryInternal(req.Storage)
|
||||
configEntry, err := b.nonLockedClientConfigEntry(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -193,7 +193,7 @@ Configure the client credentials that are used to query instance details from AW
|
|||
`
|
||||
|
||||
const pathConfigClientHelpDesc = `
|
||||
AWS auth backend makes DescribeInstances API call to retrieve information regarding
|
||||
the instance that performs login. The aws_secret_key and aws_access_key registered with Vault should have the
|
||||
permissions to make the API call.
|
||||
aws-ec2 auth backend makes DescribeInstances API call to retrieve information regarding
|
||||
the instance that performs login. The aws_secret_key and aws_access_key registered with
|
||||
Vault should have the permissions to make the API call.
|
||||
`
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -44,21 +44,21 @@ expiration, before it is removed from the backend storage.`,
|
|||
}
|
||||
|
||||
func (b *backend) pathConfigTidyIdentityWhitelistExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
entry, err := b.configTidyIdentities(req.Storage)
|
||||
entry, err := b.lockedConfigTidyIdentities(req.Storage)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return entry != nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) configTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
|
||||
func (b *backend) lockedConfigTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
|
||||
return b.configTidyIdentitiesInternal(s)
|
||||
return b.nonLockedConfigTidyIdentities(s)
|
||||
}
|
||||
|
||||
func (b *backend) configTidyIdentitiesInternal(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
|
||||
func (b *backend) nonLockedConfigTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
|
||||
entry, err := s.Get(identityWhitelistConfigPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -78,7 +78,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(req *logical.Reque
|
|||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
configEntry, err := b.configTidyIdentitiesInternal(req.Storage)
|
||||
configEntry, err := b.nonLockedConfigTidyIdentities(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -113,7 +113,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(req *logical.Reque
|
|||
}
|
||||
|
||||
func (b *backend) pathConfigTidyIdentityWhitelistRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
clientConfig, err := b.configTidyIdentities(req.Storage)
|
||||
clientConfig, err := b.lockedConfigTidyIdentities(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -46,21 +46,21 @@ Defaults to 4320h (180 days).`,
|
|||
}
|
||||
|
||||
func (b *backend) pathConfigTidyRoletagBlacklistExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
entry, err := b.configTidyRoleTags(req.Storage)
|
||||
entry, err := b.lockedConfigTidyRoleTags(req.Storage)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return entry != nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) configTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
|
||||
func (b *backend) lockedConfigTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
|
||||
b.configMutex.RLock()
|
||||
defer b.configMutex.RUnlock()
|
||||
|
||||
return b.configTidyRoleTagsInternal(s)
|
||||
return b.nonLockedConfigTidyRoleTags(s)
|
||||
}
|
||||
|
||||
func (b *backend) configTidyRoleTagsInternal(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
|
||||
func (b *backend) nonLockedConfigTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
|
||||
entry, err := s.Get(roletagBlacklistConfigPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -81,7 +81,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(req *logical.Reques
|
|||
b.configMutex.Lock()
|
||||
defer b.configMutex.Unlock()
|
||||
|
||||
configEntry, err := b.configTidyRoleTagsInternal(req.Storage)
|
||||
configEntry, err := b.nonLockedConfigTidyRoleTags(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(req *logical.Reques
|
|||
}
|
||||
|
||||
func (b *backend) pathConfigTidyRoletagBlacklistRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
clientConfig, err := b.configTidyRoleTags(req.Storage)
|
||||
clientConfig, err := b.lockedConfigTidyRoleTags(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"time"
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
@ -236,7 +236,7 @@ func (b *backend) pathLoginUpdate(
|
|||
}
|
||||
|
||||
// Get the entry for the role used by the instance.
|
||||
roleEntry, err := b.awsRole(req.Storage, roleName)
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -442,7 +442,7 @@ func (b *backend) handleRoleTagLogin(s logical.Storage, identityDoc *identityDoc
|
|||
}
|
||||
|
||||
// Check if the role tag is blacklisted.
|
||||
blacklistEntry, err := b.blacklistRoleTagEntry(s, rTagValue)
|
||||
blacklistEntry, err := b.lockedBlacklistRoleTagEntry(s, rTagValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -478,7 +478,7 @@ func (b *backend) pathLoginRenew(
|
|||
// Cross check that the instance is still in 'running' state
|
||||
_, err := b.validateInstance(req.Storage, instanceID, region)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to verify instance ID: %s", err)), nil
|
||||
return nil, fmt.Errorf("failed to verify instance ID: %s", err)
|
||||
}
|
||||
|
||||
storedIdentity, err := whitelistIdentityEntry(req.Storage, instanceID)
|
||||
|
@ -487,12 +487,12 @@ func (b *backend) pathLoginRenew(
|
|||
}
|
||||
|
||||
// Ensure that role entry is not deleted.
|
||||
roleEntry, err := b.awsRole(req.Storage, storedIdentity.Role)
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, storedIdentity.Role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if roleEntry == nil {
|
||||
return logical.ErrorResponse("role entry not found"), nil
|
||||
return nil, fmt.Errorf("role entry not found")
|
||||
}
|
||||
|
||||
// If the login was made using the role tag, then max_ttl from tag
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -54,7 +54,7 @@ using the AMI ID specified by this parameter.`,
|
|||
"disallow_reauthentication": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Default: false,
|
||||
Description: "If set, only allows a single token to be granted per instance ID. In order to perform a fresh login, the entry in whitelist for the instance ID needs to be cleared using 'auth/aws/identity-whitelist/<instance_id>' endpoint.",
|
||||
Description: "If set, only allows a single token to be granted per instance ID. In order to perform a fresh login, the entry in whitelist for the instance ID needs to be cleared using 'auth/aws-ec2/identity-whitelist/<instance_id>' endpoint.",
|
||||
},
|
||||
},
|
||||
|
||||
|
@ -101,7 +101,7 @@ func pathListRoles(b *backend) *framework.Path {
|
|||
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
|
||||
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
|
||||
func (b *backend) pathRoleExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) {
|
||||
entry, err := b.awsRole(req.Storage, strings.ToLower(data.Get("role").(string)))
|
||||
entry, err := b.lockedAWSRole(req.Storage, strings.ToLower(data.Get("role").(string)))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -109,14 +109,14 @@ func (b *backend) pathRoleExistenceCheck(req *logical.Request, data *framework.F
|
|||
}
|
||||
|
||||
// awsRole is used to get the information registered for the given AMI ID.
|
||||
func (b *backend) awsRole(s logical.Storage, role string) (*awsRoleEntry, error) {
|
||||
func (b *backend) lockedAWSRole(s logical.Storage, role string) (*awsRoleEntry, error) {
|
||||
b.roleMutex.RLock()
|
||||
defer b.roleMutex.RUnlock()
|
||||
|
||||
return b.awsRoleInternal(s, role)
|
||||
return b.nonLockedAWSRole(s, role)
|
||||
}
|
||||
|
||||
func (b *backend) awsRoleInternal(s logical.Storage, role string) (*awsRoleEntry, error) {
|
||||
func (b *backend) nonLockedAWSRole(s logical.Storage, role string) (*awsRoleEntry, error) {
|
||||
entry, err := s.Get("role/" + strings.ToLower(role))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -162,7 +162,7 @@ func (b *backend) pathRoleList(
|
|||
// pathRoleRead is used to view the information registered for a given AMI ID.
|
||||
func (b *backend) pathRoleRead(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
roleEntry, err := b.awsRole(req.Storage, strings.ToLower(data.Get("role").(string)))
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, strings.ToLower(data.Get("role").(string)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -196,7 +196,7 @@ func (b *backend) pathRoleCreateUpdate(
|
|||
b.roleMutex.Lock()
|
||||
defer b.roleMutex.Unlock()
|
||||
|
||||
roleEntry, err := b.awsRoleInternal(req.Storage, roleName)
|
||||
roleEntry, err := b.nonLockedAWSRole(req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
|
@ -54,7 +54,7 @@ If set, the created tag can only be used by the instance with the given ID.`,
|
|||
"disallow_reauthentication": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Default: false,
|
||||
Description: "If set, only allows a single token to be granted per instance ID. In order to perform a fresh login, the entry in whitelist for the instance ID needs to be cleared using the 'auth/aws/identity-whitelist/<instance_id>' endpoint.",
|
||||
Description: "If set, only allows a single token to be granted per instance ID. In order to perform a fresh login, the entry in whitelist for the instance ID needs to be cleared using the 'auth/aws-ec2/identity-whitelist/<instance_id>' endpoint.",
|
||||
},
|
||||
},
|
||||
|
||||
|
@ -78,7 +78,7 @@ func (b *backend) pathRoleTagUpdate(
|
|||
}
|
||||
|
||||
// Fetch the role entry
|
||||
roleEntry, err := b.awsRole(req.Storage, roleName)
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -346,7 +346,7 @@ func (b *backend) parseAndVerifyRoleTagValue(s logical.Storage, tag string) (*ro
|
|||
return nil, fmt.Errorf("missing role name")
|
||||
}
|
||||
|
||||
roleEntry, err := b.awsRole(s, rTag.Role)
|
||||
roleEntry, err := b.lockedAWSRole(s, rTag.Role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
@ -72,14 +72,14 @@ func (b *backend) pathRoletagBlacklistsList(
|
|||
|
||||
// Fetch an entry from the role tag blacklist for a given tag.
|
||||
// This method takes a role tag in its original form and not a base64 encoded form.
|
||||
func (b *backend) blacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
|
||||
func (b *backend) lockedBlacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
|
||||
b.blacklistMutex.RLock()
|
||||
defer b.blacklistMutex.RUnlock()
|
||||
|
||||
return b.blacklistRoleTagEntryInternal(s, tag)
|
||||
return b.nonLockedBlacklistRoleTagEntry(s, tag)
|
||||
}
|
||||
|
||||
func (b *backend) blacklistRoleTagEntryInternal(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
|
||||
func (b *backend) nonLockedBlacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
|
||||
entry, err := s.Get("blacklist/roletag/" + base64.StdEncoding.EncodeToString([]byte(tag)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -119,7 +119,7 @@ func (b *backend) pathRoletagBlacklistRead(
|
|||
return logical.ErrorResponse("missing role_tag"), nil
|
||||
}
|
||||
|
||||
entry, err := b.blacklistRoleTagEntry(req.Storage, tag)
|
||||
entry, err := b.lockedBlacklistRoleTagEntry(req.Storage, tag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -166,7 +166,7 @@ func (b *backend) pathRoletagBlacklistUpdate(
|
|||
}
|
||||
|
||||
// Get the entry for the role mentioned in the role tag.
|
||||
roleEntry, err := b.awsRole(req.Storage, rTag.Role)
|
||||
roleEntry, err := b.lockedAWSRole(req.Storage, rTag.Role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -178,7 +178,7 @@ func (b *backend) pathRoletagBlacklistUpdate(
|
|||
defer b.blacklistMutex.Unlock()
|
||||
|
||||
// Check if the role tag is already blacklisted. If yes, update it.
|
||||
blEntry, err := b.blacklistRoleTagEntryInternal(req.Storage, tag)
|
||||
blEntry, err := b.nonLockedBlacklistRoleTagEntry(req.Storage, tag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"fmt"
|
|
@ -1,4 +1,4 @@
|
|||
package aws
|
||||
package awsec2
|
||||
|
||||
import (
|
||||
"fmt"
|
|
@ -101,15 +101,6 @@ func connectionState(t *testing.T, serverCAPath, serverCertPath, serverKeyPath,
|
|||
return connState
|
||||
}
|
||||
|
||||
func failOnError(t *testing.T, resp *logical.Response, err error) {
|
||||
if resp != nil && resp.IsError() {
|
||||
t.Fatalf("error returned in response: %s", resp.Data["error"])
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_RegisteredNonCA_CRL(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
storage := &logical.InmemStorage{}
|
||||
|
@ -140,7 +131,9 @@ func TestBackend_RegisteredNonCA_CRL(t *testing.T) {
|
|||
}
|
||||
|
||||
resp, err := b.HandleRequest(certReq)
|
||||
failOnError(t, resp, err)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Connection state is presenting the client Non-CA cert and its key.
|
||||
// This is exactly what is registered at the backend.
|
||||
|
@ -155,7 +148,9 @@ func TestBackend_RegisteredNonCA_CRL(t *testing.T) {
|
|||
}
|
||||
// Login should succeed.
|
||||
resp, err = b.HandleRequest(loginReq)
|
||||
failOnError(t, resp, err)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Register a CRL containing the issued client certificate used above.
|
||||
issuedCRL, err := ioutil.ReadFile(testIssuedCertCRL)
|
||||
|
@ -172,7 +167,9 @@ func TestBackend_RegisteredNonCA_CRL(t *testing.T) {
|
|||
Data: crlData,
|
||||
}
|
||||
resp, err = b.HandleRequest(crlReq)
|
||||
failOnError(t, resp, err)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Attempt login with the same connection state but with the CRL registered
|
||||
resp, err = b.HandleRequest(loginReq)
|
||||
|
@ -214,7 +211,9 @@ func TestBackend_CRLs(t *testing.T) {
|
|||
}
|
||||
|
||||
resp, err := b.HandleRequest(certReq)
|
||||
failOnError(t, resp, err)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Connection state is presenting the client CA cert and its key.
|
||||
// This is exactly what is registered at the backend.
|
||||
|
@ -228,7 +227,9 @@ func TestBackend_CRLs(t *testing.T) {
|
|||
},
|
||||
}
|
||||
resp, err = b.HandleRequest(loginReq)
|
||||
failOnError(t, resp, err)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Now, without changing the registered client CA cert, present from
|
||||
// the client side, a cert issued using the registered CA.
|
||||
|
@ -237,7 +238,9 @@ func TestBackend_CRLs(t *testing.T) {
|
|||
|
||||
// Attempt login with the updated connection
|
||||
resp, err = b.HandleRequest(loginReq)
|
||||
failOnError(t, resp, err)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Register a CRL containing the issued client certificate used above.
|
||||
issuedCRL, err := ioutil.ReadFile(testIssuedCertCRL)
|
||||
|
@ -255,7 +258,9 @@ func TestBackend_CRLs(t *testing.T) {
|
|||
Data: crlData,
|
||||
}
|
||||
resp, err = b.HandleRequest(crlReq)
|
||||
failOnError(t, resp, err)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Attempt login with the revoked certificate.
|
||||
resp, err = b.HandleRequest(loginReq)
|
||||
|
@ -273,7 +278,9 @@ func TestBackend_CRLs(t *testing.T) {
|
|||
}
|
||||
certData["certificate"] = clientCA2
|
||||
resp, err = b.HandleRequest(certReq)
|
||||
failOnError(t, resp, err)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Test login using a different client CA cert pair.
|
||||
connState = connectionState(t, serverCAPath, serverCertPath, serverKeyPath, testRootCACertPath2, testRootCAKeyPath2)
|
||||
|
@ -281,7 +288,9 @@ func TestBackend_CRLs(t *testing.T) {
|
|||
|
||||
// Attempt login with the updated connection
|
||||
resp, err = b.HandleRequest(loginReq)
|
||||
failOnError(t, resp, err)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Register a CRL containing the root CA certificate used above.
|
||||
rootCRL, err := ioutil.ReadFile(testRootCertCRL)
|
||||
|
@ -290,7 +299,9 @@ func TestBackend_CRLs(t *testing.T) {
|
|||
}
|
||||
crlData["crl"] = rootCRL
|
||||
resp, err = b.HandleRequest(crlReq)
|
||||
failOnError(t, resp, err)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Attempt login with the same connection state but with the CRL registered
|
||||
resp, err = b.HandleRequest(loginReq)
|
||||
|
@ -754,13 +765,7 @@ func Test_Renew(t *testing.T) {
|
|||
}
|
||||
|
||||
resp, err = b.pathLoginRenew(req, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("got nil response from renew")
|
||||
}
|
||||
if !resp.IsError() {
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
|
@ -106,7 +107,7 @@ func (b *backend) pathLoginRenew(
|
|||
|
||||
clientCerts := req.Connection.ConnState.PeerCertificates
|
||||
if len(clientCerts) == 0 {
|
||||
return logical.ErrorResponse("no client certificate found"), nil
|
||||
return nil, fmt.Errorf("no client certificate found")
|
||||
}
|
||||
skid := base64.StdEncoding.EncodeToString(clientCerts[0].SubjectKeyId)
|
||||
akid := base64.StdEncoding.EncodeToString(clientCerts[0].AuthorityKeyId)
|
||||
|
@ -114,7 +115,7 @@ func (b *backend) pathLoginRenew(
|
|||
// Certificate should not only match a registered certificate policy.
|
||||
// Also, the identity of the certificate presented should match the identity of the certificate used during login
|
||||
if req.Auth.InternalData["subject_key_id"] != skid && req.Auth.InternalData["authority_key_id"] != akid {
|
||||
return logical.ErrorResponse("client identity during renewal not matching client identity used during login"), nil
|
||||
return nil, fmt.Errorf("client identity during renewal not matching client identity used during login")
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -129,7 +130,7 @@ func (b *backend) pathLoginRenew(
|
|||
}
|
||||
|
||||
if !policyutil.EquivalentPolicies(cert.Policies, req.Auth.Policies) {
|
||||
return logical.ErrorResponse("policies have changed, not renewing"), nil
|
||||
return nil, fmt.Errorf("policies have changed, not renewing")
|
||||
}
|
||||
|
||||
return framework.LeaseExtend(cert.TTL, 0, b.System())(req, d)
|
||||
|
|
|
@ -84,7 +84,7 @@ func (b *backend) pathLoginRenew(
|
|||
verifyResp = verifyResponse
|
||||
}
|
||||
if !policyutil.EquivalentPolicies(verifyResp.Policies, req.Auth.Policies) {
|
||||
return logical.ErrorResponse("policies do not match"), nil
|
||||
return nil, fmt.Errorf("policies do not match")
|
||||
}
|
||||
|
||||
config, err := b.Config(req.Storage)
|
||||
|
|
|
@ -3,6 +3,8 @@ package ldap
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"strings"
|
||||
|
||||
"github.com/go-ldap/ldap"
|
||||
"github.com/hashicorp/vault/helper/mfa"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
@ -19,13 +21,7 @@ func Backend() *framework.Backend {
|
|||
Help: backendHelp,
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
Root: append([]string{
|
||||
"config",
|
||||
"groups/*",
|
||||
"users/*",
|
||||
},
|
||||
mfa.MFARootPaths()...,
|
||||
),
|
||||
Root: mfa.MFARootPaths(),
|
||||
|
||||
Unauthenticated: []string{
|
||||
"login/*",
|
||||
|
@ -35,7 +31,9 @@ func Backend() *framework.Backend {
|
|||
Paths: append([]*framework.Path{
|
||||
pathConfig(&b),
|
||||
pathGroups(&b),
|
||||
pathGroupsList(&b),
|
||||
pathUsers(&b),
|
||||
pathUsersList(&b),
|
||||
},
|
||||
mfa.MFAPaths(b.Backend, pathLogin(&b))...,
|
||||
),
|
||||
|
@ -101,90 +99,50 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
|||
if c == nil {
|
||||
return nil, logical.ErrorResponse("invalid connection returned from LDAP dial"), nil
|
||||
}
|
||||
binddn := ""
|
||||
if cfg.DiscoverDN || (cfg.BindDN != "" && cfg.BindPassword != "") {
|
||||
if err = c.Bind(cfg.BindDN, cfg.BindPassword); err != nil {
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP bind (service) failed: %v", err)), nil
|
||||
}
|
||||
sresult, err := c.Search(&ldap.SearchRequest{
|
||||
BaseDN: cfg.UserDN,
|
||||
Scope: 2, // subtree
|
||||
Filter: fmt.Sprintf("(%s=%s)", cfg.UserAttr, ldap.EscapeFilter(username)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP search for binddn failed: %v", err)), nil
|
||||
}
|
||||
if len(sresult.Entries) != 1 {
|
||||
return nil, logical.ErrorResponse("LDAP search for binddn 0 or not uniq"), nil
|
||||
}
|
||||
binddn = sresult.Entries[0].DN
|
||||
} else {
|
||||
if cfg.UPNDomain != "" {
|
||||
binddn = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain)
|
||||
} else {
|
||||
binddn = fmt.Sprintf("%s=%s,%s", cfg.UserAttr, EscapeLDAPValue(username), cfg.UserDN)
|
||||
}
|
||||
|
||||
bindDN, err := getBindDN(cfg, c, username)
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
if err = c.Bind(binddn, password); err != nil {
|
||||
|
||||
if err = c.Bind(bindDN, password); err != nil {
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP bind failed: %v", err)), nil
|
||||
}
|
||||
|
||||
userdn := ""
|
||||
if cfg.UPNDomain != "" {
|
||||
// Find the distinguished name for the user if userPrincipalName used for login
|
||||
sresult, err := c.Search(&ldap.SearchRequest{
|
||||
BaseDN: cfg.UserDN,
|
||||
Scope: 2, // subtree
|
||||
Filter: fmt.Sprintf("(userPrincipalName=%s)", ldap.EscapeFilter(binddn)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP search failed: %v", err)), nil
|
||||
}
|
||||
for _, e := range sresult.Entries {
|
||||
userdn = e.DN
|
||||
}
|
||||
} else {
|
||||
userdn = binddn
|
||||
userDN, err := getUserDN(cfg, c, bindDN)
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
var allgroups []string
|
||||
var policies []string
|
||||
resp := &logical.Response{
|
||||
ldapGroups, err := getLdapGroups(cfg, c, userDN, username)
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
ldapResponse := &logical.Response{
|
||||
Data: map[string]interface{}{},
|
||||
}
|
||||
if len(ldapGroups) == 0 {
|
||||
errString := fmt.Sprintf(
|
||||
"no LDAP groups found in userDN '%s' or groupDN '%s';only policies from locally-defined groups available",
|
||||
cfg.UserDN,
|
||||
cfg.GroupDN)
|
||||
ldapResponse.AddWarning(errString)
|
||||
}
|
||||
|
||||
// Fetch custom (local) groups the user has been added to
|
||||
var allGroups []string
|
||||
// Import the custom added groups from ldap backend
|
||||
user, err := b.User(req.Storage, username)
|
||||
if err == nil && user != nil {
|
||||
allgroups = append(allgroups, user.Groups...)
|
||||
allGroups = append(allGroups, user.Groups...)
|
||||
}
|
||||
// add the LDAP groups
|
||||
allGroups = append(allGroups, ldapGroups...)
|
||||
|
||||
if cfg.GroupDN != "" {
|
||||
// Enumerate all groups the user is member of. The search filter should
|
||||
// work with both openldap and MS AD standard schemas.
|
||||
sresult, err := c.Search(&ldap.SearchRequest{
|
||||
BaseDN: cfg.GroupDN,
|
||||
Scope: 2, // subtree
|
||||
Filter: fmt.Sprintf("(|(memberUid=%s)(member=%s)(uniqueMember=%s))", ldap.EscapeFilter(username), ldap.EscapeFilter(userdn), ldap.EscapeFilter(userdn)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, logical.ErrorResponse(fmt.Sprintf("LDAP search failed: %v", err)), nil
|
||||
}
|
||||
|
||||
for _, e := range sresult.Entries {
|
||||
dn, err := ldap.ParseDN(e.DN)
|
||||
if err != nil || len(dn.RDNs) == 0 || len(dn.RDNs[0].Attributes) == 0 {
|
||||
continue
|
||||
}
|
||||
gname := dn.RDNs[0].Attributes[0].Value
|
||||
allgroups = append(allgroups, gname)
|
||||
}
|
||||
} else {
|
||||
resp.AddWarning("no group DN configured; only policies from locally-defined groups available")
|
||||
}
|
||||
|
||||
for _, gname := range allgroups {
|
||||
group, err := b.Group(req.Storage, gname)
|
||||
// Retrieve policies
|
||||
var policies []string
|
||||
for _, groupName := range allGroups {
|
||||
group, err := b.Group(req.Storage, groupName)
|
||||
if err == nil && group != nil {
|
||||
policies = append(policies, group.Policies...)
|
||||
}
|
||||
|
@ -192,15 +150,140 @@ func (b *backend) Login(req *logical.Request, username string, password string)
|
|||
|
||||
if len(policies) == 0 {
|
||||
errStr := "user is not a member of any authorized group"
|
||||
if len(resp.Warnings()) > 0 {
|
||||
errStr = fmt.Sprintf("%s; additionally, %s", errStr, resp.Warnings()[0])
|
||||
if len(ldapResponse.Warnings()) > 0 {
|
||||
errStr = fmt.Sprintf("%s; additionally, %s", errStr, ldapResponse.Warnings()[0])
|
||||
}
|
||||
|
||||
resp.Data["error"] = errStr
|
||||
return nil, resp, nil
|
||||
ldapResponse.Data["error"] = errStr
|
||||
return nil, ldapResponse, nil
|
||||
}
|
||||
|
||||
return policies, resp, nil
|
||||
return policies, ldapResponse, nil
|
||||
}
|
||||
|
||||
func getBindDN(cfg *ConfigEntry, c *ldap.Conn, username string) (string, error) {
|
||||
bindDN := ""
|
||||
if cfg.DiscoverDN || (cfg.BindDN != "" && cfg.BindPassword != "") {
|
||||
if err := c.Bind(cfg.BindDN, cfg.BindPassword); err != nil {
|
||||
return bindDN, fmt.Errorf("LDAP bind (service) failed: %v", err)
|
||||
}
|
||||
result, err := c.Search(&ldap.SearchRequest{
|
||||
BaseDN: cfg.UserDN,
|
||||
Scope: 2, // subtree
|
||||
Filter: fmt.Sprintf("(%s=%s)", cfg.UserAttr, ldap.EscapeFilter(username)),
|
||||
})
|
||||
if err != nil {
|
||||
return bindDN, fmt.Errorf("LDAP search for binddn failed: %v", err)
|
||||
}
|
||||
if len(result.Entries) != 1 {
|
||||
return bindDN, fmt.Errorf("LDAP search for binddn 0 or not unique")
|
||||
}
|
||||
bindDN = result.Entries[0].DN
|
||||
} else {
|
||||
if cfg.UPNDomain != "" {
|
||||
bindDN = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain)
|
||||
} else {
|
||||
bindDN = fmt.Sprintf("%s=%s,%s", cfg.UserAttr, EscapeLDAPValue(username), cfg.UserDN)
|
||||
}
|
||||
}
|
||||
|
||||
return bindDN, nil
|
||||
}
|
||||
|
||||
func getUserDN(cfg *ConfigEntry, c *ldap.Conn, bindDN string) (string, error) {
|
||||
userDN := ""
|
||||
if cfg.UPNDomain != "" {
|
||||
// Find the distinguished name for the user if userPrincipalName used for login
|
||||
result, err := c.Search(&ldap.SearchRequest{
|
||||
BaseDN: cfg.UserDN,
|
||||
Scope: 2, // subtree
|
||||
Filter: fmt.Sprintf("(userPrincipalName=%s)", ldap.EscapeFilter(bindDN)),
|
||||
})
|
||||
if err != nil {
|
||||
return userDN, fmt.Errorf("LDAP search failed for detecting user: %v", err)
|
||||
}
|
||||
for _, e := range result.Entries {
|
||||
userDN = e.DN
|
||||
}
|
||||
} else {
|
||||
userDN = bindDN
|
||||
}
|
||||
|
||||
return userDN, nil
|
||||
}
|
||||
|
||||
func getLdapGroups(cfg *ConfigEntry, c *ldap.Conn, userDN string, username string) ([]string, error) {
|
||||
// retrieve the groups in a string/bool map as a structure to avoid duplicates inside
|
||||
ldapMap := make(map[string]bool)
|
||||
// Fetch the optional memberOf property values on the user object
|
||||
// This is the most common method used in Active Directory setup to retrieve the groups
|
||||
result, err := c.Search(&ldap.SearchRequest{
|
||||
BaseDN: userDN,
|
||||
Scope: 0, // base scope to fetch only the userDN
|
||||
Filter: "(cn=*)", // bogus filter, required to fetch the CN from userDN
|
||||
Attributes: []string{
|
||||
"memberOf",
|
||||
},
|
||||
})
|
||||
// this check remains in case something happens with the ldap query or connection
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LDAP fetch of distinguishedName=%s failed: %v", userDN, err)
|
||||
}
|
||||
// if there are more than one entry, we consider the results irrelevant and ignore them
|
||||
if len(result.Entries) == 1 {
|
||||
for _, attr := range result.Entries[0].Attributes {
|
||||
// Find the groups the user is member of from the 'memberOf' attribute extracting the CN
|
||||
if attr.Name == "memberOf" {
|
||||
for _, value := range attr.Values {
|
||||
memberOfDN, err := ldap.ParseDN(value)
|
||||
if err != nil || len(memberOfDN.RDNs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rdn := range memberOfDN.RDNs {
|
||||
for _, rdnTypeAndValue := range rdn.Attributes {
|
||||
if strings.EqualFold(rdnTypeAndValue.Type, "CN") {
|
||||
ldapMap[rdnTypeAndValue.Value] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find groups by searching in groupDN for any of the memberUid, member or uniqueMember attributes
|
||||
// and retrieving the CN in the DN result
|
||||
if cfg.GroupDN != "" {
|
||||
result, err := c.Search(&ldap.SearchRequest{
|
||||
BaseDN: cfg.GroupDN,
|
||||
Scope: 2, // subtree
|
||||
Filter: fmt.Sprintf("(|(memberUid=%s)(member=%s)(uniqueMember=%s))", ldap.EscapeFilter(username), ldap.EscapeFilter(userDN), ldap.EscapeFilter(userDN)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LDAP search failed: %v", err)
|
||||
}
|
||||
|
||||
for _, e := range result.Entries {
|
||||
dn, err := ldap.ParseDN(e.DN)
|
||||
if err != nil || len(dn.RDNs) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, rdn := range dn.RDNs {
|
||||
for _, rdnTypeAndValue := range rdn.Attributes {
|
||||
if strings.EqualFold(rdnTypeAndValue.Type, "CN") {
|
||||
ldapMap[rdnTypeAndValue.Value] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ldapGroups := make([]string, len(ldapMap))
|
||||
for key, _ := range ldapMap {
|
||||
ldapGroups = append(ldapGroups, key)
|
||||
}
|
||||
return ldapGroups, nil
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
|
|
|
@ -2,6 +2,7 @@ package ldap
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -38,6 +39,8 @@ func TestBackend_basic(t *testing.T) {
|
|||
testAccStepGroup(t, "engineers", "bar"),
|
||||
testAccStepUser(t, "tesla", "engineers"),
|
||||
testAccStepLogin(t, "tesla", "password"),
|
||||
testAccStepGroupList(t, []string{"engineers", "scientists"}),
|
||||
testAccStepUserList(t, []string{"tesla"}),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
@ -321,3 +324,39 @@ func TestLDAPEscape(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepGroupList(t *testing.T, groups []string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ListOperation,
|
||||
Path: "groups",
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.IsError() {
|
||||
return fmt.Errorf("Got error response: %#v", *resp)
|
||||
}
|
||||
|
||||
exp := groups
|
||||
if !reflect.DeepEqual(exp, resp.Data["keys"].([]string)) {
|
||||
return fmt.Errorf("expected:\n%#v\ngot:\n%#v\n", exp, resp.Data["keys"])
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepUserList(t *testing.T, users []string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ListOperation,
|
||||
Path: "users",
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.IsError() {
|
||||
return fmt.Errorf("Got error response: %#v", *resp)
|
||||
}
|
||||
|
||||
exp := users
|
||||
if !reflect.DeepEqual(exp, resp.Data["keys"].([]string)) {
|
||||
return fmt.Errorf("expected:\n%#v\ngot:\n%#v\n", exp, resp.Data["keys"])
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,19 @@ import (
|
|||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathGroupsList(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "groups/?$",
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.pathGroupList,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathGroupHelpSyn,
|
||||
HelpDescription: pathGroupHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathGroups(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: `groups/(?P<name>.+)`,
|
||||
|
@ -94,6 +107,15 @@ func (b *backend) pathGroupWrite(
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathGroupList(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
groups, err := req.Storage.List("group/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logical.ListResponse(groups), nil
|
||||
}
|
||||
|
||||
type GroupEntry struct {
|
||||
Policies []string
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
|
@ -83,7 +84,7 @@ func (b *backend) pathLoginRenew(
|
|||
}
|
||||
|
||||
if !policyutil.EquivalentPolicies(loginPolicies, req.Auth.Policies) {
|
||||
return logical.ErrorResponse("policies have changed, not renewing"), nil
|
||||
return nil, fmt.Errorf("policies have changed, not renewing")
|
||||
}
|
||||
|
||||
return framework.LeaseExtend(0, 0, b.System())(req, d)
|
||||
|
|
|
@ -7,6 +7,19 @@ import (
|
|||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathUsersList(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "users/?$",
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.pathUserList,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathUserHelpSyn,
|
||||
HelpDescription: pathUserHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathUsers(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: `users/(?P<name>.+)`,
|
||||
|
@ -25,7 +38,7 @@ func pathUsers(b *backend) *framework.Path {
|
|||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.DeleteOperation: b.pathUserDelete,
|
||||
logical.ReadOperation: b.pathUserRead,
|
||||
logical.UpdateOperation: b.pathUserWrite,
|
||||
logical.UpdateOperation: b.pathUserWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathUserHelpSyn,
|
||||
|
@ -99,6 +112,15 @@ func (b *backend) pathUserWrite(
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathUserList(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
users, err := req.Storage.List("user/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logical.ListResponse(users), nil
|
||||
}
|
||||
|
||||
type UserEntry struct {
|
||||
Groups []string
|
||||
}
|
||||
|
|
|
@ -94,7 +94,7 @@ func (b *backend) pathLoginRenew(
|
|||
}
|
||||
|
||||
if !policyutil.EquivalentPolicies(user.Policies, req.Auth.Policies) {
|
||||
return logical.ErrorResponse("policies have changed, not renewing"), nil
|
||||
return nil, fmt.Errorf("policies have changed, not renewing")
|
||||
}
|
||||
|
||||
return framework.LeaseExtend(user.TTL, user.MaxTTL, b.System())(req, d)
|
||||
|
|
|
@ -17,12 +17,6 @@ func Backend() *framework.Backend {
|
|||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
Root: []string{
|
||||
"config/*",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigRoot(),
|
||||
pathConfigLease(&b),
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/ec2"
|
||||
"github.com/aws/aws-sdk-go/service/iam"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
logicaltest "github.com/hashicorp/vault/logical/testing"
|
||||
|
@ -40,15 +42,21 @@ func TestBackend_basic(t *testing.T) {
|
|||
func TestBackend_basicSTS(t *testing.T) {
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
AcceptanceTest: true,
|
||||
PreCheck: func() { testAccPreCheck(t) },
|
||||
Backend: getBackend(t),
|
||||
PreCheck: func() {
|
||||
testAccPreCheck(t)
|
||||
createRole(t)
|
||||
},
|
||||
Backend: getBackend(t),
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t),
|
||||
testAccStepWritePolicy(t, "test", testPolicy),
|
||||
testAccStepReadSTS(t, "test"),
|
||||
testAccStepWriteArnPolicyRef(t, "test", testPolicyArn),
|
||||
testAccStepReadSTSWithArnPolicy(t, "test"),
|
||||
testAccStepWriteArnRoleRef(t, testRoleName),
|
||||
testAccStepReadSTS(t, testRoleName),
|
||||
},
|
||||
Teardown: teardown,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -84,6 +92,123 @@ func testAccPreCheck(t *testing.T) {
|
|||
log.Println("[INFO] Test: Using us-west-2 as test region")
|
||||
os.Setenv("AWS_DEFAULT_REGION", "us-west-2")
|
||||
}
|
||||
|
||||
if v := os.Getenv("AWS_ACCOUNT_ID"); v == "" {
|
||||
accountId, err := getAccountId()
|
||||
if err != nil {
|
||||
t.Fatal("AWS_ACCOUNT_ID could not be read from iam:GetUser for acceptance tests")
|
||||
}
|
||||
log.Printf("[INFO] Test: Used %s as AWS_ACCOUNT_ID", accountId)
|
||||
os.Setenv("AWS_ACCOUNT_ID", accountId)
|
||||
}
|
||||
}
|
||||
|
||||
func getAccountId() (string, error) {
|
||||
creds := credentials.NewStaticCredentials(os.Getenv("AWS_ACCESS_KEY_ID"),
|
||||
os.Getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
"")
|
||||
|
||||
awsConfig := &aws.Config{
|
||||
Credentials: creds,
|
||||
Region: aws.String("us-east-1"),
|
||||
HTTPClient: cleanhttp.DefaultClient(),
|
||||
}
|
||||
svc := iam.New(session.New(awsConfig))
|
||||
|
||||
params := &iam.GetUserInput{}
|
||||
res, err := svc.GetUser(params)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// split "arn:aws:iam::012345678912:user/username"
|
||||
accountId := strings.Split(*res.User.Arn, ":")[4]
|
||||
return accountId, nil
|
||||
}
|
||||
|
||||
const testRoleName = "Vault-Acceptance-Test-AWS-Assume-Role"
|
||||
|
||||
func createRole(t *testing.T) {
|
||||
const testRoleAssumePolicy = `{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect":"Allow",
|
||||
"Principal": {
|
||||
"AWS": "arn:aws:iam::%s:root"
|
||||
},
|
||||
"Action": "sts:AssumeRole"
|
||||
}
|
||||
]
|
||||
}
|
||||
`
|
||||
creds := credentials.NewStaticCredentials(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), "")
|
||||
|
||||
awsConfig := &aws.Config{
|
||||
Credentials: creds,
|
||||
Region: aws.String("us-east-1"),
|
||||
HTTPClient: cleanhttp.DefaultClient(),
|
||||
}
|
||||
svc := iam.New(session.New(awsConfig))
|
||||
trustPolicy := fmt.Sprintf(testRoleAssumePolicy, os.Getenv("AWS_ACCOUNT_ID"))
|
||||
|
||||
params := &iam.CreateRoleInput{
|
||||
AssumeRolePolicyDocument: aws.String(trustPolicy),
|
||||
RoleName: aws.String(testRoleName),
|
||||
Path: aws.String("/"),
|
||||
}
|
||||
|
||||
log.Printf("[INFO] AWS CreateRole: %s", testRoleName)
|
||||
_, err := svc.CreateRole(params)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("AWS CreateRole failed: %v", err)
|
||||
}
|
||||
|
||||
attachment := &iam.AttachRolePolicyInput{
|
||||
PolicyArn: aws.String(testPolicyArn),
|
||||
RoleName: aws.String(testRoleName), // Required
|
||||
}
|
||||
_, err = svc.AttachRolePolicy(attachment)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("AWS CreateRole failed: %v", err)
|
||||
}
|
||||
|
||||
// Sleep sometime because AWS is eventually consistent
|
||||
log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...")
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
|
||||
func teardown() error {
|
||||
creds := credentials.NewStaticCredentials(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), "")
|
||||
|
||||
awsConfig := &aws.Config{
|
||||
Credentials: creds,
|
||||
Region: aws.String("us-east-1"),
|
||||
HTTPClient: cleanhttp.DefaultClient(),
|
||||
}
|
||||
svc := iam.New(session.New(awsConfig))
|
||||
|
||||
attachment := &iam.DetachRolePolicyInput{
|
||||
PolicyArn: aws.String(testPolicyArn),
|
||||
RoleName: aws.String(testRoleName), // Required
|
||||
}
|
||||
_, err := svc.DetachRolePolicy(attachment)
|
||||
|
||||
params := &iam.DeleteRoleInput{
|
||||
RoleName: aws.String(testRoleName),
|
||||
}
|
||||
|
||||
log.Printf("[INFO] AWS DeleteRole: %s", testRoleName)
|
||||
_, err = svc.DeleteRole(params)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[WARN] AWS DeleteRole failed: %v", err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func testAccStepConfig(t *testing.T) logicaltest.TestStep {
|
||||
|
@ -178,7 +303,7 @@ func testAccStepReadSTSWithArnPolicy(t *testing.T, name string) logicaltest.Test
|
|||
ErrorOk: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Data["error"] !=
|
||||
"Can't generate STS credentials for a managed policy; use an inline policy instead" {
|
||||
"Can't generate STS credentials for a managed policy; use a role to assume or an inline policy instead" {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
return nil
|
||||
|
@ -317,3 +442,13 @@ func testAccStepReadArnPolicy(t *testing.T, name string, value string) logicalte
|
|||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepWriteArnRoleRef(t *testing.T, roleName string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/" + roleName,
|
||||
Data: map[string]interface{}{
|
||||
"arn": fmt.Sprintf("arn:aws:iam::%s:role/%s", os.Getenv("AWS_ACCOUNT_ID"), roleName),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,15 +48,23 @@ func (b *backend) pathSTSRead(
|
|||
}
|
||||
policyValue := string(policy.Value)
|
||||
if strings.HasPrefix(policyValue, "arn:") {
|
||||
return logical.ErrorResponse(
|
||||
"Can't generate STS credentials for a managed policy; use an inline policy instead"),
|
||||
logical.ErrInvalidRequest
|
||||
if strings.Contains(policyValue, ":role/") {
|
||||
return b.assumeRole(
|
||||
req.Storage,
|
||||
req.DisplayName, policyName, policyValue,
|
||||
ttl,
|
||||
)
|
||||
} else {
|
||||
return logical.ErrorResponse(
|
||||
"Can't generate STS credentials for a managed policy; use a role to assume or an inline policy instead"),
|
||||
logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
// Use the helper to create the secret
|
||||
return b.secretTokenCreate(
|
||||
req.Storage,
|
||||
req.DisplayName, policyName, policyValue,
|
||||
&ttl,
|
||||
ttl,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ func genUsername(displayName, policyName, userType string) (ret string, warning
|
|||
|
||||
func (b *backend) secretTokenCreate(s logical.Storage,
|
||||
displayName, policyName, policy string,
|
||||
lifeTimeInSeconds *int64) (*logical.Response, error) {
|
||||
lifeTimeInSeconds int64) (*logical.Response, error) {
|
||||
STSClient, err := clientSTS(s)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
|
@ -83,7 +83,7 @@ func (b *backend) secretTokenCreate(s logical.Storage,
|
|||
&sts.GetFederationTokenInput{
|
||||
Name: aws.String(username),
|
||||
Policy: aws.String(policy),
|
||||
DurationSeconds: lifeTimeInSeconds,
|
||||
DurationSeconds: &lifeTimeInSeconds,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
@ -111,6 +111,48 @@ func (b *backend) secretTokenCreate(s logical.Storage,
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) assumeRole(s logical.Storage,
|
||||
displayName, policyName, policy string,
|
||||
lifeTimeInSeconds int64) (*logical.Response, error) {
|
||||
STSClient, err := clientSTS(s)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
username, usernameWarning := genUsername(displayName, policyName, "iam_user")
|
||||
|
||||
tokenResp, err := STSClient.AssumeRole(
|
||||
&sts.AssumeRoleInput{
|
||||
RoleSessionName: aws.String(username),
|
||||
RoleArn: aws.String(policy),
|
||||
DurationSeconds: &lifeTimeInSeconds,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error assuming role: %s", err)), nil
|
||||
}
|
||||
|
||||
resp := b.Secret(SecretAccessKeyType).Response(map[string]interface{}{
|
||||
"access_key": *tokenResp.Credentials.AccessKeyId,
|
||||
"secret_key": *tokenResp.Credentials.SecretAccessKey,
|
||||
"security_token": *tokenResp.Credentials.SessionToken,
|
||||
}, map[string]interface{}{
|
||||
"username": username,
|
||||
"policy": policy,
|
||||
"is_sts": true,
|
||||
})
|
||||
|
||||
// Set the secret TTL to appropriately match the expiration of the token
|
||||
resp.Secret.TTL = tokenResp.Credentials.Expiration.Sub(time.Now())
|
||||
|
||||
if usernameWarning != "" {
|
||||
resp.AddWarning(usernameWarning)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) secretAccessKeysCreate(
|
||||
s logical.Storage,
|
||||
displayName, policyName string, policy string) (*logical.Response, error) {
|
||||
|
|
|
@ -9,15 +9,9 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
|||
return Backend().Setup(conf)
|
||||
}
|
||||
|
||||
func Backend() *framework.Backend {
|
||||
func Backend() *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
PathsSpecial: &logical.Paths{
|
||||
Root: []string{
|
||||
"config/*",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigAccess(),
|
||||
pathRoles(),
|
||||
|
@ -29,7 +23,7 @@ func Backend() *framework.Backend {
|
|||
},
|
||||
}
|
||||
|
||||
return b.Backend
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -18,6 +19,55 @@ import (
|
|||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
func TestBackend_config_access(t *testing.T) {
|
||||
if os.Getenv(logicaltest.TestEnvVar) == "" {
|
||||
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
|
||||
return
|
||||
}
|
||||
|
||||
accessConfig, process := testStartConsulServer(t)
|
||||
defer testStopConsulServer(t, process)
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b := Backend()
|
||||
_, err := b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
confReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/access",
|
||||
Storage: storage,
|
||||
Data: accessConfig,
|
||||
}
|
||||
|
||||
resp, err := b.HandleRequest(confReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) || resp != nil {
|
||||
t.Fatalf("failed to write configuration: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
|
||||
confReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(confReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("failed to write configuration: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
|
||||
expected := map[string]interface{}{
|
||||
"address": "127.0.0.1:8500",
|
||||
"scheme": "http",
|
||||
}
|
||||
if !reflect.DeepEqual(expected, resp.Data) {
|
||||
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
|
||||
}
|
||||
if resp.Data["token"] != nil {
|
||||
t.Fatalf("token should not be set in the response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
if os.Getenv(logicaltest.TestEnvVar) == "" {
|
||||
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
|
||||
|
@ -40,6 +90,106 @@ func TestBackend_basic(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestBackend_renew_revoke(t *testing.T) {
|
||||
if os.Getenv(logicaltest.TestEnvVar) == "" {
|
||||
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
|
||||
return
|
||||
}
|
||||
|
||||
config, process := testStartConsulServer(t)
|
||||
defer testStopConsulServer(t, process)
|
||||
|
||||
beConfig := logical.TestBackendConfig()
|
||||
beConfig.StorageView = &logical.InmemStorage{}
|
||||
b, _ := Factory(beConfig)
|
||||
|
||||
req := &logical.Request{
|
||||
Storage: beConfig.StorageView,
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/access",
|
||||
Data: config,
|
||||
}
|
||||
resp, err := b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req.Path = "roles/test"
|
||||
req.Data = map[string]interface{}{
|
||||
"policy": base64.StdEncoding.EncodeToString([]byte(testPolicy)),
|
||||
"lease": "6h",
|
||||
}
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req.Operation = logical.ReadOperation
|
||||
req.Path = "creds/test"
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil || resp.IsError() {
|
||||
t.Fatal("resp nil or error")
|
||||
}
|
||||
|
||||
generatedSecret := resp.Secret
|
||||
generatedSecret.IssueTime = time.Now()
|
||||
generatedSecret.TTL = 6 * time.Hour
|
||||
|
||||
var d struct {
|
||||
Token string `mapstructure:"token"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
log.Printf("[WARN] Generated token: %s", d.Token)
|
||||
|
||||
// Build a client and verify that the credentials work
|
||||
apiConfig := api.DefaultConfig()
|
||||
apiConfig.Address = config["address"].(string)
|
||||
apiConfig.Token = d.Token
|
||||
client, err := api.NewClient(apiConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
log.Printf("[WARN] Verifying that the generated token works...")
|
||||
_, err = client.KV().Put(&api.KVPair{
|
||||
Key: "foo",
|
||||
Value: []byte("bar"),
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req.Operation = logical.RenewOperation
|
||||
req.Secret = generatedSecret
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("got nil response from renew")
|
||||
}
|
||||
|
||||
req.Operation = logical.RevokeOperation
|
||||
resp, err = b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
log.Printf("[WARN] Verifying that the generated token does not work...")
|
||||
_, err = client.KV().Put(&api.KVPair{
|
||||
Key: "foo",
|
||||
Value: []byte("bar"),
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_management(t *testing.T) {
|
||||
if os.Getenv(logicaltest.TestEnvVar) == "" {
|
||||
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
|
||||
|
|
|
@ -7,26 +7,23 @@ import (
|
|||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func client(s logical.Storage) (*api.Client, error) {
|
||||
entry, err := s.Get("config/access")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func client(s logical.Storage) (*api.Client, error, error) {
|
||||
conf, userErr, intErr := readConfigAccess(s)
|
||||
if intErr != nil {
|
||||
return nil, nil, intErr
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, fmt.Errorf(
|
||||
"root credentials haven't been configured. Please configure\n" +
|
||||
"them at the '/root' endpoint")
|
||||
if userErr != nil {
|
||||
return nil, userErr, nil
|
||||
}
|
||||
if conf == nil {
|
||||
return nil, nil, fmt.Errorf("no error received but no configuration found")
|
||||
}
|
||||
|
||||
var conf accessConfig
|
||||
if err := entry.DecodeJSON(&conf); err != nil {
|
||||
return nil, fmt.Errorf("error reading root configuration: %s", err)
|
||||
}
|
||||
|
||||
consulConf := api.DefaultConfig()
|
||||
consulConf := api.DefaultNonPooledConfig()
|
||||
consulConf.Address = conf.Address
|
||||
consulConf.Scheme = conf.Scheme
|
||||
consulConf.Token = conf.Token
|
||||
|
||||
return api.NewClient(consulConf)
|
||||
client, err := api.NewClient(consulConf)
|
||||
return client, nil, err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -31,11 +33,52 @@ func pathConfigAccess() *framework.Path {
|
|||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: pathConfigAccessRead,
|
||||
logical.UpdateOperation: pathConfigAccessWrite,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func readConfigAccess(storage logical.Storage) (*accessConfig, error, error) {
|
||||
entry, err := storage.Get("config/access")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, fmt.Errorf(
|
||||
"Access credentials for the backend itself haven't been configured. Please configure them at the '/config/access' endpoint"),
|
||||
nil
|
||||
}
|
||||
|
||||
conf := &accessConfig{}
|
||||
if err := entry.DecodeJSON(conf); err != nil {
|
||||
return nil, nil, fmt.Errorf("error reading consul access configuration: %s", err)
|
||||
}
|
||||
|
||||
return conf, nil, nil
|
||||
}
|
||||
|
||||
func pathConfigAccessRead(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
conf, userErr, intErr := readConfigAccess(req.Storage)
|
||||
if intErr != nil {
|
||||
return nil, intErr
|
||||
}
|
||||
if userErr != nil {
|
||||
return logical.ErrorResponse(userErr.Error()), nil
|
||||
}
|
||||
if conf == nil {
|
||||
return nil, fmt.Errorf("no user error reported but consul access configuration not found")
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"address": conf.Address,
|
||||
"scheme": conf.Scheme,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func pathConfigAccessWrite(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := logical.StorageEntryJSON("config/access", accessConfig{
|
||||
|
|
|
@ -47,8 +47,11 @@ func (b *backend) pathTokenRead(
|
|||
}
|
||||
|
||||
// Get the consul client
|
||||
c, err := client(req.Storage)
|
||||
if err != nil {
|
||||
c, userErr, intErr := client(req.Storage)
|
||||
if intErr != nil {
|
||||
return nil, intErr
|
||||
}
|
||||
if userErr != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
|
@ -67,7 +70,9 @@ func (b *backend) pathTokenRead(
|
|||
// Use the helper to create the secret
|
||||
s := b.Secret(SecretTokenType).Response(map[string]interface{}{
|
||||
"token": token,
|
||||
}, nil)
|
||||
}, map[string]interface{}{
|
||||
"token": token,
|
||||
})
|
||||
s.Secret.TTL = result.Lease
|
||||
|
||||
return s, nil
|
||||
|
|
|
@ -32,14 +32,27 @@ func (b *backend) secretTokenRenew(
|
|||
|
||||
func secretTokenRevoke(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
c, err := client(req.Storage)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
c, userErr, intErr := client(req.Storage)
|
||||
if intErr != nil {
|
||||
return nil, intErr
|
||||
}
|
||||
if userErr != nil {
|
||||
// Returning logical.ErrorResponse from revocation function is risky
|
||||
return nil, userErr
|
||||
}
|
||||
|
||||
_, err = c.ACL().Destroy(d.Get("token").(string), nil)
|
||||
tokenRaw, ok := req.Secret.InternalData["token"]
|
||||
if !ok {
|
||||
// We return nil here because this is a pre-0.5.3 problem and there is
|
||||
// nothing we can do about it. We already can't revoke the lease
|
||||
// properly if it has been renewed and this is documented pre-0.5.3
|
||||
// behavior with a security bulletin about it.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
_, err := c.ACL().Destroy(tokenRaw.(string), nil)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
|
|
|
@ -141,12 +141,10 @@ func (b *backend) secretCredsRevoke(
|
|||
|
||||
// can't drop if not all database users are dropped
|
||||
if rows.Err() != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"could not generate sql statements for all rows: %v", rows.Err())), nil
|
||||
return nil, fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"could not perform all sql statements: %v", lastStmtError)), nil
|
||||
return nil, fmt.Errorf("could not perform all sql statements: %s", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this login
|
||||
|
|
|
@ -20,12 +20,6 @@ func Backend() *framework.Backend {
|
|||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
Root: []string{
|
||||
"config/*",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
|
|
|
@ -271,6 +271,15 @@ func (b *backend) pathRoleRead(
|
|||
Data: structs.New(role).Map(),
|
||||
}
|
||||
|
||||
if resp.Data == nil {
|
||||
return nil, fmt.Errorf("error converting role data to response")
|
||||
}
|
||||
|
||||
// These values are deprecated and the entries are migrated on read
|
||||
delete(resp.Data, "lease")
|
||||
delete(resp.Data, "lease_max")
|
||||
delete(resp.Data, "allowed_base_domain")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -38,12 +38,12 @@ reference`,
|
|||
func (b *backend) secretCredsRevoke(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
if req.Secret == nil {
|
||||
return nil, fmt.Errorf("Secret is nil in request")
|
||||
return nil, fmt.Errorf("secret is nil in request")
|
||||
}
|
||||
|
||||
serialInt, ok := req.Secret.InternalData["serial_number"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Could not find serial in internal secret data")
|
||||
return nil, fmt.Errorf("could not find serial in internal secret data")
|
||||
}
|
||||
|
||||
serial := strings.Replace(strings.ToLower(serialInt.(string)), "-", ":", -1)
|
||||
|
|
|
@ -19,12 +19,6 @@ func Backend() *framework.Backend {
|
|||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
Root: []string{
|
||||
"config/*",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
|
|
|
@ -171,10 +171,10 @@ func (b *backend) secretCredsRevoke(
|
|||
|
||||
// can't drop if not all privileges are revoked
|
||||
if rows.Err() != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("could not generate revocation statements for all rows: %v", rows.Err())), nil
|
||||
return nil, fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("could not perform all revocation statements: %v", lastStmtError)), nil
|
||||
return nil, fmt.Errorf("could not perform all revocation statements: %s", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this user
|
||||
|
|
10
builtin/logical/rabbitmq/README.md
Normal file
10
builtin/logical/rabbitmq/README.md
Normal file
|
@ -0,0 +1,10 @@
|
|||
# RabbitMQ Backend
|
||||
|
||||
## Testing
|
||||
|
||||
There are unit and integration RabbitMQ backend tests. Unit tests can be run by `go test`. Integration tests require setting the following environment variables:
|
||||
```
|
||||
RABBITMQ_CONNECTION_URI=
|
||||
RABBITMQ_USERNAME=
|
||||
RABBITMQ_PASSWORD=
|
||||
```
|
125
builtin/logical/rabbitmq/backend.go
Normal file
125
builtin/logical/rabbitmq/backend.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package rabbitmq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
"github.com/michaelklishin/rabbit-hole"
|
||||
)
|
||||
|
||||
// Factory creates and configures the backend
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
return Backend().Setup(conf)
|
||||
}
|
||||
|
||||
// Creates a new backend with all the paths and secrets belonging to it
|
||||
func Backend() *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
pathListRoles(&b),
|
||||
pathCreds(&b),
|
||||
pathRoles(&b),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{
|
||||
secretCreds(&b),
|
||||
},
|
||||
|
||||
Clean: b.resetClient,
|
||||
}
|
||||
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
|
||||
client *rabbithole.Client
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// DB returns the database connection.
|
||||
func (b *backend) Client(s logical.Storage) (*rabbithole.Client, error) {
|
||||
b.lock.RLock()
|
||||
|
||||
// If we already have a client, return it
|
||||
if b.client != nil {
|
||||
b.lock.RUnlock()
|
||||
return b.client, nil
|
||||
}
|
||||
|
||||
b.lock.RUnlock()
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
entry, err := s.Get("config/connection")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, fmt.Errorf("configure the client connection with config/connection first")
|
||||
}
|
||||
|
||||
var connConfig connectionConfig
|
||||
if err := entry.DecodeJSON(&connConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
// If the client was creted during the lock switch, return it
|
||||
if b.client != nil {
|
||||
return b.client, nil
|
||||
}
|
||||
|
||||
b.client, err = rabbithole.NewClient(connConfig.URI, connConfig.Username, connConfig.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Use a default pooled transport so there would be no leaked file descriptors
|
||||
b.client.SetTransport(cleanhttp.DefaultPooledTransport())
|
||||
|
||||
return b.client, nil
|
||||
}
|
||||
|
||||
// resetClient forces a connection next time Client() is called.
|
||||
func (b *backend) resetClient() {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
b.client = nil
|
||||
}
|
||||
|
||||
// Lease returns the lease information
|
||||
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get("config/lease")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result configLease
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The RabbitMQ backend dynamically generates RabbitMQ users.
|
||||
|
||||
After mounting this backend, configure it using the endpoints within
|
||||
the "config/" path.
|
||||
`
|
218
builtin/logical/rabbitmq/backend_test.go
Normal file
218
builtin/logical/rabbitmq/backend_test.go
Normal file
|
@ -0,0 +1,218 @@
|
|||
package rabbitmq
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
logicaltest "github.com/hashicorp/vault/logical/testing"
|
||||
"github.com/michaelklishin/rabbit-hole"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
if os.Getenv(logicaltest.TestEnvVar) == "" {
|
||||
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
|
||||
return
|
||||
}
|
||||
b, _ := Factory(logical.TestBackendConfig())
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
PreCheck: func() { testAccPreCheck(t) },
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t),
|
||||
testAccStepRole(t),
|
||||
testAccStepReadCreds(t, b, "web"),
|
||||
},
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestBackend_roleCrud(t *testing.T) {
|
||||
if os.Getenv(logicaltest.TestEnvVar) == "" {
|
||||
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
|
||||
return
|
||||
}
|
||||
b, _ := Factory(logical.TestBackendConfig())
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
PreCheck: func() { testAccPreCheck(t) },
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t),
|
||||
testAccStepRole(t),
|
||||
testAccStepReadRole(t, "web", "administrator", `{"/": {"configure": ".*", "write": ".*", "read": ".*"}}`),
|
||||
testAccStepDeleteRole(t, "web"),
|
||||
testAccStepReadRole(t, "web", "", ""),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const (
|
||||
uriEnv = "RABBITMQ_CONNECTION_URI"
|
||||
usernameEnv = "RABBITMQ_USERNAME"
|
||||
passwordEnv = "RABBITMQ_PASSWORD"
|
||||
)
|
||||
|
||||
func mustSet(name string) string {
|
||||
return fmt.Sprintf("%s must be set for acceptance tests", name)
|
||||
}
|
||||
|
||||
func testAccPreCheck(t *testing.T) {
|
||||
if uri := os.Getenv(uriEnv); uri == "" {
|
||||
t.Fatal(mustSet(uriEnv))
|
||||
}
|
||||
if username := os.Getenv(usernameEnv); username == "" {
|
||||
t.Fatal(mustSet(usernameEnv))
|
||||
}
|
||||
if password := os.Getenv(passwordEnv); password == "" {
|
||||
t.Fatal(mustSet(passwordEnv))
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepConfig(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Data: map[string]interface{}{
|
||||
"connection_uri": os.Getenv(uriEnv),
|
||||
"username": os.Getenv(usernameEnv),
|
||||
"password": os.Getenv(passwordEnv),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepRole(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/web",
|
||||
Data: map[string]interface{}{
|
||||
"tags": "administrator",
|
||||
"vhosts": `{"/": {"configure": ".*", "write": ".*", "read": ".*"}}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "roles/" + n,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadCreds(t *testing.T, b logical.Backend, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[WARN] Generated credentials: %v", d)
|
||||
|
||||
uri := os.Getenv(uriEnv)
|
||||
|
||||
client, err := rabbithole.NewClient(uri, d.Username, d.Password)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.ListVhosts()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to list vhosts with generated credentials: %s", err)
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Operation: logical.RevokeOperation,
|
||||
Secret: &logical.Secret{
|
||||
InternalData: map[string]interface{}{
|
||||
"secret_type": "creds",
|
||||
"username": d.Username,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.IsError() {
|
||||
return fmt.Errorf("Error on resp: %#v", *resp)
|
||||
}
|
||||
}
|
||||
|
||||
client, err = rabbithole.NewClient(uri, d.Username, d.Password)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.ListVhosts()
|
||||
if err == nil {
|
||||
t.Fatalf("expected to fail listing vhosts: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadRole(t *testing.T, name, tags, rawVHosts string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "roles/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil {
|
||||
if tags == "" && rawVHosts == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
var d struct {
|
||||
Tags string `mapstructure:"tags"`
|
||||
VHosts map[string]vhostPermission `mapstructure:"vhosts"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.Tags != tags {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
var vhosts map[string]vhostPermission
|
||||
if err := json.Unmarshal([]byte(rawVHosts), &vhosts); err != nil {
|
||||
return fmt.Errorf("bad expected vhosts %#v: %s", vhosts, err)
|
||||
}
|
||||
|
||||
for host, permission := range vhosts {
|
||||
actualPermission, ok := d.VHosts[host]
|
||||
if !ok {
|
||||
return fmt.Errorf("expected vhost: %s", host)
|
||||
}
|
||||
|
||||
if actualPermission.Configure != permission.Configure {
|
||||
return fmt.Errorf("expected permission %s to be %s, got %s", "configure", permission.Configure, actualPermission.Configure)
|
||||
}
|
||||
|
||||
if actualPermission.Write != permission.Write {
|
||||
return fmt.Errorf("expected permission %s to be %s, got %s", "write", permission.Write, actualPermission.Write)
|
||||
}
|
||||
|
||||
if actualPermission.Read != permission.Read {
|
||||
return fmt.Errorf("expected permission %s to be %s, got %s", "read", permission.Read, actualPermission.Read)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
118
builtin/logical/rabbitmq/path_config_connection.go
Normal file
118
builtin/logical/rabbitmq/path_config_connection.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
package rabbitmq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
"github.com/michaelklishin/rabbit-hole"
|
||||
)
|
||||
|
||||
func pathConfigConnection(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/connection",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"connection_uri": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "RabbitMQ Management URI",
|
||||
},
|
||||
"username": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Username of a RabbitMQ management administrator",
|
||||
},
|
||||
"password": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Password of the provided RabbitMQ management user",
|
||||
},
|
||||
"verify_connection": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Default: true,
|
||||
Description: `If set, connection_uri is verified by actually connecting to the RabbitMQ management API`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.pathConnectionUpdate,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigConnectionHelpSyn,
|
||||
HelpDescription: pathConfigConnectionHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathConnectionUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
uri := data.Get("connection_uri").(string)
|
||||
if uri == "" {
|
||||
return logical.ErrorResponse("missing connection_uri"), nil
|
||||
}
|
||||
|
||||
username := data.Get("username").(string)
|
||||
if username == "" {
|
||||
return logical.ErrorResponse("missing username"), nil
|
||||
}
|
||||
|
||||
password := data.Get("password").(string)
|
||||
if password == "" {
|
||||
return logical.ErrorResponse("missing password"), nil
|
||||
}
|
||||
|
||||
// Don't check the connection_url if verification is disabled
|
||||
verifyConnection := data.Get("verify_connection").(bool)
|
||||
if verifyConnection {
|
||||
// Create RabbitMQ management client
|
||||
client, err := rabbithole.NewClient(uri, username, password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create client: %s", err)
|
||||
}
|
||||
|
||||
// Verify that configured credentials is capable of listing
|
||||
if _, err = client.ListUsers(); err != nil {
|
||||
return nil, fmt.Errorf("failed to validate the connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
|
||||
URI: uri,
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset the client connection
|
||||
b.resetClient()
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// connectionConfig contains the information required to make a connection to a RabbitMQ node
|
||||
type connectionConfig struct {
|
||||
// URI of the RabbitMQ server
|
||||
URI string `json:"connection_uri"`
|
||||
|
||||
// Username which has 'administrator' tag attached to it
|
||||
Username string `json:"username"`
|
||||
|
||||
// Password for the Username
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
const pathConfigConnectionHelpSyn = `
|
||||
Configure the connection URI, username, and password to talk to RabbitMQ management HTTP API.
|
||||
`
|
||||
|
||||
const pathConfigConnectionHelpDesc = `
|
||||
This path configures the connection properties used to connect to RabbitMQ management HTTP API.
|
||||
The "connection_uri" parameter is a string that is used to connect to the API. The "username"
|
||||
and "password" parameters are strings that are used as credentials to the API. The "verify_connection"
|
||||
parameter is a boolean that is used to verify whether the provided connection URI, username, and password
|
||||
are valid.
|
||||
|
||||
The URI looks like:
|
||||
"http://localhost:15672"
|
||||
`
|
83
builtin/logical/rabbitmq/path_config_lease.go
Normal file
83
builtin/logical/rabbitmq/path_config_lease.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package rabbitmq
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathConfigLease(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/lease",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
Default: 0,
|
||||
Description: "Duration before which the issued credentials needs renewal",
|
||||
},
|
||||
"max_ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
Default: 0,
|
||||
Description: `Duration after which the issued credentials should not be allowed to be renewed`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathLeaseRead,
|
||||
logical.UpdateOperation: b.pathLeaseUpdate,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigLeaseHelpSyn,
|
||||
HelpDescription: pathConfigLeaseHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// Sets the lease configuration parameters
|
||||
func (b *backend) pathLeaseUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
|
||||
TTL: time.Second * time.Duration(d.Get("ttl").(int)),
|
||||
MaxTTL: time.Second * time.Duration(d.Get("max_ttl").(int)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Returns the lease configuration parameters
|
||||
func (b *backend) pathLeaseRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
lease, err := b.Lease(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lease.TTL = lease.TTL / time.Second
|
||||
lease.MaxTTL = lease.MaxTTL / time.Second
|
||||
|
||||
return &logical.Response{
|
||||
Data: structs.New(lease).Map(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lease configuration information for the secrets issued by this backend
|
||||
type configLease struct {
|
||||
TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"`
|
||||
MaxTTL time.Duration `json:"max_ttl" structs:"max_ttl" mapstructure:"max_ttl"`
|
||||
}
|
||||
|
||||
var pathConfigLeaseHelpSyn = "Configure the lease parameters for generated credentials"
|
||||
|
||||
var pathConfigLeaseHelpDesc = `
|
||||
Sets the ttl and max_ttl values for the secrets to be issued by this backend.
|
||||
Both ttl and max_ttl takes in an integer number of seconds as input as well as
|
||||
inputs like "1h".
|
||||
`
|
53
builtin/logical/rabbitmq/path_config_lease_test.go
Normal file
53
builtin/logical/rabbitmq/path_config_lease_test.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package rabbitmq
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func TestBackend_config_lease_RU(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b := Backend()
|
||||
if _, err = b.Setup(config); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configData := map[string]interface{}{
|
||||
"ttl": "10h",
|
||||
"max_ttl": "20h",
|
||||
}
|
||||
configReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/lease",
|
||||
Storage: config.StorageView,
|
||||
Data: configData,
|
||||
}
|
||||
resp, err = b.HandleRequest(configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr:%s", resp, err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected a nil response")
|
||||
}
|
||||
|
||||
configReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("bad: resp: %#v\nerr:%s", resp, err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected a response")
|
||||
}
|
||||
|
||||
if resp.Data["ttl"].(time.Duration) != 36000 {
|
||||
t.Fatalf("bad: ttl: expected:36000 actual:%d", resp.Data["ttl"].(time.Duration))
|
||||
}
|
||||
if resp.Data["max_ttl"].(time.Duration) != 72000 {
|
||||
t.Fatalf("bad: ttl: expected:72000 actual:%d", resp.Data["ttl"].(time.Duration))
|
||||
}
|
||||
}
|
120
builtin/logical/rabbitmq/path_role_create.go
Normal file
120
builtin/logical/rabbitmq/path_role_create.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package rabbitmq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
"github.com/michaelklishin/rabbit-hole"
|
||||
)
|
||||
|
||||
func pathCreds(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "creds/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathCredsRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleCreateReadHelpSyn,
|
||||
HelpDescription: pathRoleCreateReadHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// Issues the credential based on the role name
|
||||
func (b *backend) pathCredsRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse("missing name"), nil
|
||||
}
|
||||
|
||||
// Get the role
|
||||
role, err := b.Role(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
||||
}
|
||||
|
||||
// Ensure username is unique
|
||||
uuidVal, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
username := fmt.Sprintf("%s-%s", req.DisplayName, uuidVal)
|
||||
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the client configuration
|
||||
client, err := b.Client(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if client == nil {
|
||||
return logical.ErrorResponse("failed to get the client"), nil
|
||||
}
|
||||
|
||||
// Register the generated credentials in the backend, with the RabbitMQ server
|
||||
if _, err = client.PutUser(username, rabbithole.UserSettings{
|
||||
Password: password,
|
||||
Tags: role.Tags,
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("failed to create a new user with the generated credentials")
|
||||
}
|
||||
|
||||
// If the role had vhost permissions specified, assign those permissions
|
||||
// to the created username for respective vhosts.
|
||||
for vhost, permission := range role.VHosts {
|
||||
if _, err := client.UpdatePermissionsIn(vhost, username, rabbithole.Permissions{
|
||||
Configure: permission.Configure,
|
||||
Write: permission.Write,
|
||||
Read: permission.Read,
|
||||
}); err != nil {
|
||||
// Delete the user because it's in an unknown state
|
||||
if _, rmErr := client.DeleteUser(username); rmErr != nil {
|
||||
return nil, fmt.Errorf("failed to delete user:%s, err: %s. %s", username, err, rmErr)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to update permissions to the %s user. err:%s", username, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Return the secret
|
||||
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}, map[string]interface{}{
|
||||
"username": username,
|
||||
})
|
||||
|
||||
// Determine if we have a lease
|
||||
lease, err := b.Lease(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease != nil {
|
||||
resp.Secret.TTL = lease.TTL
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathRoleCreateReadHelpSyn = `
|
||||
Request RabbitMQ credentials for a certain role.
|
||||
`
|
||||
|
||||
const pathRoleCreateReadHelpDesc = `
|
||||
This path reads RabbitMQ credentials for a certain role. The
|
||||
RabbitMQ credentials will be generated on demand and will be automatically
|
||||
revoked when the lease is up.
|
||||
`
|
182
builtin/logical/rabbitmq/path_roles.go
Normal file
182
builtin/logical/rabbitmq/path_roles.go
Normal file
|
@ -0,0 +1,182 @@
|
|||
package rabbitmq
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathListRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/?$",
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.pathRoleList,
|
||||
},
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
"tags": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Comma-separated list of tags for this role.",
|
||||
},
|
||||
"vhosts": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "A map of virtual hosts to permissions.",
|
||||
},
|
||||
},
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRoleRead,
|
||||
logical.UpdateOperation: b.pathRoleUpdate,
|
||||
logical.DeleteOperation: b.pathRoleDelete,
|
||||
},
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// Reads the role configuration from the storage
|
||||
func (b *backend) Role(s logical.Storage, n string) (*roleEntry, error) {
|
||||
entry, err := s.Get("role/" + n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result roleEntry
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Deletes an existing role
|
||||
func (b *backend) pathRoleDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse("missing name"), nil
|
||||
}
|
||||
|
||||
return nil, req.Storage.Delete("role/" + name)
|
||||
}
|
||||
|
||||
// Reads an existing role
|
||||
func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse("missing name"), nil
|
||||
}
|
||||
|
||||
role, err := b.Role(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: structs.New(role).Map(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lists all the roles registered with the backend
|
||||
func (b *backend) pathRoleList(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
roles, err := req.Storage.List("role/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(roles), nil
|
||||
}
|
||||
|
||||
// Registers a new role with the backend
|
||||
func (b *backend) pathRoleUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse("missing name"), nil
|
||||
}
|
||||
|
||||
tags := d.Get("tags").(string)
|
||||
rawVHosts := d.Get("vhosts").(string)
|
||||
|
||||
if tags == "" && rawVHosts == "" {
|
||||
return logical.ErrorResponse("both tags and vhosts not specified"), nil
|
||||
}
|
||||
|
||||
var vhosts map[string]vhostPermission
|
||||
if len(rawVHosts) > 0 {
|
||||
err := json.Unmarshal([]byte(rawVHosts), &vhosts)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to unmarshal vhosts: %s", err)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
|
||||
Tags: tags,
|
||||
VHosts: vhosts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Role that defines the capabilities of the credentials issued against it
|
||||
type roleEntry struct {
|
||||
Tags string `json:"tags" structs:"tags" mapstructure:"tags"`
|
||||
VHosts map[string]vhostPermission `json:"vhosts" structs:"vhosts" mapstructure:"vhosts"`
|
||||
}
|
||||
|
||||
// Structure representing the permissions of a vhost
|
||||
type vhostPermission struct {
|
||||
Configure string `json:"configure" structs:"configure" mapstructure:"configure"`
|
||||
Write string `json:"write" structs:"write" mapstructure:"write"`
|
||||
Read string `json:"read" structs:"read" mapstructure:"read"`
|
||||
}
|
||||
|
||||
const pathRoleHelpSyn = `
|
||||
Manage the roles that can be created with this backend.
|
||||
`
|
||||
|
||||
const pathRoleHelpDesc = `
|
||||
This path lets you manage the roles that can be created with this backend.
|
||||
|
||||
The "tags" parameter customizes the tags used to create the role.
|
||||
This is a comma separated list of strings. The "vhosts" parameter customizes
|
||||
the virtual hosts that this user will be associated with. This is a JSON object
|
||||
passed as a string in the form:
|
||||
{
|
||||
"vhostOne": {
|
||||
"configure": ".*",
|
||||
"write": ".*",
|
||||
"read": ".*"
|
||||
},
|
||||
"vhostTwo": {
|
||||
"configure": ".*",
|
||||
"write": ".*",
|
||||
"read": ".*"
|
||||
}
|
||||
}
|
||||
`
|
1
builtin/logical/rabbitmq/path_roles_test.go
Normal file
1
builtin/logical/rabbitmq/path_roles_test.go
Normal file
|
@ -0,0 +1 @@
|
|||
package rabbitmq
|
67
builtin/logical/rabbitmq/secret_creds.go
Normal file
67
builtin/logical/rabbitmq/secret_creds.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package rabbitmq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
// SecretCredsType is the key for this backend's secrets.
|
||||
const SecretCredsType = "creds"
|
||||
|
||||
func secretCreds(b *backend) *framework.Secret {
|
||||
return &framework.Secret{
|
||||
Type: SecretCredsType,
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"username": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "RabbitMQ username",
|
||||
},
|
||||
"password": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Password for the RabbitMQ username",
|
||||
},
|
||||
},
|
||||
Renew: b.secretCredsRenew,
|
||||
Revoke: b.secretCredsRevoke,
|
||||
}
|
||||
}
|
||||
|
||||
// Renew the previously issued secret
|
||||
func (b *backend) secretCredsRenew(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the lease information
|
||||
lease, err := b.Lease(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
lease = &configLease{}
|
||||
}
|
||||
|
||||
return framework.LeaseExtend(lease.TTL, lease.MaxTTL, b.System())(req, d)
|
||||
}
|
||||
|
||||
// Revoke the previously issued secret
|
||||
func (b *backend) secretCredsRevoke(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing username internal data")
|
||||
}
|
||||
username := usernameRaw.(string)
|
||||
|
||||
// Get our connection
|
||||
client, err := b.Client(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err = client.DeleteUser(username); err != nil {
|
||||
return nil, fmt.Errorf("could not delete user: %s", err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
|
@ -21,7 +21,7 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
|||
return b.Setup(conf)
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
||||
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
salt, err := salt.NewSalt(conf.StorageView, &salt.Config{
|
||||
HashFunc: salt.SHA256Hash,
|
||||
})
|
||||
|
@ -35,10 +35,6 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
|||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
Root: []string{
|
||||
"config/*",
|
||||
"keys/*",
|
||||
},
|
||||
Unauthenticated: []string{
|
||||
"verify",
|
||||
},
|
||||
|
@ -59,7 +55,7 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) {
|
|||
secretOTP(&b),
|
||||
},
|
||||
}
|
||||
return b.Backend, nil
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
|
|
|
@ -61,6 +61,120 @@ oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F
|
|||
`
|
||||
)
|
||||
|
||||
func TestBackend_allowed_users(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = b.Setup(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
roleData := map[string]interface{}{
|
||||
"key_type": "otp",
|
||||
"default_user": "ubuntu",
|
||||
"cidr_list": "52.207.235.245/16",
|
||||
"allowed_users": "test",
|
||||
}
|
||||
|
||||
roleReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/role1",
|
||||
Storage: config.StorageView,
|
||||
Data: roleData,
|
||||
}
|
||||
|
||||
resp, err := b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) || resp != nil {
|
||||
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
|
||||
credsData := map[string]interface{}{
|
||||
"ip": "52.207.235.245",
|
||||
"username": "ubuntu",
|
||||
}
|
||||
credsReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Storage: config.StorageView,
|
||||
Path: "creds/role1",
|
||||
Data: credsData,
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(credsReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) || resp == nil {
|
||||
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
if resp.Data["key"] == "" ||
|
||||
resp.Data["key_type"] != "otp" ||
|
||||
resp.Data["ip"] != "52.207.235.245" ||
|
||||
resp.Data["username"] != "ubuntu" {
|
||||
t.Fatalf("failed to create credential: resp:%#v", resp)
|
||||
}
|
||||
|
||||
credsData["username"] = "test"
|
||||
resp, err = b.HandleRequest(credsReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) || resp == nil {
|
||||
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
if resp.Data["key"] == "" ||
|
||||
resp.Data["key_type"] != "otp" ||
|
||||
resp.Data["ip"] != "52.207.235.245" ||
|
||||
resp.Data["username"] != "test" {
|
||||
t.Fatalf("failed to create credential: resp:%#v", resp)
|
||||
}
|
||||
|
||||
credsData["username"] = "random"
|
||||
resp, err = b.HandleRequest(credsReq)
|
||||
if err != nil || resp == nil || (resp != nil && !resp.IsError()) {
|
||||
t.Fatalf("expected failure: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
|
||||
delete(roleData, "allowed_users")
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) || resp != nil {
|
||||
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
|
||||
credsData["username"] = "ubuntu"
|
||||
resp, err = b.HandleRequest(credsReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) || resp == nil {
|
||||
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
if resp.Data["key"] == "" ||
|
||||
resp.Data["key_type"] != "otp" ||
|
||||
resp.Data["ip"] != "52.207.235.245" ||
|
||||
resp.Data["username"] != "ubuntu" {
|
||||
t.Fatalf("failed to create credential: resp:%#v", resp)
|
||||
}
|
||||
|
||||
credsData["username"] = "test"
|
||||
resp, err = b.HandleRequest(credsReq)
|
||||
if err != nil || resp == nil || (resp != nil && !resp.IsError()) {
|
||||
t.Fatalf("expected failure: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
|
||||
roleData["allowed_users"] = "*"
|
||||
resp, err = b.HandleRequest(roleReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) || resp != nil {
|
||||
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(credsReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) || resp == nil {
|
||||
t.Fatalf("failed to create role: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
if resp.Data["key"] == "" ||
|
||||
resp.Data["key_type"] != "otp" ||
|
||||
resp.Data["ip"] != "52.207.235.245" ||
|
||||
resp.Data["username"] != "test" {
|
||||
t.Fatalf("failed to create credential: resp:%#v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func testingFactory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
_, err := vault.StartSSHHostTestServer()
|
||||
if err != nil {
|
||||
|
|
|
@ -79,8 +79,10 @@ func (b *backend) pathCredsCreateWrite(
|
|||
// is the default username in the role. If neither is true, then
|
||||
// that username is not allowed to generate a credential.
|
||||
if err != nil && username != role.DefaultUser {
|
||||
return logical.ErrorResponse("Username is not present is allowed users list."), nil
|
||||
return logical.ErrorResponse("Username is not present is allowed users list"), nil
|
||||
}
|
||||
} else if username != role.DefaultUser {
|
||||
return logical.ErrorResponse("Username has to be either in allowed users list or has to be a default username"), nil
|
||||
}
|
||||
|
||||
// Validate the IP address
|
||||
|
@ -285,12 +287,22 @@ func validateIP(ip, roleName, cidrList, excludeCidrList string, zeroAddressRoles
|
|||
// Checks if the username supplied by the user is present in the list of
|
||||
// allowed users registered which creation of role.
|
||||
func validateUsername(username, allowedUsers string) error {
|
||||
if allowedUsers == "" {
|
||||
return fmt.Errorf("username not in allowed users list")
|
||||
}
|
||||
|
||||
// Role was explicitly configured to allow any username.
|
||||
if allowedUsers == "*" {
|
||||
return nil
|
||||
}
|
||||
|
||||
userList := strings.Split(allowedUsers, ",")
|
||||
for _, user := range userList {
|
||||
if user == username {
|
||||
if strings.TrimSpace(user) == username {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("username not in allowed users list")
|
||||
}
|
||||
|
||||
|
|
|
@ -575,7 +575,7 @@ func TestPolicyFuzzing(t *testing.T) {
|
|||
}
|
||||
|
||||
func testPolicyFuzzingCommon(t *testing.T, be *backend) {
|
||||
storage := &logical.LockingInmemStorage{}
|
||||
storage := &logical.InmemStorage{}
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"}
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"github.com/hashicorp/vault/version"
|
||||
|
||||
credAppId "github.com/hashicorp/vault/builtin/credential/app-id"
|
||||
credAws "github.com/hashicorp/vault/builtin/credential/aws"
|
||||
credAwsEc2 "github.com/hashicorp/vault/builtin/credential/aws-ec2"
|
||||
credCert "github.com/hashicorp/vault/builtin/credential/cert"
|
||||
credGitHub "github.com/hashicorp/vault/builtin/credential/github"
|
||||
credLdap "github.com/hashicorp/vault/builtin/credential/ldap"
|
||||
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/hashicorp/vault/builtin/logical/mysql"
|
||||
"github.com/hashicorp/vault/builtin/logical/pki"
|
||||
"github.com/hashicorp/vault/builtin/logical/postgresql"
|
||||
"github.com/hashicorp/vault/builtin/logical/rabbitmq"
|
||||
"github.com/hashicorp/vault/builtin/logical/ssh"
|
||||
"github.com/hashicorp/vault/builtin/logical/transit"
|
||||
|
||||
|
@ -64,7 +65,7 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
|
|||
},
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"cert": credCert.Factory,
|
||||
"aws": credAws.Factory,
|
||||
"aws-ec2": credAwsEc2.Factory,
|
||||
"app-id": credAppId.Factory,
|
||||
"github": credGitHub.Factory,
|
||||
"userpass": credUserpass.Factory,
|
||||
|
@ -80,6 +81,7 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
|
|||
"mssql": mssql.Factory,
|
||||
"mysql": mysql.Factory,
|
||||
"ssh": ssh.Factory,
|
||||
"rabbitmq": rabbitmq.Factory,
|
||||
},
|
||||
ShutdownCh: command.MakeShutdownCh(),
|
||||
SighupCh: command.MakeSighupCh(),
|
||||
|
@ -171,6 +173,12 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
|
|||
}, nil
|
||||
},
|
||||
|
||||
"unwrap": func() (cli.Command, error) {
|
||||
return &command.UnwrapCommand{
|
||||
Meta: *metaPtr,
|
||||
}, nil
|
||||
},
|
||||
|
||||
"list": func() (cli.Command, error) {
|
||||
return &command.ListCommand{
|
||||
Meta: *metaPtr,
|
||||
|
|
|
@ -21,6 +21,7 @@ func HelpFunc(commands map[string]cli.CommandFactory) string {
|
|||
"write": struct{}{},
|
||||
"server": struct{}{},
|
||||
"status": struct{}{},
|
||||
"unwrap": struct{}{},
|
||||
}
|
||||
|
||||
// Determine the maximum key length, and classify based on type
|
||||
|
|
|
@ -129,12 +129,17 @@ func (t TableFormatter) OutputSecret(ui cli.Ui, secret, s *api.Secret) error {
|
|||
|
||||
input = append(input, fmt.Sprintf("Key %s Value", config.Delim))
|
||||
|
||||
input = append(input, fmt.Sprintf("--- %s -----", config.Delim))
|
||||
|
||||
if s.LeaseDuration > 0 {
|
||||
if s.LeaseID != "" {
|
||||
input = append(input, fmt.Sprintf("lease_id %s %s", config.Delim, s.LeaseID))
|
||||
input = append(input, fmt.Sprintf(
|
||||
"lease_duration %s %d", config.Delim, s.LeaseDuration))
|
||||
} else {
|
||||
input = append(input, fmt.Sprintf(
|
||||
"refresh_interval %s %d", config.Delim, s.LeaseDuration))
|
||||
}
|
||||
input = append(input, fmt.Sprintf(
|
||||
"lease_duration %s %d", config.Delim, s.LeaseDuration))
|
||||
if s.LeaseID != "" {
|
||||
input = append(input, fmt.Sprintf(
|
||||
"lease_renewable %s %s", config.Delim, strconv.FormatBool(s.Renewable)))
|
||||
|
@ -152,6 +157,12 @@ func (t TableFormatter) OutputSecret(ui cli.Ui, secret, s *api.Secret) error {
|
|||
}
|
||||
}
|
||||
|
||||
if s.WrapInfo != nil {
|
||||
input = append(input, fmt.Sprintf("wrapping_token: %s %s", config.Delim, s.WrapInfo.Token))
|
||||
input = append(input, fmt.Sprintf("wrapping_token_ttl: %s %d", config.Delim, s.WrapInfo.TTL))
|
||||
input = append(input, fmt.Sprintf("wrapping_token_creation_time: %s %s", config.Delim, s.WrapInfo.CreationTime.String()))
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(s.Data))
|
||||
for k := range s.Data {
|
||||
keys = append(keys, k)
|
||||
|
|
|
@ -3,8 +3,6 @@ package command
|
|||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
|
@ -63,23 +61,7 @@ func (c *ReadCommand) Run(args []string) int {
|
|||
|
||||
// Handle single field output
|
||||
if field != "" {
|
||||
if val, ok := secret.Data[field]; ok {
|
||||
// c.Ui.Output() prints a CR character which in this case is
|
||||
// not desired. Since Vault CLI currently only uses BasicUi,
|
||||
// which writes to standard output, os.Stdout is used here to
|
||||
// directly print the message. If mitchellh/cli exposes method
|
||||
// to print without CR, this check needs to be removed.
|
||||
if reflect.TypeOf(c.Ui).String() == "*cli.BasicUi" {
|
||||
fmt.Fprintf(os.Stdout, fmt.Sprintf("%v", val))
|
||||
} else {
|
||||
c.Ui.Output(fmt.Sprintf("%v", val))
|
||||
}
|
||||
return 0
|
||||
} else {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Field %s not present in secret", field))
|
||||
return 1
|
||||
}
|
||||
return PrintRawField(c.Ui, secret, field)
|
||||
}
|
||||
|
||||
return OutputSecret(c.Ui, format, secret)
|
||||
|
|
|
@ -44,6 +44,8 @@ type ServerCommand struct {
|
|||
|
||||
meta.Meta
|
||||
|
||||
logger *log.Logger
|
||||
|
||||
ReloadFuncs map[string][]server.ReloadFunc
|
||||
}
|
||||
|
||||
|
@ -63,11 +65,11 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
if os.Getenv("VAULT_DEV_ROOT_TOKEN_ID") != "" {
|
||||
if os.Getenv("VAULT_DEV_ROOT_TOKEN_ID") != "" && devRootTokenID == "" {
|
||||
devRootTokenID = os.Getenv("VAULT_DEV_ROOT_TOKEN_ID")
|
||||
}
|
||||
|
||||
if os.Getenv("VAULT_DEV_LISTEN_ADDRESS") != "" {
|
||||
if os.Getenv("VAULT_DEV_LISTEN_ADDRESS") != "" && devListenAddress == "" {
|
||||
devListenAddress = os.Getenv("VAULT_DEV_LISTEN_ADDRESS")
|
||||
}
|
||||
|
||||
|
@ -136,7 +138,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
// Create a logger. We wrap it in a gated writer so that it doesn't
|
||||
// start logging too early.
|
||||
logGate := &gatedwriter.Writer{Writer: os.Stderr}
|
||||
logger := log.New(&logutils.LevelFilter{
|
||||
c.logger = log.New(&logutils.LevelFilter{
|
||||
Levels: []logutils.LogLevel{
|
||||
"TRACE", "DEBUG", "INFO", "WARN", "ERR"},
|
||||
MinLevel: logutils.LogLevel(strings.ToUpper(logLevel)),
|
||||
|
@ -150,7 +152,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
|
||||
// Initialize the backend
|
||||
backend, err := physical.NewBackend(
|
||||
config.Backend.Type, logger, config.Backend.Config)
|
||||
config.Backend.Type, c.logger, config.Backend.Config)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Error initializing backend of type %s: %s",
|
||||
|
@ -179,7 +181,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
AuditBackends: c.AuditBackends,
|
||||
CredentialBackends: c.CredentialBackends,
|
||||
LogicalBackends: c.LogicalBackends,
|
||||
Logger: logger,
|
||||
Logger: c.logger,
|
||||
DisableCache: config.DisableCache,
|
||||
DisableMlock: config.DisableMlock,
|
||||
MaxLeaseTTL: config.MaxLeaseTTL,
|
||||
|
@ -190,7 +192,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
var ok bool
|
||||
if config.HABackend != nil {
|
||||
habackend, err := physical.NewBackend(
|
||||
config.HABackend.Type, logger, config.HABackend.Config)
|
||||
config.HABackend.Type, c.logger, config.HABackend.Config)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Error initializing backend of type %s: %s",
|
||||
|
@ -299,14 +301,14 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
sd, ok := coreConfig.HAPhysical.(physical.ServiceDiscovery)
|
||||
if ok {
|
||||
activeFunc := func() bool {
|
||||
if isLeader, _, err := core.Leader(); err != nil {
|
||||
if isLeader, _, err := core.Leader(); err == nil {
|
||||
return isLeader
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
sealedFunc := func() bool {
|
||||
if sealed, err := core.Sealed(); err != nil {
|
||||
if sealed, err := core.Sealed(); err == nil {
|
||||
return sealed
|
||||
}
|
||||
return true
|
||||
|
@ -322,7 +324,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
// Initialize the listeners
|
||||
lns := make([]net.Listener, 0, len(config.Listeners))
|
||||
for i, lnConfig := range config.Listeners {
|
||||
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config)
|
||||
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, logGate)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Error initializing listener of type %s: %s",
|
||||
|
@ -351,6 +353,13 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
}
|
||||
}
|
||||
|
||||
// Make sure we close all listeners from this point on
|
||||
defer func() {
|
||||
for _, ln := range lns {
|
||||
ln.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
infoKeys = append(infoKeys, "version")
|
||||
info["version"] = version.GetVersion().String()
|
||||
|
||||
|
@ -368,9 +377,6 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
c.Ui.Output("")
|
||||
|
||||
if verifyOnly {
|
||||
for _, listener := range lns {
|
||||
listener.Close()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
|
@ -410,10 +416,6 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
}
|
||||
}
|
||||
|
||||
for _, listener := range lns {
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
|
|
|
@ -200,6 +200,7 @@ func ParseConfig(d string) (*Config, error) {
|
|||
}
|
||||
|
||||
valid := []string{
|
||||
"atlas",
|
||||
"backend",
|
||||
"ha_backend",
|
||||
"listener",
|
||||
|
@ -414,6 +415,8 @@ func parseHABackends(result *Config, list *ast.ObjectList) error {
|
|||
}
|
||||
|
||||
func parseListeners(result *Config, list *ast.ObjectList) error {
|
||||
var foundAtlas bool
|
||||
|
||||
listeners := make([]*Listener, 0, len(list.Items))
|
||||
for _, item := range list.Items {
|
||||
key := "listener"
|
||||
|
@ -423,10 +426,14 @@ func parseListeners(result *Config, list *ast.ObjectList) error {
|
|||
|
||||
valid := []string{
|
||||
"address",
|
||||
"endpoint",
|
||||
"infrastructure",
|
||||
"node_id",
|
||||
"tls_disable",
|
||||
"tls_cert_file",
|
||||
"tls_key_file",
|
||||
"tls_min_version",
|
||||
"token",
|
||||
}
|
||||
if err := checkHCLKeys(item.Val, valid); err != nil {
|
||||
return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key))
|
||||
|
@ -437,8 +444,27 @@ func parseListeners(result *Config, list *ast.ObjectList) error {
|
|||
return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key))
|
||||
}
|
||||
|
||||
lnType := strings.ToLower(key)
|
||||
|
||||
if lnType == "atlas" {
|
||||
if foundAtlas {
|
||||
return multierror.Prefix(fmt.Errorf("only one listener of type 'atlas' is permitted"), fmt.Sprintf("listeners.%s", key))
|
||||
} else {
|
||||
foundAtlas = true
|
||||
if m["token"] == "" {
|
||||
return multierror.Prefix(fmt.Errorf("'token' must be specified for an Atlas listener"), fmt.Sprintf("listeners.%s", key))
|
||||
}
|
||||
if m["infrastructure"] == "" {
|
||||
return multierror.Prefix(fmt.Errorf("'infrastructure' must be specified for an Atlas listener"), fmt.Sprintf("listeners.%s", key))
|
||||
}
|
||||
if m["node_id"] == "" {
|
||||
return multierror.Prefix(fmt.Errorf("'node_id' must be specified for an Atlas listener"), fmt.Sprintf("listeners.%s", key))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
listeners = append(listeners, &Listener{
|
||||
Type: strings.ToLower(key),
|
||||
Type: lnType,
|
||||
Config: m,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -15,6 +15,15 @@ func TestLoadConfigFile(t *testing.T) {
|
|||
|
||||
expected := &Config{
|
||||
Listeners: []*Listener{
|
||||
&Listener{
|
||||
Type: "atlas",
|
||||
Config: map[string]string{
|
||||
"token": "foobar",
|
||||
"infrastructure": "foo/bar",
|
||||
"endpoint": "https://foo.bar:1111",
|
||||
"node_id": "foo_node",
|
||||
},
|
||||
},
|
||||
&Listener{
|
||||
Type: "tcp",
|
||||
Config: map[string]string{
|
||||
|
@ -72,6 +81,15 @@ func TestLoadConfigFile_json(t *testing.T) {
|
|||
"address": "127.0.0.1:443",
|
||||
},
|
||||
},
|
||||
&Listener{
|
||||
Type: "atlas",
|
||||
Config: map[string]string{
|
||||
"token": "foobar",
|
||||
"infrastructure": "foo/bar",
|
||||
"endpoint": "https://foo.bar:1111",
|
||||
"node_id": "foo_node",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
Backend: &Backend{
|
||||
|
|
|
@ -6,17 +6,19 @@ import (
|
|||
_ "crypto/sha512"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ListenerFactory is the factory function to create a listener.
|
||||
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, ReloadFunc, error)
|
||||
type ListenerFactory func(map[string]string, io.Writer) (net.Listener, map[string]string, ReloadFunc, error)
|
||||
|
||||
// BuiltinListeners is the list of built-in listener types.
|
||||
var BuiltinListeners = map[string]ListenerFactory{
|
||||
"tcp": tcpListenerFactory,
|
||||
"tcp": tcpListenerFactory,
|
||||
"atlas": atlasListenerFactory,
|
||||
}
|
||||
|
||||
// tlsLookup maps the tls_min_version configuration to the internal value
|
||||
|
@ -28,13 +30,13 @@ var tlsLookup = map[string]uint16{
|
|||
|
||||
// NewListener creates a new listener of the given type with the given
|
||||
// configuration. The type is looked up in the BuiltinListeners map.
|
||||
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
|
||||
func NewListener(t string, config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
|
||||
f, ok := BuiltinListeners[t]
|
||||
if !ok {
|
||||
return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t)
|
||||
}
|
||||
|
||||
return f(config)
|
||||
return f(config, logger)
|
||||
}
|
||||
|
||||
func listenerWrapTLS(
|
||||
|
|
60
command/server/listener_atlas.go
Normal file
60
command/server/listener_atlas.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/scada-client/scada"
|
||||
"github.com/hashicorp/vault/version"
|
||||
)
|
||||
|
||||
type SCADAListener struct {
|
||||
ln net.Listener
|
||||
scadaProvider *scada.Provider
|
||||
}
|
||||
|
||||
func (s *SCADAListener) Accept() (net.Conn, error) {
|
||||
return s.ln.Accept()
|
||||
}
|
||||
|
||||
func (s *SCADAListener) Close() error {
|
||||
s.scadaProvider.Shutdown()
|
||||
return s.ln.Close()
|
||||
}
|
||||
|
||||
func (s *SCADAListener) Addr() net.Addr {
|
||||
return s.ln.Addr()
|
||||
}
|
||||
|
||||
func atlasListenerFactory(config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
|
||||
scadaConfig := &scada.Config{
|
||||
Service: "vault",
|
||||
Version: version.GetVersion().String(),
|
||||
ResourceType: "vault-cluster",
|
||||
Meta: map[string]string{
|
||||
"node_id": config["node_id"],
|
||||
},
|
||||
Atlas: scada.AtlasConfig{
|
||||
Endpoint: config["endpoint"],
|
||||
Infrastructure: config["infrastructure"],
|
||||
Token: config["token"],
|
||||
},
|
||||
}
|
||||
|
||||
provider, list, err := scada.NewHTTPProvider(scadaConfig, logger)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
ln := &SCADAListener{
|
||||
ln: list,
|
||||
scadaProvider: provider,
|
||||
}
|
||||
|
||||
props := map[string]string{
|
||||
"addr": "Atlas/SCADA",
|
||||
"infrastructure": scadaConfig.Atlas.Infrastructure,
|
||||
}
|
||||
|
||||
return listenerWrapTLS(ln, props, config)
|
||||
}
|
|
@ -1,11 +1,12 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
|
||||
func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
|
||||
addr, ok := config["address"]
|
||||
if !ok {
|
||||
addr = "127.0.0.1:8200"
|
||||
|
|
|
@ -16,7 +16,7 @@ func TestTCPListener(t *testing.T) {
|
|||
ln, _, _, err := tcpListenerFactory(map[string]string{
|
||||
"address": "127.0.0.1:0",
|
||||
"tls_disable": "1",
|
||||
})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ func TestTCPListener_tls(t *testing.T) {
|
|||
"address": "127.0.0.1:0",
|
||||
"tls_cert_file": wd + "reload_foo.pem",
|
||||
"tls_key_file": wd + "reload_foo.key",
|
||||
})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,13 @@ disable_mlock = true
|
|||
statsd_addr = "bar"
|
||||
statsite_addr = "foo"
|
||||
|
||||
listener "atlas" {
|
||||
token = "foobar"
|
||||
infrastructure = "foo/bar"
|
||||
endpoint = "https://foo.bar:1111"
|
||||
node_id = "foo_node"
|
||||
}
|
||||
|
||||
listener "tcp" {
|
||||
address = "127.0.0.1:443"
|
||||
}
|
||||
|
|
|
@ -1,17 +1,24 @@
|
|||
{
|
||||
"listener":{
|
||||
"tcp":{
|
||||
"address":"127.0.0.1:443"
|
||||
}
|
||||
},
|
||||
"backend":{
|
||||
"consul":{
|
||||
"foo":"bar"
|
||||
}
|
||||
},
|
||||
"telemetry":{
|
||||
"statsite_address":"baz"
|
||||
},
|
||||
"max_lease_ttl":"10h",
|
||||
"default_lease_ttl":"10h"
|
||||
"listener": [{
|
||||
"tcp": {
|
||||
"address": "127.0.0.1:443"
|
||||
}
|
||||
}, {
|
||||
"atlas": {
|
||||
"token": "foobar",
|
||||
"infrastructure": "foo/bar",
|
||||
"endpoint": "https://foo.bar:1111",
|
||||
"node_id": "foo_node"
|
||||
}
|
||||
}],
|
||||
"backend": {
|
||||
"consul": {
|
||||
"foo": "bar"
|
||||
}
|
||||
},
|
||||
"telemetry": {
|
||||
"statsite_address": "baz"
|
||||
},
|
||||
"max_lease_ttl": "10h",
|
||||
"default_lease_ttl": "10h"
|
||||
}
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/ssh"
|
||||
|
@ -27,15 +27,17 @@ type SSHCredentialResp struct {
|
|||
Key string `mapstructure:"key"`
|
||||
Username string `mapstructure:"username"`
|
||||
IP string `mapstructure:"ip"`
|
||||
Port int `mapstructure:"port"`
|
||||
Port string `mapstructure:"port"`
|
||||
}
|
||||
|
||||
func (c *SSHCommand) Run(args []string) int {
|
||||
var role, mountPoint, format string
|
||||
var role, mountPoint, format, userKnownHostsFile, strictHostKeyChecking string
|
||||
var noExec bool
|
||||
var sshCmdArgs []string
|
||||
var sshDynamicKeyFileName string
|
||||
flags := c.Meta.FlagSet("ssh", meta.FlagSetDefault)
|
||||
flags.StringVar(&strictHostKeyChecking, "strict-host-key-checking", "", "")
|
||||
flags.StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "")
|
||||
flags.StringVar(&format, "format", "table", "")
|
||||
flags.StringVar(&role, "role", "", "")
|
||||
flags.StringVar(&mountPoint, "mount-point", "ssh", "")
|
||||
|
@ -45,6 +47,27 @@ func (c *SSHCommand) Run(args []string) int {
|
|||
if err := flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
// If the flag is already set then it takes the precedence. If the flag is not
|
||||
// set, try setting it from env var.
|
||||
if os.Getenv("VAULT_SSH_STRICT_HOST_KEY_CHECKING") != "" && strictHostKeyChecking == "" {
|
||||
strictHostKeyChecking = os.Getenv("VAULT_SSH_STRICT_HOST_KEY_CHECKING")
|
||||
}
|
||||
// Assign default value if both flag and env var are not set
|
||||
if strictHostKeyChecking == "" {
|
||||
strictHostKeyChecking = "ask"
|
||||
}
|
||||
|
||||
// If the flag is already set then it takes the precedence. If the flag is not
|
||||
// set, try setting it from env var.
|
||||
if os.Getenv("VAULT_SSH_USER_KNOWN_HOSTS_FILE") != "" && userKnownHostsFile == "" {
|
||||
userKnownHostsFile = os.Getenv("VAULT_SSH_USER_KNOWN_HOSTS_FILE")
|
||||
}
|
||||
// Assign default value if both flag and env var are not set
|
||||
if userKnownHostsFile == "" {
|
||||
userKnownHostsFile = "~/.ssh/known_hosts"
|
||||
}
|
||||
|
||||
args = flags.Args()
|
||||
if len(args) < 1 {
|
||||
c.Ui.Error("ssh expects at least one argument")
|
||||
|
@ -123,14 +146,16 @@ func (c *SSHCommand) Run(args []string) int {
|
|||
return OutputSecret(c.Ui, format, keySecret)
|
||||
}
|
||||
|
||||
// Port comes back as a json.Number which mapstructure doesn't like, so convert it
|
||||
if keySecret.Data["port"] != nil {
|
||||
keySecret.Data["port"] = keySecret.Data["port"].(json.Number).String()
|
||||
}
|
||||
var resp SSHCredentialResp
|
||||
if err := mapstructure.Decode(keySecret.Data, &resp); err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Error parsing the credential response:%s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
port := strconv.Itoa(resp.Port)
|
||||
|
||||
if resp.KeyType == ssh.KeyTypeDynamic {
|
||||
if len(resp.Key) == 0 {
|
||||
c.Ui.Error(fmt.Sprintf("Invalid key"))
|
||||
|
@ -148,7 +173,7 @@ func (c *SSHCommand) Run(args []string) int {
|
|||
// Feel free to try and remove this dependency.
|
||||
sshpassPath, err := exec.LookPath("sshpass")
|
||||
if err == nil {
|
||||
sshCmdArgs = append(sshCmdArgs, []string{"-p", string(resp.Key), "ssh", "-p", port, username + "@" + ip.String()}...)
|
||||
sshCmdArgs = append(sshCmdArgs, []string{"-p", string(resp.Key), "ssh", "-o UserKnownHostsFile=" + userKnownHostsFile, "-o StrictHostKeyChecking=" + strictHostKeyChecking, "-p", resp.Port, username + "@" + ip.String()}...)
|
||||
sshCmd := exec.Command(sshpassPath, sshCmdArgs...)
|
||||
sshCmd.Stdin = os.Stdin
|
||||
sshCmd.Stdout = os.Stdout
|
||||
|
@ -161,7 +186,7 @@ func (c *SSHCommand) Run(args []string) int {
|
|||
c.Ui.Output("OTP for the session is " + resp.Key)
|
||||
c.Ui.Output("[Note: Install 'sshpass' to automate typing in OTP]")
|
||||
}
|
||||
sshCmdArgs = append(sshCmdArgs, []string{"-p", port, username + "@" + ip.String()}...)
|
||||
sshCmdArgs = append(sshCmdArgs, []string{"-o UserKnownHostsFile=" + userKnownHostsFile, "-o StrictHostKeyChecking=" + strictHostKeyChecking, "-p", resp.Port, username + "@" + ip.String()}...)
|
||||
|
||||
sshCmd := exec.Command("ssh", sshCmdArgs...)
|
||||
sshCmd.Stdin = os.Stdin
|
||||
|
@ -259,24 +284,38 @@ General Options:
|
|||
` + meta.GeneralOptionsUsage() + `
|
||||
SSH Options:
|
||||
|
||||
-role Role to be used to create the key.
|
||||
Each IP is associated with a role. To see the associated
|
||||
roles with IP, use "lookup" endpoint. If you are certain
|
||||
that there is only one role associated with the IP, you can
|
||||
skip mentioning the role. It will be chosen by default. If
|
||||
there are no roles associated with the IP, register the
|
||||
CIDR block of that IP using the "roles/" endpoint.
|
||||
-role Role to be used to create the key.
|
||||
Each IP is associated with a role. To see the associated
|
||||
roles with IP, use "lookup" endpoint. If you are certain
|
||||
that there is only one role associated with the IP, you can
|
||||
skip mentioning the role. It will be chosen by default. If
|
||||
there are no roles associated with the IP, register the
|
||||
CIDR block of that IP using the "roles/" endpoint.
|
||||
|
||||
-no-exec Shows the credentials but does not establish connection.
|
||||
-no-exec Shows the credentials but does not establish connection.
|
||||
|
||||
-mount-point Mount point of SSH backend. If the backend is mounted at
|
||||
'ssh', which is the default as well, this parameter can be
|
||||
skipped.
|
||||
-mount-point Mount point of SSH backend. If the backend is mounted at
|
||||
'ssh', which is the default as well, this parameter can be
|
||||
skipped.
|
||||
|
||||
-format If no-exec option is enabled, then the credentials will be
|
||||
printed out and SSH connection will not be established. The
|
||||
format of the output can be 'json' or 'table'. JSON output
|
||||
is useful when writing scripts. Default is 'table'.
|
||||
-format If no-exec option is enabled, then the credentials will be
|
||||
printed out and SSH connection will not be established. The
|
||||
format of the output can be 'json' or 'table'. JSON output
|
||||
is useful when writing scripts. Default is 'table'.
|
||||
|
||||
-strict-host-key-checking This option corresponds to StrictHostKeyChecking of SSH configuration.
|
||||
If 'sshpass' is employed to enable automated login, then if host key
|
||||
is not "known" to the client, 'vault ssh' command will fail. Set this
|
||||
option to "no" to bypass the host key checking. Defaults to "ask".
|
||||
Can also be specified with VAULT_SSH_STRICT_HOST_KEY_CHECKING environment
|
||||
variable.
|
||||
|
||||
-user-known-hosts-file This option corresponds to UserKnownHostsFile of SSH configuration.
|
||||
Assigns the file to use for storing the host keys. If this option is
|
||||
set to "/dev/null" along with "-strict-host-key-checking=no", both
|
||||
warnings and host key checking can be avoided while establishing the
|
||||
connection. Defaults to "~/.ssh/known_hosts". Can also be specified
|
||||
with VAULT_SSH_USER_KNOWN_HOSTS_FILE environment variable.
|
||||
`
|
||||
return strings.TrimSpace(helpText)
|
||||
}
|
||||
|
|
|
@ -123,8 +123,8 @@ Token Options:
|
|||
|
||||
-lease="1h" Deprecated; use "-ttl" instead.
|
||||
|
||||
-ttl="1h" TTL to associate with the token. This option enables
|
||||
the tokens to be renewable.
|
||||
-ttl="1h" Initial TTL to associate with the token; renewals can
|
||||
extend this value.
|
||||
|
||||
-metadata="key=value" Metadata to associate with the token. This shows
|
||||
up in the audit log. This can be specified multiple
|
||||
|
|
98
command/unwrap.go
Normal file
98
command/unwrap.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
)
|
||||
|
||||
// UnwrapCommand is a Command that behaves like ReadCommand but specifically
|
||||
// for unwrapping cubbyhole-wrapped secrets
|
||||
type UnwrapCommand struct {
|
||||
meta.Meta
|
||||
}
|
||||
|
||||
func (c *UnwrapCommand) Run(args []string) int {
|
||||
var format string
|
||||
var field string
|
||||
var err error
|
||||
var secret *api.Secret
|
||||
var flags *flag.FlagSet
|
||||
flags = c.Meta.FlagSet("unwrap", meta.FlagSetDefault)
|
||||
flags.StringVar(&format, "format", "table", "")
|
||||
flags.StringVar(&field, "field", "", "")
|
||||
flags.Usage = func() { c.Ui.Error(c.Help()) }
|
||||
if err := flags.Parse(args); err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
args = flags.Args()
|
||||
if len(args) != 1 || len(args[0]) == 0 {
|
||||
c.Ui.Error("Unwrap expects one argument: the ID of the wrapping token")
|
||||
flags.Usage()
|
||||
return 1
|
||||
}
|
||||
|
||||
tokenID := args[0]
|
||||
_, err = uuid.ParseUUID(tokenID)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Given token could not be parsed as a UUID: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
client, err := c.Client()
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Error initializing client: %s", err))
|
||||
return 2
|
||||
}
|
||||
|
||||
secret, err = client.Logical().Unwrap(tokenID)
|
||||
if err != nil {
|
||||
c.Ui.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
if secret == nil {
|
||||
c.Ui.Error("Secret returned was nil")
|
||||
return 1
|
||||
}
|
||||
|
||||
// Handle single field output
|
||||
if field != "" {
|
||||
return PrintRawField(c.Ui, secret, field)
|
||||
}
|
||||
|
||||
return OutputSecret(c.Ui, format, secret)
|
||||
}
|
||||
|
||||
func (c *UnwrapCommand) Synopsis() string {
|
||||
return "Unwrap a wrapped secret"
|
||||
}
|
||||
|
||||
func (c *UnwrapCommand) Help() string {
|
||||
helpText := `
|
||||
Usage: vault unwrap [options] <wrapping token ID>
|
||||
|
||||
Unwrap a wrapped secret.
|
||||
|
||||
Unwraps the data wrapped by the given token ID. The returned result is the
|
||||
same as a 'read' operation on a non-wrapped secret.
|
||||
|
||||
General Options:
|
||||
` + meta.GeneralOptionsUsage() + `
|
||||
Read Options:
|
||||
|
||||
-format=table The format for output. By default it is a whitespace-
|
||||
delimited table. This can also be json or yaml.
|
||||
|
||||
-field=field If included, the raw value of the specified field
|
||||
will be output raw to stdout.
|
||||
|
||||
`
|
||||
return strings.TrimSpace(helpText)
|
||||
}
|
74
command/unwrap_test.go
Normal file
74
command/unwrap_test.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func TestUnwrap(t *testing.T) {
|
||||
core, _, token := vault.TestCoreUnsealed(t)
|
||||
ln, addr := http.TestServer(t, core)
|
||||
defer ln.Close()
|
||||
|
||||
ui := new(cli.MockUi)
|
||||
c := &UnwrapCommand{
|
||||
Meta: meta.Meta{
|
||||
ClientToken: token,
|
||||
Ui: ui,
|
||||
},
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-address", addr,
|
||||
"-field", "zip",
|
||||
}
|
||||
|
||||
// Run once so the client is setup, ignore errors
|
||||
c.Run(args)
|
||||
|
||||
// Get the client so we can write data
|
||||
client, err := c.Client()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
wrapLookupFunc := func(method, path string) string {
|
||||
if method == "GET" && path == "secret/foo" {
|
||||
return "60s"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
client.SetWrappingLookupFunc(wrapLookupFunc)
|
||||
|
||||
data := map[string]interface{}{"zip": "zap"}
|
||||
if _, err := client.Logical().Write("secret/foo", data); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
outer, err := client.Logical().Read("secret/foo")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if outer == nil {
|
||||
t.Fatal("outer response was nil")
|
||||
}
|
||||
if outer.WrapInfo == nil {
|
||||
t.Fatal("outer wrapinfo was nil, response was %#v", *outer)
|
||||
}
|
||||
|
||||
args = append(args, outer.WrapInfo.Token)
|
||||
|
||||
// Run the read
|
||||
if code := c.Run(args); code != 0 {
|
||||
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
|
||||
}
|
||||
|
||||
output := ui.OutputWriter.String()
|
||||
if output != "zap\n" {
|
||||
t.Fatalf("unexpectd output:\n%s", output)
|
||||
}
|
||||
}
|
|
@ -1,6 +1,14 @@
|
|||
package command
|
||||
|
||||
import "github.com/hashicorp/vault/command/token"
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/command/token"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
// DefaultTokenHelper returns the token helper that is configured for Vault.
|
||||
func DefaultTokenHelper() (token.TokenHelper, error) {
|
||||
|
@ -20,3 +28,43 @@ func DefaultTokenHelper() (token.TokenHelper, error) {
|
|||
}
|
||||
return &token.ExternalTokenHelper{BinaryPath: path}, nil
|
||||
}
|
||||
|
||||
func PrintRawField(ui cli.Ui, secret *api.Secret, field string) int {
|
||||
var val interface{}
|
||||
switch field {
|
||||
case "wrapping_token":
|
||||
if secret.WrapInfo != nil {
|
||||
val = secret.WrapInfo.Token
|
||||
}
|
||||
case "wrapping_token_ttl":
|
||||
if secret.WrapInfo != nil {
|
||||
val = secret.WrapInfo.TTL
|
||||
}
|
||||
case "wrapping_token_creation_time":
|
||||
if secret.WrapInfo != nil {
|
||||
val = secret.WrapInfo.CreationTime.String()
|
||||
}
|
||||
case "refresh_interval":
|
||||
val = secret.LeaseDuration
|
||||
default:
|
||||
val = secret.Data[field]
|
||||
}
|
||||
|
||||
if val != nil {
|
||||
// c.Ui.Output() prints a CR character which in this case is
|
||||
// not desired. Since Vault CLI currently only uses BasicUi,
|
||||
// which writes to standard output, os.Stdout is used here to
|
||||
// directly print the message. If mitchellh/cli exposes method
|
||||
// to print without CR, this check needs to be removed.
|
||||
if reflect.TypeOf(ui).String() == "*cli.BasicUi" {
|
||||
fmt.Fprintf(os.Stdout, fmt.Sprintf("%v", val))
|
||||
} else {
|
||||
ui.Output(fmt.Sprintf("%v", val))
|
||||
}
|
||||
return 0
|
||||
} else {
|
||||
ui.Error(fmt.Sprintf(
|
||||
"Field %s not present in secret", field))
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
|
109
command/wrapping_test.go
Normal file
109
command/wrapping_test.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func TestWrapping_Env(t *testing.T) {
|
||||
core, _, token := vault.TestCoreUnsealed(t)
|
||||
ln, addr := http.TestServer(t, core)
|
||||
defer ln.Close()
|
||||
|
||||
ui := new(cli.MockUi)
|
||||
c := &TokenLookupCommand{
|
||||
Meta: meta.Meta{
|
||||
ClientToken: token,
|
||||
Ui: ui,
|
||||
},
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-address", addr,
|
||||
}
|
||||
// Run it once for client
|
||||
c.Run(args)
|
||||
|
||||
// Create a new token for us to use
|
||||
client, err := c.Client()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{
|
||||
Lease: "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
prevWrapTTLEnv := os.Getenv(api.EnvVaultWrapTTL)
|
||||
os.Setenv(api.EnvVaultWrapTTL, "5s")
|
||||
defer func() {
|
||||
os.Setenv(api.EnvVaultWrapTTL, prevWrapTTLEnv)
|
||||
}()
|
||||
|
||||
// Now when we do a lookup-self the response should be wrapped
|
||||
args = append(args, resp.Auth.ClientToken)
|
||||
|
||||
resp, err = client.Auth().Token().LookupSelf()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
if resp.WrapInfo == nil {
|
||||
t.Fatal("nil wrap info")
|
||||
}
|
||||
if resp.WrapInfo.Token == "" || resp.WrapInfo.TTL != 5 {
|
||||
t.Fatal("did not get token or ttl wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapping_Flag(t *testing.T) {
|
||||
core, _, token := vault.TestCoreUnsealed(t)
|
||||
ln, addr := http.TestServer(t, core)
|
||||
defer ln.Close()
|
||||
|
||||
ui := new(cli.MockUi)
|
||||
c := &TokenLookupCommand{
|
||||
Meta: meta.Meta{
|
||||
ClientToken: token,
|
||||
Ui: ui,
|
||||
},
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-address", addr,
|
||||
"-wrap-ttl", "5s",
|
||||
}
|
||||
// Run it once for client
|
||||
c.Run(args)
|
||||
|
||||
// Create a new token for us to use
|
||||
client, err := c.Client()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{
|
||||
Lease: "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
if resp.WrapInfo == nil {
|
||||
t.Fatal("nil wrap info")
|
||||
}
|
||||
if resp.WrapInfo.Token == "" || resp.WrapInfo.TTL != 5 {
|
||||
t.Fatal("did not get token or ttl wrong")
|
||||
}
|
||||
}
|
|
@ -7,6 +7,12 @@ import (
|
|||
"github.com/hashicorp/vault/helper/strutil"
|
||||
)
|
||||
|
||||
// ParsePolicies parses a comma-delimited list of policies.
|
||||
// The resulting collection will have no duplicate elements.
|
||||
// If 'root' policy was present in the list of policies, then
|
||||
// all other policies will be ignored, the result will contain
|
||||
// just the 'root'. In cases where 'root' is not present, if
|
||||
// 'default' policy is not already present, it will be added.
|
||||
func ParsePolicies(policiesRaw string) []string {
|
||||
if policiesRaw == "" {
|
||||
return []string{"default"}
|
||||
|
@ -14,10 +20,18 @@ func ParsePolicies(policiesRaw string) []string {
|
|||
|
||||
policies := strings.Split(policiesRaw, ",")
|
||||
|
||||
return SanitizePolicies(policies)
|
||||
return SanitizePolicies(policies, true)
|
||||
}
|
||||
|
||||
func SanitizePolicies(policies []string) []string {
|
||||
// SanitizePolicies performs the common input validation tasks
|
||||
// which are performed on the list of policies across Vault.
|
||||
// The resulting collection will have no duplicate elements.
|
||||
// If 'root' policy was present in the list of policies, then
|
||||
// all other policies will be ignored, the result will contain
|
||||
// just the 'root'. In cases where 'root' is not present, if
|
||||
// 'default' policy is not already present, it will be added
|
||||
// if addDefault is set to true.
|
||||
func SanitizePolicies(policies []string, addDefault bool) []string {
|
||||
defaultFound := false
|
||||
for i, p := range policies {
|
||||
policies[i] = strings.ToLower(strings.TrimSpace(p))
|
||||
|
@ -38,7 +52,7 @@ func SanitizePolicies(policies []string) []string {
|
|||
}
|
||||
|
||||
// Always add 'default' except only if the policies contain 'root'.
|
||||
if len(policies) == 0 || !defaultFound {
|
||||
if addDefault && (len(policies) == 0 || !defaultFound) {
|
||||
policies = append(policies, "default")
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,21 @@ package policyutil
|
|||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizePolicies(t *testing.T) {
|
||||
expected := []string{"foo", "bar"}
|
||||
actual := SanitizePolicies([]string{"foo", "bar"}, false)
|
||||
if !EquivalentPolicies(expected, actual) {
|
||||
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
|
||||
}
|
||||
|
||||
// If 'default' is already added, do not remove it.
|
||||
expected = []string{"foo", "bar", "default"}
|
||||
actual = SanitizePolicies([]string{"foo", "bar", "default"}, false)
|
||||
if !EquivalentPolicies(expected, actual) {
|
||||
t.Fatal("bad: expected:%s\ngot:%s\n", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePolicies(t *testing.T) {
|
||||
expected := []string{"foo", "bar", "default"}
|
||||
actual := ParsePolicies("foo,bar")
|
||||
|
|
|
@ -6,15 +6,23 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
// AuthHeaderName is the name of the header containing the token.
|
||||
const AuthHeaderName = "X-Vault-Token"
|
||||
const (
|
||||
// AuthHeaderName is the name of the header containing the token.
|
||||
AuthHeaderName = "X-Vault-Token"
|
||||
|
||||
// WrapHeaderName is the name of the header containing a directive to wrap the
|
||||
// response.
|
||||
WrapTTLHeaderName = "X-Vault-Wrap-TTL"
|
||||
)
|
||||
|
||||
// Handler returns an http.Handler for the API. This can be used on
|
||||
// its own to mount the Vault API within another web server.
|
||||
|
@ -85,7 +93,7 @@ func parseRequest(r *http.Request, out interface{}) error {
|
|||
// case of an error.
|
||||
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) {
|
||||
resp, err := core.HandleRequest(r)
|
||||
if err == vault.ErrStandby {
|
||||
if errwrap.Contains(err, vault.ErrStandby.Error()) {
|
||||
respondStandby(core, w, rawReq.URL)
|
||||
return resp, false
|
||||
}
|
||||
|
@ -153,13 +161,44 @@ func requestAuth(r *http.Request, req *logical.Request) *logical.Request {
|
|||
return req
|
||||
}
|
||||
|
||||
// requestWrapTTL adds the WrapTTL value to the logical.Request if it
|
||||
// exists.
|
||||
func requestWrapTTL(r *http.Request, req *logical.Request) (*logical.Request, error) {
|
||||
// First try for the header value
|
||||
wrapTTL := r.Header.Get(WrapTTLHeaderName)
|
||||
if wrapTTL == "" {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// If it has an allowed suffix parse as a duration string
|
||||
if strings.HasSuffix(wrapTTL, "s") || strings.HasSuffix(wrapTTL, "m") || strings.HasSuffix(wrapTTL, "h") {
|
||||
dur, err := time.ParseDuration(wrapTTL)
|
||||
if err != nil {
|
||||
return req, err
|
||||
}
|
||||
req.WrapTTL = dur
|
||||
} else {
|
||||
// Parse as a straight number of seconds
|
||||
seconds, err := strconv.ParseInt(wrapTTL, 10, 64)
|
||||
if err != nil {
|
||||
return req, err
|
||||
}
|
||||
req.WrapTTL = time.Duration(seconds) * time.Second
|
||||
}
|
||||
if int64(req.WrapTTL) < 0 {
|
||||
return req, fmt.Errorf("requested wrap ttl cannot be negative")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Determines the type of the error being returned and sets the HTTP
|
||||
// status code appropriately
|
||||
func respondErrorStatus(w http.ResponseWriter, err error) {
|
||||
status := http.StatusInternalServerError
|
||||
switch {
|
||||
// Keep adding more error types here to appropriate the status codes
|
||||
case errwrap.ContainsType(err, new(vault.StatusBadRequest)):
|
||||
case err != nil && errwrap.ContainsType(err, new(vault.StatusBadRequest)):
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
respondError(w, status, err)
|
||||
|
@ -167,7 +206,7 @@ func respondErrorStatus(w http.ResponseWriter, err error) {
|
|||
|
||||
func respondError(w http.ResponseWriter, status int, err error) {
|
||||
// Adjust status code when sealed
|
||||
if err == vault.ErrSealed {
|
||||
if errwrap.Contains(err, vault.ErrSealed.Error()) {
|
||||
status = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
|
@ -194,19 +233,19 @@ func respondCommon(w http.ResponseWriter, resp *logical.Response, err error) boo
|
|||
}
|
||||
|
||||
if resp.IsError() {
|
||||
var statusCode int
|
||||
statusCode := http.StatusBadRequest
|
||||
|
||||
switch err {
|
||||
case logical.ErrPermissionDenied:
|
||||
statusCode = http.StatusForbidden
|
||||
case logical.ErrUnsupportedOperation:
|
||||
statusCode = http.StatusMethodNotAllowed
|
||||
case logical.ErrUnsupportedPath:
|
||||
statusCode = http.StatusNotFound
|
||||
case logical.ErrInvalidRequest:
|
||||
statusCode = http.StatusBadRequest
|
||||
default:
|
||||
statusCode = http.StatusBadRequest
|
||||
if err != nil {
|
||||
switch err {
|
||||
case logical.ErrPermissionDenied:
|
||||
statusCode = http.StatusForbidden
|
||||
case logical.ErrUnsupportedOperation:
|
||||
statusCode = http.StatusMethodNotAllowed
|
||||
case logical.ErrUnsupportedPath:
|
||||
statusCode = http.StatusNotFound
|
||||
case logical.ErrInvalidRequest:
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
}
|
||||
|
||||
err := fmt.Errorf("%s", resp.Data["error"].(string))
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
|
@ -64,6 +66,33 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// We use this test to verify header auth wrapping
|
||||
func TestSysMounts_headerAuth_Wrapped(t *testing.T) {
|
||||
core, _, token := vault.TestCoreUnsealed(t)
|
||||
ln, addr := TestServer(t, core)
|
||||
defer ln.Close()
|
||||
|
||||
req, err := http.NewRequest("GET", addr+"/v1/sys/mounts", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
req.Header.Set(AuthHeaderName, token)
|
||||
req.Header.Set(WrapTTLHeaderName, "60s")
|
||||
|
||||
client := cleanhttp.DefaultClient()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
testResponseStatus(t, resp, 200)
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.ReadFrom(resp.Body)
|
||||
if strings.TrimSpace(buf.String()) != "null" {
|
||||
t.Fatalf("bad: %v", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_sealed(t *testing.T) {
|
||||
core, _, token := vault.TestCoreUnsealed(t)
|
||||
ln, addr := TestServer(t, core)
|
||||
|
|
188
http/logical.go
188
http/logical.go
|
@ -7,77 +7,89 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
type PrepareRequestFunc func(req *logical.Request) error
|
||||
|
||||
func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Request, int, error) {
|
||||
// Determine the path...
|
||||
if !strings.HasPrefix(r.URL.Path, "/v1/") {
|
||||
return nil, http.StatusNotFound, nil
|
||||
}
|
||||
path := r.URL.Path[len("/v1/"):]
|
||||
if path == "" {
|
||||
return nil, http.StatusNotFound, nil
|
||||
}
|
||||
|
||||
// Determine the operation
|
||||
var op logical.Operation
|
||||
switch r.Method {
|
||||
case "DELETE":
|
||||
op = logical.DeleteOperation
|
||||
case "GET":
|
||||
op = logical.ReadOperation
|
||||
// Need to call ParseForm to get query params loaded
|
||||
queryVals := r.URL.Query()
|
||||
listStr := queryVals.Get("list")
|
||||
if listStr != "" {
|
||||
list, err := strconv.ParseBool(listStr)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, nil
|
||||
}
|
||||
if list {
|
||||
op = logical.ListOperation
|
||||
}
|
||||
}
|
||||
case "POST", "PUT":
|
||||
op = logical.UpdateOperation
|
||||
case "LIST":
|
||||
op = logical.ListOperation
|
||||
default:
|
||||
return nil, http.StatusMethodNotAllowed, nil
|
||||
}
|
||||
|
||||
// Parse the request if we can
|
||||
var data map[string]interface{}
|
||||
if op == logical.UpdateOperation {
|
||||
err := parseRequest(r, &data)
|
||||
if err == io.EOF {
|
||||
data = nil
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, err
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
req := requestAuth(r, &logical.Request{
|
||||
Operation: op,
|
||||
Path: path,
|
||||
Data: data,
|
||||
Connection: getConnection(r),
|
||||
})
|
||||
req, err = requestWrapTTL(r, req)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
|
||||
}
|
||||
|
||||
return req, 0, nil
|
||||
}
|
||||
|
||||
func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback PrepareRequestFunc) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Determine the path...
|
||||
if !strings.HasPrefix(r.URL.Path, "/v1/") {
|
||||
respondError(w, http.StatusNotFound, nil)
|
||||
return
|
||||
}
|
||||
path := r.URL.Path[len("/v1/"):]
|
||||
if path == "" {
|
||||
respondError(w, http.StatusNotFound, nil)
|
||||
req, statusCode, err := buildLogicalRequest(w, r)
|
||||
if err != nil || statusCode != 0 {
|
||||
respondError(w, statusCode, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Determine the operation
|
||||
var op logical.Operation
|
||||
switch r.Method {
|
||||
case "DELETE":
|
||||
op = logical.DeleteOperation
|
||||
case "GET":
|
||||
op = logical.ReadOperation
|
||||
// Need to call ParseForm to get query params loaded
|
||||
queryVals := r.URL.Query()
|
||||
listStr := queryVals.Get("list")
|
||||
if listStr != "" {
|
||||
list, err := strconv.ParseBool(listStr)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusBadRequest, nil)
|
||||
}
|
||||
if list {
|
||||
op = logical.ListOperation
|
||||
}
|
||||
}
|
||||
case "POST", "PUT":
|
||||
op = logical.UpdateOperation
|
||||
case "LIST":
|
||||
op = logical.ListOperation
|
||||
default:
|
||||
respondError(w, http.StatusMethodNotAllowed, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the request if we can
|
||||
var data map[string]interface{}
|
||||
if op == logical.UpdateOperation {
|
||||
err := parseRequest(r, &data)
|
||||
if err == io.EOF {
|
||||
data = nil
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
req := requestAuth(r, &logical.Request{
|
||||
Operation: op,
|
||||
Path: path,
|
||||
Data: data,
|
||||
Connection: getConnection(r),
|
||||
})
|
||||
|
||||
// Certain endpoints may require changes to the request object.
|
||||
// They will have a callback registered to do the needful.
|
||||
// Invoking it before proceeding.
|
||||
// Certain endpoints may require changes to the request object. They
|
||||
// will have a callback registered to do the needed operations, so
|
||||
// invoke it before proceeding.
|
||||
if prepareRequestCallback != nil {
|
||||
if err := prepareRequestCallback(req); err != nil {
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
|
@ -93,7 +105,7 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
|
|||
return
|
||||
}
|
||||
switch {
|
||||
case op == logical.ReadOperation:
|
||||
case req.Operation == logical.ReadOperation:
|
||||
if resp == nil {
|
||||
respondError(w, http.StatusNotFound, nil)
|
||||
return
|
||||
|
@ -101,7 +113,7 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
|
|||
|
||||
// Basically: if we have empty "keys" or no keys at all, 404. This
|
||||
// provides consistency with GET.
|
||||
case op == logical.ListOperation:
|
||||
case req.Operation == logical.ListOperation:
|
||||
if resp == nil || len(resp.Data) == 0 {
|
||||
respondError(w, http.StatusNotFound, nil)
|
||||
return
|
||||
|
@ -123,7 +135,7 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
|
|||
}
|
||||
|
||||
// Build the proper response
|
||||
respondLogical(w, r, path, dataOnly, resp)
|
||||
respondLogical(w, r, req.Path, dataOnly, resp)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -148,34 +160,22 @@ func respondLogical(w http.ResponseWriter, r *http.Request, path string, dataOnl
|
|||
return
|
||||
}
|
||||
|
||||
logicalResp := &LogicalResponse{
|
||||
Data: resp.Data,
|
||||
Warnings: resp.Warnings(),
|
||||
}
|
||||
if resp.Secret != nil {
|
||||
logicalResp.LeaseID = resp.Secret.LeaseID
|
||||
logicalResp.Renewable = resp.Secret.Renewable
|
||||
logicalResp.LeaseDuration = int(resp.Secret.TTL.Seconds())
|
||||
}
|
||||
|
||||
// If we have authentication information, then
|
||||
// set up the result structure.
|
||||
if resp.Auth != nil {
|
||||
logicalResp.Auth = &Auth{
|
||||
ClientToken: resp.Auth.ClientToken,
|
||||
Accessor: resp.Auth.Accessor,
|
||||
Policies: resp.Auth.Policies,
|
||||
Metadata: resp.Auth.Metadata,
|
||||
LeaseDuration: int(resp.Auth.TTL.Seconds()),
|
||||
Renewable: resp.Auth.Renewable,
|
||||
if resp.WrapInfo != nil && resp.WrapInfo.Token != "" {
|
||||
httpResp = logical.HTTPResponse{
|
||||
WrapInfo: &logical.HTTPWrapInfo{
|
||||
Token: resp.WrapInfo.Token,
|
||||
TTL: int(resp.WrapInfo.TTL.Seconds()),
|
||||
CreationTime: resp.WrapInfo.CreationTime,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
httpResp = logical.SanitizeResponse(resp)
|
||||
}
|
||||
|
||||
httpResp = logicalResp
|
||||
}
|
||||
|
||||
// Respond
|
||||
respondOk(w, httpResp)
|
||||
return
|
||||
}
|
||||
|
||||
// respondRaw is used when the response is using HTTPContentType and HTTPRawBody
|
||||
|
@ -246,21 +246,3 @@ func getConnection(r *http.Request) (connection *logical.Connection) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
type LogicalResponse struct {
|
||||
LeaseID string `json:"lease_id"`
|
||||
Renewable bool `json:"renewable"`
|
||||
LeaseDuration int `json:"lease_duration"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Warnings []string `json:"warnings"`
|
||||
Auth *Auth `json:"auth"`
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
ClientToken string `json:"client_token"`
|
||||
Accessor string `json:"accessor"`
|
||||
Policies []string `json:"policies"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
LeaseDuration int `json:"lease_duration"`
|
||||
Renewable bool `json:"renewable"`
|
||||
}
|
||||
|
|
|
@ -40,8 +40,9 @@ func TestLogical(t *testing.T) {
|
|||
"data": map[string]interface{}{
|
||||
"data": "bar",
|
||||
},
|
||||
"auth": nil,
|
||||
"warnings": nilWarnings,
|
||||
"auth": nil,
|
||||
"wrap_info": nil,
|
||||
"warnings": nilWarnings,
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
testResponseBody(t, resp, &actual)
|
||||
|
@ -140,8 +141,9 @@ func TestLogical_StandbyRedirect(t *testing.T) {
|
|||
"role": "",
|
||||
"explicit_max_ttl": float64(0),
|
||||
},
|
||||
"warnings": nilWarnings,
|
||||
"auth": nil,
|
||||
"warnings": nilWarnings,
|
||||
"wrap_info": nil,
|
||||
"auth": nil,
|
||||
}
|
||||
|
||||
testResponseStatus(t, resp, 200)
|
||||
|
@ -178,6 +180,7 @@ func TestLogical_CreateToken(t *testing.T) {
|
|||
"renewable": false,
|
||||
"lease_duration": float64(0),
|
||||
"data": nil,
|
||||
"wrap_info": nil,
|
||||
"auth": map[string]interface{}{
|
||||
"policies": []interface{}{"root"},
|
||||
"metadata": nil,
|
||||
|
|
|
@ -3,6 +3,7 @@ package http
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
|
@ -20,7 +21,7 @@ func handleSysLeader(core *vault.Core) http.Handler {
|
|||
func handleSysLeaderGet(core *vault.Core, w http.ResponseWriter, r *http.Request) {
|
||||
haEnabled := true
|
||||
isLeader, address, err := core.Leader()
|
||||
if err == vault.ErrHANotEnabled {
|
||||
if errwrap.Contains(err, vault.ErrHANotEnabled.Error()) {
|
||||
haEnabled = false
|
||||
err = nil
|
||||
}
|
||||
|
|
|
@ -17,8 +17,8 @@ func TestSysPolicies(t *testing.T) {
|
|||
|
||||
var actual map[string]interface{}
|
||||
expected := map[string]interface{}{
|
||||
"policies": []interface{}{"default", "root"},
|
||||
"keys": []interface{}{"default", "root"},
|
||||
"policies": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
|
||||
"keys": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
testResponseBody(t, resp, &actual)
|
||||
|
@ -62,14 +62,19 @@ func TestSysWritePolicy(t *testing.T) {
|
|||
|
||||
var actual map[string]interface{}
|
||||
expected := map[string]interface{}{
|
||||
"policies": []interface{}{"default", "foo", "root"},
|
||||
"keys": []interface{}{"default", "foo", "root"},
|
||||
"policies": []interface{}{"cubbyhole-response-wrapping", "default", "foo", "root"},
|
||||
"keys": []interface{}{"cubbyhole-response-wrapping", "default", "foo", "root"},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
testResponseBody(t, resp, &actual)
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("bad: got\n%#v\nexpected\n%#v\n", actual, expected)
|
||||
}
|
||||
|
||||
resp = testHttpPost(t, token, addr+"/v1/sys/policy/cubbyhole-response-wrapping", map[string]interface{}{
|
||||
"rules": ``,
|
||||
})
|
||||
testResponseStatus(t, resp, 400)
|
||||
}
|
||||
|
||||
func TestSysDeletePolicy(t *testing.T) {
|
||||
|
@ -86,12 +91,17 @@ func TestSysDeletePolicy(t *testing.T) {
|
|||
resp = testHttpDelete(t, token, addr+"/v1/sys/policy/foo")
|
||||
testResponseStatus(t, resp, 204)
|
||||
|
||||
// Also attempt to delete these since they should not be allowed (ignore
|
||||
// responses, if they exist later that's sufficient)
|
||||
resp = testHttpDelete(t, token, addr+"/v1/sys/policy/default")
|
||||
resp = testHttpDelete(t, token, addr+"/v1/sys/policy/cubbyhole-response-wrapping")
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/v1/sys/policy")
|
||||
|
||||
var actual map[string]interface{}
|
||||
expected := map[string]interface{}{
|
||||
"policies": []interface{}{"default", "root"},
|
||||
"keys": []interface{}{"default", "root"},
|
||||
"policies": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
|
||||
"keys": []interface{}{"cubbyhole-response-wrapping", "default", "root"},
|
||||
}
|
||||
testResponseStatus(t, resp, 200)
|
||||
testResponseBody(t, resp, &actual)
|
||||
|
|
|
@ -13,19 +13,21 @@ import (
|
|||
|
||||
func handleSysSeal(core *vault.Core) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case "PUT":
|
||||
case "POST":
|
||||
req, statusCode, err := buildLogicalRequest(w, r)
|
||||
if err != nil || statusCode != 0 {
|
||||
respondError(w, statusCode, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch req.Operation {
|
||||
case logical.UpdateOperation:
|
||||
default:
|
||||
respondError(w, http.StatusMethodNotAllowed, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the auth for the request so we can access the token directly
|
||||
req := requestAuth(r, &logical.Request{})
|
||||
|
||||
// Seal with the token above
|
||||
if err := core.Seal(req.ClientToken); err != nil {
|
||||
if err := core.SealWithRequest(req); err != nil {
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
@ -36,19 +38,21 @@ func handleSysSeal(core *vault.Core) http.Handler {
|
|||
|
||||
func handleSysStepDown(core *vault.Core) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case "PUT":
|
||||
case "POST":
|
||||
req, statusCode, err := buildLogicalRequest(w, r)
|
||||
if err != nil || statusCode != 0 {
|
||||
respondError(w, statusCode, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch req.Operation {
|
||||
case logical.UpdateOperation:
|
||||
default:
|
||||
respondError(w, http.StatusMethodNotAllowed, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the auth for the request so we can access the token directly
|
||||
req := requestAuth(r, &logical.Request{})
|
||||
|
||||
// Seal with the token above
|
||||
if err := core.StepDown(req.ClientToken); err != nil {
|
||||
if err := core.StepDown(req); err != nil {
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
package logical
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// LockingInmemStorage implements Storage and stores all data in memory.
|
||||
type LockingInmemStorage struct {
|
||||
sync.RWMutex
|
||||
|
||||
Data map[string]*StorageEntry
|
||||
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (s *LockingInmemStorage) List(prefix string) ([]string, error) {
|
||||
s.once.Do(s.init)
|
||||
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
var result []string
|
||||
for k, _ := range s.Data {
|
||||
if strings.HasPrefix(k, prefix) {
|
||||
result = append(result, k)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *LockingInmemStorage) Get(key string) (*StorageEntry, error) {
|
||||
s.once.Do(s.init)
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
return s.Data[key], nil
|
||||
}
|
||||
|
||||
func (s *LockingInmemStorage) Put(entry *StorageEntry) error {
|
||||
s.once.Do(s.init)
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.Data[entry.Key] = entry
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LockingInmemStorage) Delete(k string) error {
|
||||
s.once.Do(s.init)
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.Data, k)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LockingInmemStorage) init() {
|
||||
s.Data = make(map[string]*StorageEntry)
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package logical
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Note: This uses the normal TestStorage, but the best way to exercise this is
|
||||
// to run transit's unit tests, which spawn 1000 goroutines to hammer the
|
||||
// backend for 10 seconds with this as the storage.
|
||||
|
||||
func TestLockingInmemStorage(t *testing.T) {
|
||||
TestStorage(t, new(LockingInmemStorage))
|
||||
}
|
|
@ -3,6 +3,7 @@ package logical
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Request is a struct that stores the parameters and context
|
||||
|
@ -52,6 +53,10 @@ type Request struct {
|
|||
// paths relative to itself. The `Path` is effectively the client
|
||||
// request path with the MountPoint trimmed off.
|
||||
MountPoint string
|
||||
|
||||
// WrapTTL contains the requested TTL of the token used to wrap the
|
||||
// response in a cubbyhole.
|
||||
WrapTTL time.Duration
|
||||
}
|
||||
|
||||
// Get returns a data field and guards for nil Data
|
||||
|
|
|
@ -3,6 +3,7 @@ package logical
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/copystructure"
|
||||
)
|
||||
|
@ -26,6 +27,19 @@ const (
|
|||
HTTPStatusCode = "http_status_code"
|
||||
)
|
||||
|
||||
type WrapInfo struct {
|
||||
// Setting to non-zero specifies that the response should be wrapped.
|
||||
// Specifies the desired TTL of the wrapping token.
|
||||
TTL time.Duration
|
||||
|
||||
// The token containing the wrapped response
|
||||
Token string
|
||||
|
||||
// The creation time. This can be used with the TTL to figure out an
|
||||
// expected expiration.
|
||||
CreationTime time.Time
|
||||
}
|
||||
|
||||
// Response is a struct that stores the response of a request.
|
||||
// It is used to abstract the details of the higher level request protocol.
|
||||
type Response struct {
|
||||
|
@ -54,6 +68,9 @@ type Response struct {
|
|||
// Vault (backend, core, etc.) to add warnings without accidentally
|
||||
// replacing what exists.
|
||||
warnings []string
|
||||
|
||||
// Information for wrapping the response in a cubbyhole
|
||||
WrapInfo *WrapInfo
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -74,7 +91,7 @@ func init() {
|
|||
if input.Auth != nil {
|
||||
retAuth, err := copystructure.Copy(input.Auth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error copying Secret: %v", err)
|
||||
return nil, fmt.Errorf("error copying Auth: %v", err)
|
||||
}
|
||||
ret.Auth = retAuth.(*Auth)
|
||||
}
|
||||
|
@ -82,7 +99,7 @@ func init() {
|
|||
if input.Data != nil {
|
||||
retData, err := copystructure.Copy(&input.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error copying Secret: %v", err)
|
||||
return nil, fmt.Errorf("error copying Data: %v", err)
|
||||
}
|
||||
ret.Data = retData.(map[string]interface{})
|
||||
}
|
||||
|
@ -93,6 +110,14 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
if input.WrapInfo != nil {
|
||||
retWrapInfo, err := copystructure.Copy(input.WrapInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error copying WrapInfo: %v", err)
|
||||
}
|
||||
ret.WrapInfo = retWrapInfo.(*WrapInfo)
|
||||
}
|
||||
|
||||
return &ret, nil
|
||||
}
|
||||
}
|
||||
|
@ -115,6 +140,11 @@ func (r *Response) ClearWarnings() {
|
|||
r.warnings = make([]string, 0, 1)
|
||||
}
|
||||
|
||||
// Copies the warnings from the other response to this one
|
||||
func (r *Response) CloneWarnings(other *Response) {
|
||||
r.warnings = other.warnings
|
||||
}
|
||||
|
||||
// IsError returns true if this response seems to indicate an error.
|
||||
func (r *Response) IsError() bool {
|
||||
return r != nil && len(r.Data) == 1 && r.Data["error"] != nil
|
||||
|
|
60
logical/sanitize.go
Normal file
60
logical/sanitize.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package logical
|
||||
|
||||
import "time"
|
||||
|
||||
// This logic was pulled from the http package so that it can be used for
|
||||
// encoding wrapped responses as well. It simply translates the logical request
|
||||
// to an http response, with the values we want and omitting the values we
|
||||
// don't.
|
||||
func SanitizeResponse(input *Response) *HTTPResponse {
|
||||
logicalResp := &HTTPResponse{
|
||||
Data: input.Data,
|
||||
Warnings: input.Warnings(),
|
||||
}
|
||||
|
||||
if input.Secret != nil {
|
||||
logicalResp.LeaseID = input.Secret.LeaseID
|
||||
logicalResp.Renewable = input.Secret.Renewable
|
||||
logicalResp.LeaseDuration = int(input.Secret.TTL.Seconds())
|
||||
}
|
||||
|
||||
// If we have authentication information, then
|
||||
// set up the result structure.
|
||||
if input.Auth != nil {
|
||||
logicalResp.Auth = &HTTPAuth{
|
||||
ClientToken: input.Auth.ClientToken,
|
||||
Accessor: input.Auth.Accessor,
|
||||
Policies: input.Auth.Policies,
|
||||
Metadata: input.Auth.Metadata,
|
||||
LeaseDuration: int(input.Auth.TTL.Seconds()),
|
||||
Renewable: input.Auth.Renewable,
|
||||
}
|
||||
}
|
||||
|
||||
return logicalResp
|
||||
}
|
||||
|
||||
type HTTPResponse struct {
|
||||
LeaseID string `json:"lease_id"`
|
||||
Renewable bool `json:"renewable"`
|
||||
LeaseDuration int `json:"lease_duration"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
WrapInfo *HTTPWrapInfo `json:"wrap_info"`
|
||||
Warnings []string `json:"warnings"`
|
||||
Auth *HTTPAuth `json:"auth"`
|
||||
}
|
||||
|
||||
type HTTPAuth struct {
|
||||
ClientToken string `json:"client_token"`
|
||||
Accessor string `json:"accessor"`
|
||||
Policies []string `json:"policies"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
LeaseDuration int `json:"lease_duration"`
|
||||
Renewable bool `json:"renewable"`
|
||||
}
|
||||
|
||||
type HTTPWrapInfo struct {
|
||||
Token string `json:"token"`
|
||||
TTL int `json:"ttl"`
|
||||
CreationTime time.Time `json:"creation_time"`
|
||||
}
|
|
@ -1,13 +1,14 @@
|
|||
package logical
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/physical"
|
||||
)
|
||||
|
||||
// InmemStorage implements Storage and stores all data in memory.
|
||||
type InmemStorage struct {
|
||||
Data map[string]*StorageEntry
|
||||
phys *physical.InmemBackend
|
||||
|
||||
once sync.Once
|
||||
}
|
||||
|
@ -15,33 +16,38 @@ type InmemStorage struct {
|
|||
func (s *InmemStorage) List(prefix string) ([]string, error) {
|
||||
s.once.Do(s.init)
|
||||
|
||||
var result []string
|
||||
for k, _ := range s.Data {
|
||||
if strings.HasPrefix(k, prefix) {
|
||||
result = append(result, k)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return s.phys.List(prefix)
|
||||
}
|
||||
|
||||
func (s *InmemStorage) Get(key string) (*StorageEntry, error) {
|
||||
s.once.Do(s.init)
|
||||
return s.Data[key], nil
|
||||
entry, err := s.phys.Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return &StorageEntry{
|
||||
Key: entry.Key,
|
||||
Value: entry.Value,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *InmemStorage) Put(entry *StorageEntry) error {
|
||||
s.once.Do(s.init)
|
||||
s.Data[entry.Key] = entry
|
||||
return nil
|
||||
physEntry := &physical.Entry{
|
||||
Key: entry.Key,
|
||||
Value: entry.Value,
|
||||
}
|
||||
return s.phys.Put(physEntry)
|
||||
}
|
||||
|
||||
func (s *InmemStorage) Delete(k string) error {
|
||||
s.once.Do(s.init)
|
||||
delete(s.Data, k)
|
||||
return nil
|
||||
return s.phys.Delete(k)
|
||||
}
|
||||
|
||||
func (s *InmemStorage) init() {
|
||||
s.Data = make(map[string]*StorageEntry)
|
||||
s.phys = physical.NewInmem(nil)
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
@ -302,8 +303,10 @@ func Test(t TestT, c TestCase) {
|
|||
if err == nil && resp.IsError() {
|
||||
err = fmt.Errorf("Erroneous response:\n\n%#v", resp)
|
||||
}
|
||||
if err != nil && err != logical.ErrUnsupportedOperation {
|
||||
t.Error(fmt.Sprintf("[ERR] Rollback error: %s", err))
|
||||
if err != nil {
|
||||
if !errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()) {
|
||||
t.Error(fmt.Sprintf("[ERR] Rollback error: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
// If we have any failed revokes, log it.
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue