Merge branch 'master' into ui-tier-icons

This commit is contained in:
Joshua Ogle 2018-08-16 10:11:04 -06:00 committed by GitHub
commit b01c94caa8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
332 changed files with 22600 additions and 3435 deletions

2
.gitignore vendored
View File

@ -43,6 +43,8 @@ Vagrantfile
# Configs
*.hcl
!command/agent/config/test-fixtures/config.hcl
!command/agent/config/test-fixtures/config-embedded-type.hcl
.DS_Store
.idea

View File

@ -1,4 +1,48 @@
## 0.10.4 (Unreleased)
## Next (Unreleased)
DEPRECATIONS/CHANGES:
* Request Timeouts: A default request timeout of 90s is now enforced. This
setting can be overwritten in the config file. If you anticipate requests
taking longer than 90s this setting should be updated before upgrading.
* `sys/` Top Level Injection: For the last two years for backwards
compatibility data for various `sys/` routes has been injected into both the
Secret's Data map and into the top level of the JSON response object.
However, this has some subtle issues that pop up from time to time and is
becoming increasingly complicated to maintain, so it's finally being
removed.
FEATURES:
* **AliCloud OSS Storage**: AliCloud OSS can now be used for Vault storage.
* **HA Support for MySQL Storage**: MySQL storage now supports HA.
* **ACL Templating**: ACL policies can now be templated using identity Entity,
Groups, and Metadata.
IMPROVEMENTS:
* agent: Add `exit_after_auth` to be able to use the Agent for a single
authentication [GH-5013]
* cli: Add support for passing parameters to `vault read` operations [GH-5093]
* storage/mysql: Support special characters in database and table names.
BUG FIXES:
* identity: Properly populate `mount_path` and `mount_type` on group lookup
[GH-5074]
* identity: Fix carryover issue from previously fixed race condition that
could cause Vault not to start up due to two entities referencing the same
alias. These entities are now merged. [GH-5000]
* secrets/database: Fix inability to update custom SQL statements on
database roles. [GH-5080]
## 0.10.4 (July 25th, 2018)
SECURITY:
* Control Groups: The associated Identity entity with a request was not being
properly persisted. As a result, the same authorizer could provide more than
one authorization.
DEPRECATIONS/CHANGES:
@ -11,6 +55,10 @@ DEPRECATIONS/CHANGES:
* CLI Retries: The CLI will no longer retry commands on 5xx errors. This was a
source of confusion to users as to why Vault would "hang" before returning a
5xx error. The Go API client still defaults to two retries.
* Identity Entity Alias metadata: You can no longer manually set metadata on
entity aliases. All alias data (except the canonical entity ID it refers to)
is intended to be managed by the plugin providing the alias information, so
allowing it to be set manually didn't make sense.
FEATURES:
@ -24,22 +72,33 @@ FEATURES:
* **UI Control Group Workflow (enterprise)**: The UI will now detect control
group responses and provides a workflow to view the status of the request
and to authorize requests.
* **Vault Agent (Beta)**: Vault Agent is a daemon that can automatically
authenticate for you across a variety of authentication methods, provide
tokens to clients, and keep the tokens renewed, reauthenticating as
necessary.
IMPROVEMENTS:
* auth/azure: Add support for virtual machine scale sets
* auth/gcp: Support multiple bindings for region, zone, and instance group
* cli: Add subcommands for interacting with the plugin catalog [GH-4911]
* cli: Add a `-description` flag to secrets and auth tune subcommands to allow
updating an existing secret engine's or auth method's description. This
change also allows the description to be unset by providing an empty string.
* core: Add config flag to disable non-printable character check [GH-4917]
* core: A `max_request_size` parameter can now be set per-listener to adjust
the maximum allowed size per request [GH-4824]
* core: Add control group request endpoint to default policy [GH-4904]
* identity: Identity metadata is now passed through to plugins [GH-4967]
* replication: Add additional saftey checks and logging when replication is
in a bad state
* secrets/kv: Add support for using `-field=data` to KVv2 when using `vault
kv` [GH-4895]
* secrets/pki: Add the ability to tidy revoked but unexpired certificates
[GH-4916]
* secrets/ssh: Allow Vault to work with single-argument SSH flags [GH-4825]
* secrets/ssh: SSH executable path can now be configured in the CLI [GH-4937]
* storage/swift: Add additional configuration options [GH-4901]
* ui: Choose which auth methods to show to unauthenticated users via
`listing_visibility` in the auth method edit forms [GH-4854]
* ui: Authenticate users automatically by passing a wrapped token to the UI via
@ -47,15 +106,35 @@ IMPROVEMENTS:
BUG FIXES:
* api: Fix response body being cleared too early [GH-4987]
* auth/approle: Fix issue with tidy endpoint that would unnecessarily remove
secret accessors [GH-4981]
* auth/aws: Fix updating `max_retries` [GH-4980]
* auth/kubernetes: Trim trailing whitespace when sending JWT
* cli: Fix parsing of environment variables for integer flags [GH-4925]
* core: Fix returning 500 instead of 503 if a rekey is attempted when Vault is
sealed [GH-4874]
* core: Fix issue releasing the leader lock in some circumstances [GH-4915]
* core: Fix a panic that could happen if the server was shut down while still
starting up
* core: Fix deadlock that would occur if a leadership loss occurs at the same
time as a seal operation [GH-4932]
* core: Fix issue with auth mounts failing to renew tokens due to policies
changing [GH-4960]
* auth/radius: Fix issue where some radius logins were being canceled too early
[GH-4941]
* core: Fix accidental seal of vault of we lose leadership during startup
[GH-4924]
* core: Fix standby not being able to forward requests larger than 4MB
[GH-4844]
* core: Avoid panic while processing group memberships [GH-4841]
* identity: Fix a race condition creating aliases [GH-4965]
* plugins: Fix being unable to send very large payloads to or from plugins
[GH-4958]
* physical/azure: Long list responses would sometimes be truncaated [GH-4983]
* physical/azure: Long list responses would sometimes be truncated [GH-4983]
* replication: Allow replication status requests to be processed while in
merkle sync
* replication: Ensure merkle reindex flushes all changes to storage immediately
* replication: Fix a case where a network interruption could cause a secondary
to be unable to reconnect to a primary
* secrets/pki: Fix permitted DNS domains performing improper validation

View File

@ -145,6 +145,12 @@ func testPostgresDB(t testing.TB) (string, func()) {
t.Fatalf("postgresdb: could not start container: %s", err)
}
cleanup := func() {
if err := pool.Purge(resource); err != nil {
t.Fatalf("failed to cleanup local container: %s", err)
}
}
addr := fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp"))
if err := pool.Retry(func() error {
@ -155,12 +161,9 @@ func testPostgresDB(t testing.TB) (string, func()) {
defer db.Close()
return db.Ping()
}); err != nil {
cleanup()
t.Fatalf("postgresdb: could not connect: %s", err)
}
return addr, func() {
if err := pool.Purge(resource); err != nil {
t.Fatalf("postgresdb: failed to cleanup container: %s", err)
}
}
return addr, cleanup
}

View File

@ -1,5 +1,7 @@
package api
import "context"
// TokenAuth is used to perform token backend operations on Vault
type TokenAuth struct {
c *Client
@ -16,7 +18,9 @@ func (c *TokenAuth) Create(opts *TokenCreateRequest) (*Secret, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -31,7 +35,9 @@ func (c *TokenAuth) CreateOrphan(opts *TokenCreateRequest) (*Secret, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -46,7 +52,9 @@ func (c *TokenAuth) CreateWithRole(opts *TokenCreateRequest, roleName string) (*
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -63,7 +71,9 @@ func (c *TokenAuth) Lookup(token string) (*Secret, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -79,7 +89,10 @@ func (c *TokenAuth) LookupAccessor(accessor string) (*Secret, error) {
}); err != nil {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -91,7 +104,9 @@ func (c *TokenAuth) LookupAccessor(accessor string) (*Secret, error) {
func (c *TokenAuth) LookupSelf() (*Secret, error) {
r := c.c.NewRequest("GET", "/v1/auth/token/lookup-self")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -109,7 +124,9 @@ func (c *TokenAuth) Renew(token string, increment int) (*Secret, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -126,7 +143,9 @@ func (c *TokenAuth) RenewSelf(increment int) (*Secret, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -146,7 +165,9 @@ func (c *TokenAuth) RenewTokenAsSelf(token string, increment int) (*Secret, erro
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -164,7 +185,10 @@ func (c *TokenAuth) RevokeAccessor(accessor string) error {
}); err != nil {
return err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
@ -183,7 +207,9 @@ func (c *TokenAuth) RevokeOrphan(token string) error {
return err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
@ -197,7 +223,10 @@ func (c *TokenAuth) RevokeOrphan(token string) error {
// an effect.
func (c *TokenAuth) RevokeSelf(token string) error {
r := c.c.NewRequest("PUT", "/v1/auth/token/revoke-self")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
@ -217,7 +246,9 @@ func (c *TokenAuth) RevokeTree(token string) error {
return err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}

View File

@ -464,6 +464,19 @@ func (c *Client) SetMFACreds(creds []string) {
c.mfaCreds = creds
}
// SetNamespace sets the namespace supplied either via the environment
// variable or via the command line.
func (c *Client) SetNamespace(namespace string) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
if c.headers == nil {
c.headers = make(http.Header)
}
c.headers.Set("X-Vault-Namespace", namespace)
}
// Token returns the access token being used by this client. It will
// return the empty string if there is no token set.
func (c *Client) Token() string {
@ -490,6 +503,26 @@ func (c *Client) ClearToken() {
c.token = ""
}
// Headers gets the current set of headers used for requests. This returns a
// copy; to modify it make modifications locally and use SetHeaders.
func (c *Client) Headers() http.Header {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
if c.headers == nil {
return nil
}
ret := make(http.Header)
for k, v := range c.headers {
for _, val := range v {
ret[k] = append(ret[k], val)
}
}
return ret
}
// SetHeaders sets the headers to be used for future requests.
func (c *Client) SetHeaders(headers http.Header) {
c.modifyLock.Lock()
@ -608,6 +641,13 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
// a Vault server not configured with this client. This is an advanced operation
// that generally won't need to be called externally.
func (c *Client) RawRequest(r *Request) (*Response, error) {
return c.RawRequestWithContext(context.Background(), r)
}
// RawRequestWithContext performs the raw request given. This request may be against
// a Vault server not configured with this client. This is an advanced operation
// that generally won't need to be called externally.
func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Response, error) {
c.modifyLock.RLock()
token := c.token
@ -622,7 +662,7 @@ func (c *Client) RawRequest(r *Request) (*Response, error) {
c.modifyLock.RUnlock()
if limiter != nil {
limiter.Wait(context.Background())
limiter.Wait(ctx)
}
// Sanity check the token before potentially erroring from the API
@ -643,13 +683,10 @@ START:
return nil, fmt.Errorf("nil request created")
}
// Set the timeout, if any
var cancelFunc context.CancelFunc
if timeout != 0 {
var ctx context.Context
ctx, cancelFunc = context.WithTimeout(context.Background(), timeout)
req.Request = req.Request.WithContext(ctx)
ctx, _ = context.WithTimeout(ctx, timeout)
}
req.Request = req.Request.WithContext(ctx)
if backoff == nil {
backoff = retryablehttp.LinearJitterBackoff
@ -667,9 +704,6 @@ START:
var result *Response
resp, err := client.Do(req)
if cancelFunc != nil {
cancelFunc()
}
if resp != nil {
result = &Response{Response: resp}
}

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"fmt"
)
@ -8,7 +9,10 @@ import (
func (c *Client) Help(path string) (*Help, error) {
r := c.NewRequest("GET", fmt.Sprintf("/v1/%s", path))
r.Params.Add("help", "1")
resp, err := c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}

View File

@ -2,8 +2,10 @@ package api
import (
"bytes"
"context"
"fmt"
"io"
"net/url"
"os"
"github.com/hashicorp/errwrap"
@ -45,8 +47,29 @@ func (c *Client) Logical() *Logical {
}
func (c *Logical) Read(path string) (*Secret, error) {
return c.ReadWithData(path, nil)
}
func (c *Logical) ReadWithData(path string, data map[string][]string) (*Secret, error) {
r := c.c.NewRequest("GET", "/v1/"+path)
resp, err := c.c.RawRequest(r)
var values url.Values
for k, v := range data {
if values == nil {
values = make(url.Values)
}
for _, val := range v {
values.Add(k, val)
}
}
if values != nil {
r.Params = values
}
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
@ -77,7 +100,10 @@ func (c *Logical) List(path string) (*Secret, error) {
// handle the wrapping lookup function
r.Method = "GET"
r.Params.Set("list", "true")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
@ -108,7 +134,9 @@ func (c *Logical) Write(path string, data map[string]interface{}) (*Secret, erro
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
@ -134,7 +162,10 @@ func (c *Logical) Write(path string, data map[string]interface{}) (*Secret, erro
func (c *Logical) Delete(path string) (*Secret, error) {
r := c.c.NewRequest("DELETE", "/v1/"+path)
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
@ -175,7 +206,9 @@ func (c *Logical) Unwrap(wrappingToken string) (*Secret, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}

View File

@ -1,6 +1,9 @@
package api
import "fmt"
import (
"context"
"fmt"
)
// SSH is used to return a client to invoke operations on SSH backend.
type SSH struct {
@ -28,7 +31,9 @@ func (c *SSH) Credential(role string, data map[string]interface{}) (*Secret, err
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -45,7 +50,9 @@ func (c *SSH) SignKey(role string, data map[string]interface{}) (*Secret, error)
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
@ -207,7 +208,9 @@ func (c *SSHHelper) Verify(otp string) (*SSHVerifyResponse, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,8 @@
package api
import (
"context"
"errors"
"fmt"
"github.com/mitchellh/mapstructure"
@ -16,56 +18,58 @@ func (c *Sys) AuditHash(path string, input string) (string, error) {
return "", err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return "", err
}
defer resp.Body.Close()
type d struct {
Hash string `json:"hash"`
}
var result d
err = resp.DecodeJSON(&result)
secret, err := ParseSecret(resp.Body)
if err != nil {
return "", err
}
if secret == nil || secret.Data == nil {
return "", errors.New("data from server response is empty")
}
return result.Hash, err
hash, ok := secret.Data["hash"]
if !ok {
return "", errors.New("hash not found in response data")
}
hashStr, ok := hash.(string)
if !ok {
return "", errors.New("could not parse hash in response data")
}
return hashStr, nil
}
func (c *Sys) ListAudit() (map[string]*Audit, error) {
r := c.c.NewRequest("GET", "/v1/sys/audit")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result map[string]interface{}
err = resp.DecodeJSON(&result)
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
mounts := map[string]*Audit{}
for k, v := range result {
switch v.(type) {
case map[string]interface{}:
default:
continue
}
var res Audit
err = mapstructure.Decode(v, &res)
if err != nil {
return nil, err
}
// Not a mount, some other api.Secret data
if res.Type == "" {
continue
}
mounts[k] = &res
err = mapstructure.Decode(secret.Data, &mounts)
if err != nil {
return nil, err
}
return mounts, nil
@ -87,7 +91,10 @@ func (c *Sys) EnableAuditWithOptions(path string, options *EnableAuditOptions) e
return err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
@ -98,7 +105,11 @@ func (c *Sys) EnableAuditWithOptions(path string, options *EnableAuditOptions) e
func (c *Sys) DisableAudit(path string) error {
r := c.c.NewRequest("DELETE", fmt.Sprintf("/v1/sys/audit/%s", path))
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -110,16 +121,16 @@ func (c *Sys) DisableAudit(path string) error {
// documentation. Please refer to that documentation for more details.
type EnableAuditOptions struct {
Type string `json:"type"`
Description string `json:"description"`
Options map[string]string `json:"options"`
Local bool `json:"local"`
Type string `json:"type" mapstructure:"type"`
Description string `json:"description" mapstructure:"description"`
Options map[string]string `json:"options" mapstructure:"options"`
Local bool `json:"local" mapstructure:"local"`
}
type Audit struct {
Path string
Type string
Description string
Options map[string]string
Local bool
Type string `json:"type" mapstructure:"type"`
Description string `json:"description" mapstructure:"description"`
Options map[string]string `json:"options" mapstructure:"options"`
Local bool `json:"local" mapstructure:"local"`
Path string `json:"path" mapstructure:"path"`
}

View File

@ -1,6 +1,8 @@
package api
import (
"context"
"errors"
"fmt"
"github.com/mitchellh/mapstructure"
@ -8,35 +10,27 @@ import (
func (c *Sys) ListAuth() (map[string]*AuthMount, error) {
r := c.c.NewRequest("GET", "/v1/sys/auth")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result map[string]interface{}
err = resp.DecodeJSON(&result)
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
mounts := map[string]*AuthMount{}
for k, v := range result {
switch v.(type) {
case map[string]interface{}:
default:
continue
}
var res AuthMount
err = mapstructure.Decode(v, &res)
if err != nil {
return nil, err
}
// Not a mount, some other api.Secret data
if res.Type == "" {
continue
}
mounts[k] = &res
err = mapstructure.Decode(secret.Data, &mounts)
if err != nil {
return nil, err
}
return mounts, nil
@ -56,7 +50,9 @@ func (c *Sys) EnableAuthWithOptions(path string, options *EnableAuthOptions) err
return err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
@ -67,7 +63,10 @@ func (c *Sys) EnableAuthWithOptions(path string, options *EnableAuthOptions) err
func (c *Sys) DisableAuth(path string) error {
r := c.c.NewRequest("DELETE", fmt.Sprintf("/v1/sys/auth/%s", path))
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}

View File

@ -1,6 +1,12 @@
package api
import "fmt"
import (
"context"
"errors"
"fmt"
"github.com/mitchellh/mapstructure"
)
func (c *Sys) CapabilitiesSelf(path string) ([]string, error) {
return c.Capabilities(c.c.Token(), path)
@ -22,28 +28,27 @@ func (c *Sys) Capabilities(token, path string) ([]string, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result map[string]interface{}
err = resp.DecodeJSON(&result)
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
var res []string
err = mapstructure.Decode(secret.Data[path], &res)
if err != nil {
return nil, err
}
if result["capabilities"] == nil {
return nil, nil
}
var capabilities []string
capabilitiesRaw, ok := result["capabilities"].([]interface{})
if !ok {
return nil, fmt.Errorf("error interpreting returned capabilities")
}
for _, capability := range capabilitiesRaw {
capabilities = append(capabilities, capability.(string))
}
return capabilities, nil
return res, nil
}

View File

@ -1,15 +1,37 @@
package api
import (
"context"
"errors"
"github.com/mitchellh/mapstructure"
)
func (c *Sys) CORSStatus() (*CORSResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/config/cors")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
var result CORSResponse
err = resp.DecodeJSON(&result)
err = mapstructure.Decode(secret.Data, &result)
if err != nil {
return nil, err
}
return &result, err
}
@ -19,38 +41,65 @@ func (c *Sys) ConfigureCORS(req *CORSRequest) (*CORSResponse, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
var result CORSResponse
err = resp.DecodeJSON(&result)
err = mapstructure.Decode(secret.Data, &result)
if err != nil {
return nil, err
}
return &result, err
}
func (c *Sys) DisableCORS() (*CORSResponse, error) {
r := c.c.NewRequest("DELETE", "/v1/sys/config/cors")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result CORSResponse
err = resp.DecodeJSON(&result)
return &result, err
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
var result CORSResponse
err = mapstructure.Decode(secret.Data, &result)
if err != nil {
return nil, err
}
return &result, err
}
type CORSRequest struct {
AllowedOrigins string `json:"allowed_origins"`
Enabled bool `json:"enabled"`
AllowedOrigins string `json:"allowed_origins" mapstructure:"allowed_origins"`
Enabled bool `json:"enabled" mapstructure:"enabled"`
}
type CORSResponse struct {
AllowedOrigins string `json:"allowed_origins"`
Enabled bool `json:"enabled"`
AllowedOrigins string `json:"allowed_origins" mapstructure:"allowed_origins"`
Enabled bool `json:"enabled" mapstructure:"enabled"`
}

View File

@ -1,5 +1,7 @@
package api
import "context"
func (c *Sys) GenerateRootStatus() (*GenerateRootStatusResponse, error) {
return c.generateRootStatusCommon("/v1/sys/generate-root/attempt")
}
@ -10,7 +12,10 @@ func (c *Sys) GenerateDROperationTokenStatus() (*GenerateRootStatusResponse, err
func (c *Sys) generateRootStatusCommon(path string) (*GenerateRootStatusResponse, error) {
r := c.c.NewRequest("GET", path)
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -40,7 +45,9 @@ func (c *Sys) generateRootInitCommon(path, otp, pgpKey string) (*GenerateRootSta
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -61,7 +68,10 @@ func (c *Sys) GenerateDROperationTokenCancel() error {
func (c *Sys) generateRootCancelCommon(path string) error {
r := c.c.NewRequest("DELETE", path)
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -87,7 +97,9 @@ func (c *Sys) generateRootUpdateCommon(path, shard, nonce string) (*GenerateRoot
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}

View File

@ -1,5 +1,7 @@
package api
import "context"
func (c *Sys) Health() (*HealthResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/health")
// If the code is 400 or above it will automatically turn into an error,
@ -9,7 +11,10 @@ func (c *Sys) Health() (*HealthResponse, error) {
r.Params.Add("sealedcode", "299")
r.Params.Add("standbycode", "299")
r.Params.Add("drsecondarycode", "299")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}

View File

@ -1,8 +1,13 @@
package api
import "context"
func (c *Sys) InitStatus() (bool, error) {
r := c.c.NewRequest("GET", "/v1/sys/init")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return false, err
}
@ -19,7 +24,9 @@ func (c *Sys) Init(opts *InitRequest) (*InitResponse, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}

View File

@ -1,8 +1,13 @@
package api
import "context"
func (c *Sys) Leader() (*LeaderResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/leader")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,9 @@
package api
import "errors"
import (
"context"
"errors"
)
func (c *Sys) Renew(id string, increment int) (*Secret, error) {
r := c.c.NewRequest("PUT", "/v1/sys/leases/renew")
@ -13,7 +16,9 @@ func (c *Sys) Renew(id string, increment int) (*Secret, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -24,7 +29,10 @@ func (c *Sys) Renew(id string, increment int) (*Secret, error) {
func (c *Sys) Revoke(id string) error {
r := c.c.NewRequest("PUT", "/v1/sys/leases/revoke/"+id)
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -33,7 +41,10 @@ func (c *Sys) Revoke(id string) error {
func (c *Sys) RevokePrefix(id string) error {
r := c.c.NewRequest("PUT", "/v1/sys/leases/revoke-prefix/"+id)
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -42,7 +53,10 @@ func (c *Sys) RevokePrefix(id string) error {
func (c *Sys) RevokeForce(id string) error {
r := c.c.NewRequest("PUT", "/v1/sys/leases/revoke-force/"+id)
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -74,7 +88,9 @@ func (c *Sys) RevokeWithOptions(opts *RevokeOptions) error {
}
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}

View File

@ -1,6 +1,8 @@
package api
import (
"context"
"errors"
"fmt"
"github.com/mitchellh/mapstructure"
@ -8,35 +10,27 @@ import (
func (c *Sys) ListMounts() (map[string]*MountOutput, error) {
r := c.c.NewRequest("GET", "/v1/sys/mounts")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result map[string]interface{}
err = resp.DecodeJSON(&result)
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
mounts := map[string]*MountOutput{}
for k, v := range result {
switch v.(type) {
case map[string]interface{}:
default:
continue
}
var res MountOutput
err = mapstructure.Decode(v, &res)
if err != nil {
return nil, err
}
// Not a mount, some other api.Secret data
if res.Type == "" {
continue
}
mounts[k] = &res
err = mapstructure.Decode(secret.Data, &mounts)
if err != nil {
return nil, err
}
return mounts, nil
@ -48,7 +42,9 @@ func (c *Sys) Mount(path string, mountInfo *MountInput) error {
return err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
@ -59,7 +55,10 @@ func (c *Sys) Mount(path string, mountInfo *MountInput) error {
func (c *Sys) Unmount(path string) error {
r := c.c.NewRequest("DELETE", fmt.Sprintf("/v1/sys/mounts/%s", path))
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -77,7 +76,9 @@ func (c *Sys) Remount(from, to string) error {
return err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -90,7 +91,9 @@ func (c *Sys) TuneMount(path string, config MountConfigInput) error {
return err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -100,14 +103,24 @@ func (c *Sys) TuneMount(path string, config MountConfigInput) error {
func (c *Sys) MountConfig(path string) (*MountConfigOutput, error) {
r := c.c.NewRequest("GET", fmt.Sprintf("/v1/sys/mounts/%s/tune", path))
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
var result MountConfigOutput
err = resp.DecodeJSON(&result)
err = mapstructure.Decode(secret.Data, &result)
if err != nil {
return nil, err
}

97
api/sys_namespaces.go Normal file
View File

@ -0,0 +1,97 @@
package api
import (
"fmt"
"net/http"
)
// ListNamespacesResponse is the response from the ListNamespaces call.
type ListNamespacesResponse struct {
// NamespacePaths is the list of child namespace paths
NamespacePaths []string `json:"namespace_paths"`
}
type GetNamespaceResponse struct {
Path string `json:"path"`
}
// ListNamespaces lists any existing namespace relative to the namespace
// provided in the client's namespace header.
func (c *Sys) ListNamespaces() (*ListNamespacesResponse, error) {
r := c.c.NewRequest("LIST", "/v1/sys/namespaces")
resp, err := c.c.RawRequest(r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result struct {
Data struct {
Keys []string `json:"keys"`
} `json:"data"`
}
err = resp.DecodeJSON(&result)
if err != nil {
return nil, err
}
return &ListNamespacesResponse{NamespacePaths: result.Data.Keys}, nil
}
// GetNamespace returns namespace information
func (c *Sys) GetNamespace(path string) (*GetNamespaceResponse, error) {
r := c.c.NewRequest("GET", fmt.Sprintf("/v1/sys/namespaces/%s", path))
resp, err := c.c.RawRequest(r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, nil
}
ret := &GetNamespaceResponse{}
result := map[string]interface{}{
"data": map[string]interface{}{},
}
if err := resp.DecodeJSON(&result); err != nil {
return nil, err
}
if data, ok := result["data"]; ok {
if pathOk, ok := data.(map[string]interface{})["path"]; ok {
if pathRaw, ok := pathOk.(string); ok {
ret.Path = pathRaw
}
}
}
return ret, nil
}
// CreateNamespace creates a new namespace relative to the namespace provided
// in the client's namespace header.
func (c *Sys) CreateNamespace(path string) error {
r := c.c.NewRequest("POST", fmt.Sprintf("/v1/sys/namespaces/%s", path))
resp, err := c.c.RawRequest(r)
if err != nil {
return err
}
defer resp.Body.Close()
return nil
}
// DeleteNamespace delete an existing namespace relative to the namespace
// provided in the client's namespace header.
func (c *Sys) DeleteNamespace(path string) error {
r := c.c.NewRequest("DELETE", fmt.Sprintf("/v1/sys/namespaces/%s", path))
resp, err := c.c.RawRequest(r)
if err != nil {
return err
}
defer resp.Body.Close()
return nil
}

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"fmt"
"net/http"
)
@ -11,7 +12,7 @@ type ListPluginsInput struct{}
// ListPluginsResponse is the response from the ListPlugins call.
type ListPluginsResponse struct {
// Names is the list of names of the plugins.
Names []string
Names []string `json:"names"`
}
// ListPlugins lists all plugins in the catalog and returns their names as a
@ -19,7 +20,10 @@ type ListPluginsResponse struct {
func (c *Sys) ListPlugins(i *ListPluginsInput) (*ListPluginsResponse, error) {
path := "/v1/sys/plugins/catalog"
req := c.c.NewRequest("LIST", path)
resp, err := c.c.RawRequest(req)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, req)
if err != nil {
return nil, err
}
@ -54,7 +58,10 @@ type GetPluginResponse struct {
func (c *Sys) GetPlugin(i *GetPluginInput) (*GetPluginResponse, error) {
path := fmt.Sprintf("/v1/sys/plugins/catalog/%s", i.Name)
req := c.c.NewRequest(http.MethodGet, path)
resp, err := c.c.RawRequest(req)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, req)
if err != nil {
return nil, err
}
@ -93,7 +100,9 @@ func (c *Sys) RegisterPlugin(i *RegisterPluginInput) error {
return err
}
resp, err := c.c.RawRequest(req)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, req)
if err == nil {
defer resp.Body.Close()
}
@ -111,7 +120,10 @@ type DeregisterPluginInput struct {
func (c *Sys) DeregisterPlugin(i *DeregisterPluginInput) error {
path := fmt.Sprintf("/v1/sys/plugins/catalog/%s", i.Name)
req := c.c.NewRequest(http.MethodDelete, path)
resp, err := c.c.RawRequest(req)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, req)
if err == nil {
defer resp.Body.Close()
}

View File

@ -1,39 +1,47 @@
package api
import "fmt"
import (
"context"
"errors"
"fmt"
"github.com/mitchellh/mapstructure"
)
func (c *Sys) ListPolicies() ([]string, error) {
r := c.c.NewRequest("GET", "/v1/sys/policy")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result map[string]interface{}
err = resp.DecodeJSON(&result)
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
var result []string
err = mapstructure.Decode(secret.Data["policies"], &result)
if err != nil {
return nil, err
}
var ok bool
if _, ok = result["policies"]; !ok {
return nil, fmt.Errorf("policies not found in response")
}
listRaw := result["policies"].([]interface{})
var policies []string
for _, val := range listRaw {
policies = append(policies, val.(string))
}
return policies, err
return result, err
}
func (c *Sys) GetPolicy(name string) (string, error) {
r := c.c.NewRequest("GET", fmt.Sprintf("/v1/sys/policy/%s", name))
resp, err := c.c.RawRequest(r)
r := c.c.NewRequest("GET", fmt.Sprintf("/v1/sys/policies/acl/%s", name))
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
if resp.StatusCode == 404 {
@ -44,16 +52,15 @@ func (c *Sys) GetPolicy(name string) (string, error) {
return "", err
}
var result map[string]interface{}
err = resp.DecodeJSON(&result)
secret, err := ParseSecret(resp.Body)
if err != nil {
return "", err
}
if rulesRaw, ok := result["rules"]; ok {
return rulesRaw.(string), nil
if secret == nil || secret.Data == nil {
return "", errors.New("data from server response is empty")
}
if policyRaw, ok := result["policy"]; ok {
if policyRaw, ok := secret.Data["policy"]; ok {
return policyRaw.(string), nil
}
@ -70,7 +77,9 @@ func (c *Sys) PutPolicy(name, rules string) error {
return err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
@ -81,7 +90,10 @@ func (c *Sys) PutPolicy(name, rules string) error {
func (c *Sys) DeletePolicy(name string) error {
r := c.c.NewRequest("DELETE", fmt.Sprintf("/v1/sys/policy/%s", name))
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}

View File

@ -1,8 +1,18 @@
package api
import (
"context"
"errors"
"github.com/mitchellh/mapstructure"
)
func (c *Sys) RekeyStatus() (*RekeyStatusResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/rekey/init")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -15,7 +25,10 @@ func (c *Sys) RekeyStatus() (*RekeyStatusResponse, error) {
func (c *Sys) RekeyRecoveryKeyStatus() (*RekeyStatusResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/rekey-recovery-key/init")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -28,7 +41,10 @@ func (c *Sys) RekeyRecoveryKeyStatus() (*RekeyStatusResponse, error) {
func (c *Sys) RekeyVerificationStatus() (*RekeyVerificationStatusResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/rekey/verify")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -41,7 +57,10 @@ func (c *Sys) RekeyVerificationStatus() (*RekeyVerificationStatusResponse, error
func (c *Sys) RekeyRecoveryKeyVerificationStatus() (*RekeyVerificationStatusResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/rekey-recovery-key/verify")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -58,7 +77,9 @@ func (c *Sys) RekeyInit(config *RekeyInitRequest) (*RekeyStatusResponse, error)
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -75,7 +96,9 @@ func (c *Sys) RekeyRecoveryKeyInit(config *RekeyInitRequest) (*RekeyStatusRespon
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -88,7 +111,10 @@ func (c *Sys) RekeyRecoveryKeyInit(config *RekeyInitRequest) (*RekeyStatusRespon
func (c *Sys) RekeyCancel() error {
r := c.c.NewRequest("DELETE", "/v1/sys/rekey/init")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -97,7 +123,10 @@ func (c *Sys) RekeyCancel() error {
func (c *Sys) RekeyRecoveryKeyCancel() error {
r := c.c.NewRequest("DELETE", "/v1/sys/rekey-recovery-key/init")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -106,7 +135,10 @@ func (c *Sys) RekeyRecoveryKeyCancel() error {
func (c *Sys) RekeyVerificationCancel() error {
r := c.c.NewRequest("DELETE", "/v1/sys/rekey/verify")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -115,7 +147,10 @@ func (c *Sys) RekeyVerificationCancel() error {
func (c *Sys) RekeyRecoveryKeyVerificationCancel() error {
r := c.c.NewRequest("DELETE", "/v1/sys/rekey-recovery-key/verify")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -133,7 +168,9 @@ func (c *Sys) RekeyUpdate(shard, nonce string) (*RekeyUpdateResponse, error) {
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -155,7 +192,9 @@ func (c *Sys) RekeyRecoveryKeyUpdate(shard, nonce string) (*RekeyUpdateResponse,
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -168,33 +207,66 @@ func (c *Sys) RekeyRecoveryKeyUpdate(shard, nonce string) (*RekeyUpdateResponse,
func (c *Sys) RekeyRetrieveBackup() (*RekeyRetrieveResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/rekey/backup")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
var result RekeyRetrieveResponse
err = resp.DecodeJSON(&result)
err = mapstructure.Decode(secret.Data, &result)
if err != nil {
return nil, err
}
return &result, err
}
func (c *Sys) RekeyRetrieveRecoveryBackup() (*RekeyRetrieveResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/rekey/recovery-backup")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
var result RekeyRetrieveResponse
err = resp.DecodeJSON(&result)
err = mapstructure.Decode(secret.Data, &result)
if err != nil {
return nil, err
}
return &result, err
}
func (c *Sys) RekeyDeleteBackup() error {
r := c.c.NewRequest("DELETE", "/v1/sys/rekey/backup")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -204,7 +276,10 @@ func (c *Sys) RekeyDeleteBackup() error {
func (c *Sys) RekeyDeleteRecoveryBackup() error {
r := c.c.NewRequest("DELETE", "/v1/sys/rekey/recovery-backup")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -223,7 +298,9 @@ func (c *Sys) RekeyVerificationUpdate(shard, nonce string) (*RekeyVerificationUp
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -245,7 +322,9 @@ func (c *Sys) RekeyRecoveryKeyVerificationUpdate(shard, nonce string) (*RekeyVer
return nil, err
}
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
@ -290,9 +369,9 @@ type RekeyUpdateResponse struct {
}
type RekeyRetrieveResponse struct {
Nonce string `json:"nonce"`
Keys map[string][]string `json:"keys"`
KeysB64 map[string][]string `json:"keys_base64"`
Nonce string `json:"nonce" mapstructure:"nonce"`
Keys map[string][]string `json:"keys" mapstructure:"keys"`
KeysB64 map[string][]string `json:"keys_base64" mapstructure:"keys_base64"`
}
type RekeyVerificationStatusResponse struct {

View File

@ -1,10 +1,18 @@
package api
import "time"
import (
"context"
"encoding/json"
"errors"
"time"
)
func (c *Sys) Rotate() error {
r := c.c.NewRequest("POST", "/v1/sys/rotate")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -13,15 +21,54 @@ func (c *Sys) Rotate() error {
func (c *Sys) KeyStatus() (*KeyStatus, error) {
r := c.c.NewRequest("GET", "/v1/sys/key-status")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
defer resp.Body.Close()
result := new(KeyStatus)
err = resp.DecodeJSON(result)
return result, err
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
var result KeyStatus
termRaw, ok := secret.Data["term"]
if !ok {
return nil, errors.New("term not found in response")
}
term, ok := termRaw.(json.Number)
if !ok {
return nil, errors.New("could not convert term to a number")
}
term64, err := term.Int64()
if err != nil {
return nil, err
}
result.Term = int(term64)
installTimeRaw, ok := secret.Data["install_time"]
if !ok {
return nil, errors.New("install_time not found in response")
}
installTimeStr, ok := installTimeRaw.(string)
if !ok {
return nil, errors.New("could not convert install_time to a string")
}
installTime, err := time.Parse(time.RFC3339Nano, installTimeStr)
if err != nil {
return nil, err
}
result.InstallTime = installTime
return &result, err
}
type KeyStatus struct {

View File

@ -1,5 +1,7 @@
package api
import "context"
func (c *Sys) SealStatus() (*SealStatusResponse, error) {
r := c.c.NewRequest("GET", "/v1/sys/seal-status")
return sealStatusRequest(c, r)
@ -7,7 +9,10 @@ func (c *Sys) SealStatus() (*SealStatusResponse, error) {
func (c *Sys) Seal() error {
r := c.c.NewRequest("PUT", "/v1/sys/seal")
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err == nil {
defer resp.Body.Close()
}
@ -37,7 +42,9 @@ func (c *Sys) Unseal(shard string) (*SealStatusResponse, error) {
}
func sealStatusRequest(c *Sys, r *Request) (*SealStatusResponse, error) {
resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}

View File

@ -1,10 +1,15 @@
package api
import "context"
func (c *Sys) StepDown() error {
r := c.c.NewRequest("PUT", "/v1/sys/step-down")
resp, err := c.c.RawRequest(r)
if err == nil {
defer resp.Body.Close()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
return err
}

View File

@ -108,7 +108,7 @@ func TestAppRole_TidyDanglingAccessors_RaceTest(t *testing.T) {
}
count := 1
wg := sync.WaitGroup{}
wg := &sync.WaitGroup{}
now := time.Now()
started := false
for {
@ -132,7 +132,7 @@ func TestAppRole_TidyDanglingAccessors_RaceTest(t *testing.T) {
Path: "role/role1/secret-id",
Storage: storage,
}
resp, err = b.HandleRequest(context.Background(), roleSecretIDReq)
resp, err := b.HandleRequest(context.Background(), roleSecretIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}

View File

@ -568,7 +568,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
if !strings.HasSuffix(principalARN, "*") {
principalID, err := b.resolveArnToUniqueIDFunc(ctx, req.Storage, principalARN)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("unable to resolve ARN %#v to internal ID: %#v", principalARN, err)), nil
return logical.ErrorResponse(fmt.Sprintf("unable to resolve ARN %#v to internal ID: %s", principalARN, err.Error())), nil
}
roleEntry.BoundIamPrincipalIDs = append(roleEntry.BoundIamPrincipalIDs, principalID)
}

View File

@ -3,6 +3,7 @@ package aws
import (
"context"
"strings"
"sync"
"time"
"github.com/hashicorp/vault/logical"
@ -34,10 +35,9 @@ func Backend() *backend {
Paths: []*framework.Path{
pathConfigRoot(),
pathConfigLease(&b),
pathRoles(),
pathRoles(&b),
pathListRoles(&b),
pathUser(&b),
pathSTS(&b),
},
Secrets: []*framework.Secret{
@ -54,6 +54,9 @@ func Backend() *backend {
type backend struct {
*framework.Backend
// Mutex to protect access to reading and writing policies
roleMutex sync.RWMutex
}
const backendHelp = `

View File

@ -1,18 +1,19 @@
package aws
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"os"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
@ -34,8 +35,8 @@ func TestBackend_basic(t *testing.T) {
Backend: getBackend(t),
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepWritePolicy(t, "test", testPolicy),
testAccStepReadUser(t, "test"),
testAccStepWritePolicy(t, "test", testDynamoPolicy),
testAccStepRead(t, "creds", "test", []credentialTestFunc{listDynamoTablesTest}),
},
})
}
@ -56,12 +57,12 @@ func TestBackend_basicSTS(t *testing.T) {
Backend: getBackend(t),
Steps: []logicaltest.TestStep{
testAccStepConfigWithCreds(t, accessKey),
testAccStepWritePolicy(t, "test", testPolicy),
testAccStepReadSTS(t, "test"),
testAccStepWriteArnPolicyRef(t, "test", testPolicyArn),
testAccStepWritePolicy(t, "test", testDynamoPolicy),
testAccStepRead(t, "sts", "test", []credentialTestFunc{listDynamoTablesTest}),
testAccStepWriteArnPolicyRef(t, "test", ec2PolicyArn),
testAccStepReadSTSWithArnPolicy(t, "test"),
testAccStepWriteArnRoleRef(t, testRoleName),
testAccStepReadSTS(t, testRoleName),
testAccStepRead(t, "sts", testRoleName, []credentialTestFunc{describeInstancesTest}),
},
Teardown: func() error {
return teardown(accessKey)
@ -70,8 +71,8 @@ func TestBackend_basicSTS(t *testing.T) {
}
func TestBackend_policyCrud(t *testing.T) {
var compacted bytes.Buffer
if err := json.Compact(&compacted, []byte(testPolicy)); err != nil {
compacted, err := compactJSON(testDynamoPolicy)
if err != nil {
t.Fatalf("bad: %s", err)
}
@ -80,8 +81,8 @@ func TestBackend_policyCrud(t *testing.T) {
Backend: getBackend(t),
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepWritePolicy(t, "test", testPolicy),
testAccStepReadPolicy(t, "test", compacted.String()),
testAccStepWritePolicy(t, "test", testDynamoPolicy),
testAccStepReadPolicy(t, "test", compacted),
testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", ""),
},
@ -162,7 +163,7 @@ func createRole(t *testing.T) {
}
attachment := &iam.AttachRolePolicyInput{
PolicyArn: aws.String(testPolicyArn),
PolicyArn: aws.String(ec2PolicyArn),
RoleName: aws.String(testRoleName), // Required
}
_, err = svc.AttachRolePolicy(attachment)
@ -254,7 +255,7 @@ func createUser(t *testing.T, accessKey *awsAccessKey) {
accessKey.SecretAccessKey = *genAccessKey.SecretAccessKey
}
func teardown(accessKey *awsAccessKey) error {
func deleteTestRole() error {
awsConfig := &aws.Config{
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
@ -262,7 +263,7 @@ func teardown(accessKey *awsAccessKey) error {
svc := iam.New(session.New(awsConfig))
attachment := &iam.DetachRolePolicyInput{
PolicyArn: aws.String(testPolicyArn),
PolicyArn: aws.String(ec2PolicyArn),
RoleName: aws.String(testRoleName), // Required
}
_, err := svc.DetachRolePolicy(attachment)
@ -282,12 +283,25 @@ func teardown(accessKey *awsAccessKey) error {
log.Printf("[WARN] AWS DeleteRole failed: %v", err)
return err
}
return nil
}
func teardown(accessKey *awsAccessKey) error {
if err := deleteTestRole(); err != nil {
return err
}
awsConfig := &aws.Config{
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
svc := iam.New(session.New(awsConfig))
userDetachment := &iam.DetachUserPolicyInput{
PolicyArn: aws.String("arn:aws:iam::aws:policy/AdministratorAccess"),
UserName: aws.String(testUserName),
}
_, err = svc.DetachUserPolicy(userDetachment)
_, err := svc.DetachUserPolicy(userDetachment)
if err != nil {
log.Printf("[WARN] AWS DetachUserPolicy failed: %v", err)
return err
@ -354,51 +368,10 @@ func testAccStepConfigWithCreds(t *testing.T, accessKey *awsAccessKey) logicalte
}
}
func testAccStepReadUser(t *testing.T, name string) logicaltest.TestStep {
func testAccStepRead(t *testing.T, path, name string, credentialTests []credentialTestFunc) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "creds/" + name,
Check: func(resp *logical.Response) error {
var d struct {
AccessKey string `mapstructure:"access_key"`
SecretKey string `mapstructure:"secret_key"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[WARN] Generated credentials: %v", d)
// Build a client and verify that the credentials work
creds := credentials.NewStaticCredentials(d.AccessKey, d.SecretKey, "")
awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
client := ec2.New(session.New(awsConfig))
log.Printf("[WARN] Verifying that the generated credentials work...")
retryCount := 0
success := false
var err error
for !success && retryCount < 10 {
_, err = client.DescribeInstances(&ec2.DescribeInstancesInput{})
if err == nil {
return nil
}
time.Sleep(time.Second)
retryCount++
}
return err
},
}
}
func testAccStepReadSTS(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "sts/" + name,
Path: path + "/" + name,
Check: func(resp *logical.Response) error {
var d struct {
AccessKey string `mapstructure:"access_key"`
@ -409,27 +382,101 @@ func testAccStepReadSTS(t *testing.T, name string) logicaltest.TestStep {
return err
}
log.Printf("[WARN] Generated credentials: %v", d)
// Build a client and verify that the credentials work
creds := credentials.NewStaticCredentials(d.AccessKey, d.SecretKey, d.STSToken)
awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
for _, test := range credentialTests {
err := test(d.AccessKey, d.SecretKey, d.STSToken)
if err != nil {
return err
}
}
client := ec2.New(session.New(awsConfig))
log.Printf("[WARN] Verifying that the generated credentials work...")
_, err := client.DescribeInstances(&ec2.DescribeInstancesInput{})
if err != nil {
return err
}
return nil
},
}
}
func describeInstancesTest(accessKey, secretKey, token string) error {
creds := credentials.NewStaticCredentials(accessKey, secretKey, token)
awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
client := ec2.New(session.New(awsConfig))
log.Printf("[WARN] Verifying that the generated credentials work with ec2:DescribeInstances...")
return retryUntilSuccess(func() error {
_, err := client.DescribeInstances(&ec2.DescribeInstancesInput{})
return err
})
}
func describeAzsTestUnauthorized(accessKey, secretKey, token string) error {
creds := credentials.NewStaticCredentials(accessKey, secretKey, token)
awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
client := ec2.New(session.New(awsConfig))
log.Printf("[WARN] Verifying that the generated credentials don't work with ec2:DescribeAvailabilityZones...")
return retryUntilSuccess(func() error {
_, err := client.DescribeAvailabilityZones(&ec2.DescribeAvailabilityZonesInput{})
// Need to make sure AWS authenticates the generated credentials but does not authorize the operation
if err == nil {
return fmt.Errorf("operation succeeded when expected failure")
}
if aerr, ok := err.(awserr.Error); ok {
if aerr.Code() == "UnauthorizedOperation" {
return nil
}
}
return err
})
}
func listIamUsersTest(accessKey, secretKey, token string) error {
creds := credentials.NewStaticCredentials(accessKey, secretKey, token)
awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
client := iam.New(session.New(awsConfig))
log.Printf("[WARN] Verifying that the generated credentials work with iam:ListUsers...")
return retryUntilSuccess(func() error {
_, err := client.ListUsers(&iam.ListUsersInput{})
return err
})
}
func listDynamoTablesTest(accessKey, secretKey, token string) error {
creds := credentials.NewStaticCredentials(accessKey, secretKey, token)
awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
client := dynamodb.New(session.New(awsConfig))
log.Printf("[WARN] Verifying that the generated credentials work with dynamodb:ListTables...")
return retryUntilSuccess(func() error {
_, err := client.ListTables(&dynamodb.ListTablesInput{})
return err
})
}
func retryUntilSuccess(op func() error) error {
retryCount := 0
success := false
var err error
for !success && retryCount < 10 {
err = op()
if err == nil {
return nil
}
time.Sleep(time.Second)
retryCount++
}
return err
}
func testAccStepReadSTSWithArnPolicy(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
@ -437,7 +484,7 @@ func testAccStepReadSTSWithArnPolicy(t *testing.T, name string) logicaltest.Test
ErrorOk: true,
Check: func(resp *logical.Response) error {
if resp.Data["error"] !=
"Can't generate STS credentials for a managed policy; use a role to assume or an inline policy instead" {
"attempted to retrieve iam_user credentials through the sts path; this is not allowed for legacy roles" {
t.Fatalf("bad: %v", resp)
}
return nil
@ -450,7 +497,7 @@ func testAccStepWritePolicy(t *testing.T, name string, policy string) logicaltes
Operation: logical.UpdateOperation,
Path: "roles/" + name,
Data: map[string]interface{}{
"policy": testPolicy,
"policy": policy,
},
}
}
@ -475,31 +522,28 @@ func testAccStepReadPolicy(t *testing.T, name string, value string) logicaltest.
return fmt.Errorf("bad: %#v", resp)
}
var d struct {
Policy string `mapstructure:"policy"`
expected := map[string]interface{}{
"policy_arns": []string(nil),
"role_arns": []string(nil),
"policy_document": value,
"credential_types": []string{iamUserCred, federationTokenCred},
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
if !reflect.DeepEqual(resp.Data, expected) {
return fmt.Errorf("bad: got: %#v\nexpected: %#v", resp.Data, expected)
}
if d.Policy != value {
return fmt.Errorf("bad: %#v", resp)
}
return nil
},
}
}
const testPolicy = `
{
const testDynamoPolicy = `{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "Stmt1426528957000",
"Effect": "Allow",
"Action": [
"ec2:*"
"dynamodb:List*"
],
"Resource": [
"*"
@ -509,14 +553,42 @@ const testPolicy = `
}
`
const testPolicyArn = "arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess"
const ec2PolicyArn = "arn:aws:iam::aws:policy/AmazonEC2ReadOnlyAccess"
const iamPolicyArn = "arn:aws:iam::aws:policy/IAMReadOnlyAccess"
func testAccStepWriteRole(t *testing.T, name string, data map[string]interface{}) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "roles/" + name,
Data: data,
}
}
func testAccStepReadRole(t *testing.T, name string, expected map[string]interface{}) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "roles/" + name,
Check: func(resp *logical.Response) error {
if resp == nil {
if expected == nil {
return nil
}
return fmt.Errorf("bad: nil response")
}
if !reflect.DeepEqual(resp.Data, expected) {
return fmt.Errorf("bad: got %#v\nexpected: %#v", resp.Data, expected)
}
return nil
},
}
}
func testAccStepWriteArnPolicyRef(t *testing.T, name string, arn string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "roles/" + name,
Data: map[string]interface{}{
"arn": testPolicyArn,
"arn": ec2PolicyArn,
},
}
}
@ -528,20 +600,92 @@ func TestBackend_basicPolicyArnRef(t *testing.T) {
Backend: getBackend(t),
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepWriteArnPolicyRef(t, "test", testPolicyArn),
testAccStepReadUser(t, "test"),
testAccStepWriteArnPolicyRef(t, "test", ec2PolicyArn),
testAccStepRead(t, "creds", "test", []credentialTestFunc{describeInstancesTest}),
},
})
}
func TestBackend_iamUserManagedInlinePolicies(t *testing.T) {
compacted, err := compactJSON(testDynamoPolicy)
if err != nil {
t.Fatalf("bad: %#v", err)
}
roleData := map[string]interface{}{
"policy_document": testDynamoPolicy,
"policy_arns": []string{ec2PolicyArn, iamPolicyArn},
"credential_type": iamUserCred,
}
expectedRoleData := map[string]interface{}{
"policy_document": compacted,
"policy_arns": []string{ec2PolicyArn, iamPolicyArn},
"credential_types": []string{iamUserCred},
"role_arns": []string(nil),
}
logicaltest.Test(t, logicaltest.TestCase{
AcceptanceTest: true,
PreCheck: func() { testAccPreCheck(t) },
Backend: getBackend(t),
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepWriteRole(t, "test", roleData),
testAccStepReadRole(t, "test", expectedRoleData),
testAccStepRead(t, "creds", "test", []credentialTestFunc{describeInstancesTest, listIamUsersTest, listDynamoTablesTest}),
testAccStepRead(t, "sts", "test", []credentialTestFunc{describeInstancesTest, listIamUsersTest, listDynamoTablesTest}),
},
})
}
func TestBackend_AssumedRoleWithPolicyDoc(t *testing.T) {
// This looks a bit curious. The policy document and the role document act
// as a logical intersection of policies. The role allows ec2:Describe*
// (among other permissions). This policy allows everything BUT
// ec2:DescribeAvailabilityZones. Thus, the logical intersection of the two
// is all ec2:Describe* EXCEPT ec2:DescribeAvailabilityZones, and so the
// describeAZs call should fail
allowAllButDescribeAzs := `
{
"Version": "2012-10-17",
"Statement": [{
"Effect": "Allow",
"NotAction": "ec2:DescribeAvailabilityZones",
"Resource": "*"
}]
}
`
roleData := map[string]interface{}{
"policy_document": allowAllButDescribeAzs,
"role_arns": []string{fmt.Sprintf("arn:aws:iam::%s:role/%s", os.Getenv("AWS_ACCOUNT_ID"), testRoleName)},
"credential_type": assumedRoleCred,
}
logicaltest.Test(t, logicaltest.TestCase{
AcceptanceTest: true,
PreCheck: func() {
testAccPreCheck(t)
createRole(t)
// Sleep sometime because AWS is eventually consistent
log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...")
time.Sleep(10 * time.Second)
},
Backend: getBackend(t),
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepWriteRole(t, "test", roleData),
testAccStepRead(t, "sts", "test", []credentialTestFunc{describeInstancesTest, describeAzsTestUnauthorized}),
testAccStepRead(t, "creds", "test", []credentialTestFunc{describeInstancesTest, describeAzsTestUnauthorized}),
},
Teardown: deleteTestRole,
})
}
func TestBackend_policyArnCrud(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
AcceptanceTest: true,
Backend: getBackend(t),
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepWriteArnPolicyRef(t, "test", testPolicyArn),
testAccStepReadArnPolicy(t, "test", testPolicyArn),
testAccStepWriteArnPolicyRef(t, "test", ec2PolicyArn),
testAccStepReadArnPolicy(t, "test", ec2PolicyArn),
testAccStepDeletePolicy(t, "test"),
testAccStepReadArnPolicy(t, "test", ""),
},
@ -561,15 +705,14 @@ func testAccStepReadArnPolicy(t *testing.T, name string, value string) logicalte
return fmt.Errorf("bad: %#v", resp)
}
var d struct {
Policy string `mapstructure:"arn"`
expected := map[string]interface{}{
"policy_arns": []string{value},
"role_arns": []string(nil),
"policy_document": "",
"credential_types": []string{iamUserCred},
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.Policy != value {
return fmt.Errorf("bad: %#v", resp)
if !reflect.DeepEqual(resp.Data, expected) {
return fmt.Errorf("bad: got: %#v\nexpected: %#v", resp.Data, expected)
}
return nil
@ -591,3 +734,5 @@ type awsAccessKey struct {
AccessKeyId string
SecretAccessKey string
}
type credentialTestFunc func(string, string, string) error

View File

@ -4,11 +4,13 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"errors"
"fmt"
"strings"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -26,7 +28,7 @@ func pathListRoles(b *backend) *framework.Path {
}
}
func pathRoles() *framework.Path {
func pathRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
@ -35,21 +37,46 @@ func pathRoles() *framework.Path {
Description: "Name of the policy",
},
"arn": &framework.FieldSchema{
"credential_type": &framework.FieldSchema{
Type: framework.TypeString,
Description: "ARN Reference to a managed policy",
Description: fmt.Sprintf("Type of credential to retrieve. Must be one of %s, %s, or %s", assumedRoleCred, iamUserCred, federationTokenCred),
},
"role_arns": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: "ARNs of AWS roles allowed to be assumed. Only valid when credential_type is " + assumedRoleCred,
},
"policy_arns": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: "ARNs of AWS policies to attach to IAM users. Only valid when credential_type is " + iamUserCred,
},
"policy_document": &framework.FieldSchema{
Type: framework.TypeString,
Description: `JSON-encoded IAM policy document. Behavior varies by credential_type. When credential_type is
iam_user, then it will attach the contents of the policy_document to the IAM
user generated. When credential_type is assumed_role or federation_token, this
will be passed in as the Policy parameter to the AssumeRole or
GetFederationToken API call, acting as a filter on permissions available.`,
},
"arn": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Deprecated; use role_arns or policy_arns instead. ARN Reference to a managed policy
or IAM role to assume`,
},
"policy": &framework.FieldSchema{
Type: framework.TypeString,
Description: "IAM policy document",
Description: "Deprecated; use policy_document instead. IAM policy document",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.DeleteOperation: pathRolesDelete,
logical.ReadOperation: pathRolesRead,
logical.UpdateOperation: pathRolesWrite,
logical.DeleteOperation: b.pathRolesDelete,
logical.ReadOperation: b.pathRolesRead,
logical.UpdateOperation: b.pathRolesWrite,
},
HelpSynopsis: pathRolesHelpSyn,
@ -58,24 +85,33 @@ func pathRoles() *framework.Path {
}
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List(ctx, "policy/")
b.roleMutex.RLock()
defer b.roleMutex.RUnlock()
entries, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func pathRolesDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete(ctx, "policy/"+d.Get("name").(string))
legacyEntries, err := req.Storage.List(ctx, "policy/")
if err != nil {
return nil, err
}
return logical.ListResponse(append(entries, legacyEntries...)), nil
}
func (b *backend) pathRolesDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
for _, prefix := range []string{"policy/", "role/"} {
err := req.Storage.Delete(ctx, prefix+d.Get("name").(string))
if err != nil {
return nil, err
}
}
return nil, nil
}
func pathRolesRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get(ctx, "policy/"+d.Get("name").(string))
func (b *backend) pathRolesRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entry, err := b.roleRead(ctx, req.Storage, d.Get("name").(string), true)
if err != nil {
return nil, err
}
@ -83,69 +119,295 @@ func pathRolesRead(ctx context.Context, req *logical.Request, d *framework.Field
return nil, nil
}
val := string(entry.Value)
if strings.HasPrefix(val, "arn:") {
return &logical.Response{
Data: map[string]interface{}{
"arn": val,
},
}, nil
}
return &logical.Response{
Data: map[string]interface{}{
"policy": val,
},
Data: entry.toResponseData(),
}, nil
}
func useInlinePolicy(d *framework.FieldData) (bool, error) {
bp := d.Get("policy").(string) != ""
ba := d.Get("arn").(string) != ""
func legacyRoleData(d *framework.FieldData) (string, error) {
policy := d.Get("policy").(string)
arn := d.Get("arn").(string)
if !bp && !ba {
return false, errors.New("either policy or arn must be provided")
switch {
case policy == "" && arn == "":
return "", nil
case policy != "" && arn != "":
return "", errors.New("only one of policy or arn should be provided")
case policy != "":
return policy, nil
default:
return arn, nil
}
if bp && ba {
return false, errors.New("only one of policy or arn should be provided")
}
return bp, nil
}
func pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
var buf bytes.Buffer
func (b *backend) pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
var resp logical.Response
uip, err := useInlinePolicy(d)
roleName := d.Get("name").(string)
if roleName == "" {
return logical.ErrorResponse("missing role name"), nil
}
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
roleEntry, err := b.roleRead(ctx, req.Storage, roleName, false)
if err != nil {
return nil, err
}
if roleEntry == nil {
roleEntry = &awsRoleEntry{}
} else if roleEntry.InvalidData != "" {
resp.AddWarning(fmt.Sprintf("Invalid data of %q cleared out of role", roleEntry.InvalidData))
roleEntry.InvalidData = ""
}
legacyRole, err := legacyRoleData(d)
if err != nil {
return nil, err
}
if uip {
if err := json.Compact(&buf, []byte(d.Get("policy").(string))); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error compacting policy: %s", err)), nil
if credentialTypeRaw, ok := d.GetOk("credential_type"); ok {
if legacyRole != "" {
return logical.ErrorResponse("cannot supply deprecated role or policy parameters with an explicit credential_type"), nil
}
// Write the policy into storage
err := req.Storage.Put(ctx, &logical.StorageEntry{
Key: "policy/" + d.Get("name").(string),
Value: buf.Bytes(),
})
credentialType := credentialTypeRaw.(string)
allowedCredentialTypes := []string{iamUserCred, assumedRoleCred, federationTokenCred}
if !strutil.StrListContains(allowedCredentialTypes, credentialType) {
return logical.ErrorResponse(fmt.Sprintf("unrecognized credential_type: %q, not one of %#v", credentialType, allowedCredentialTypes)), nil
}
roleEntry.CredentialTypes = []string{credentialType}
}
if roleArnsRaw, ok := d.GetOk("role_arns"); ok {
if legacyRole != "" {
return logical.ErrorResponse("cannot supply deprecated role or policy parameters with role_arns"), nil
}
roleEntry.RoleArns = roleArnsRaw.([]string)
}
if policyArnsRaw, ok := d.GetOk("policy_arns"); ok {
if legacyRole != "" {
return logical.ErrorResponse("cannot supply deprecated role or policy parameters with policy_arns"), nil
}
roleEntry.PolicyArns = policyArnsRaw.([]string)
}
if policyDocumentRaw, ok := d.GetOk("policy_document"); ok {
if legacyRole != "" {
return logical.ErrorResponse("cannot supply deprecated role or policy parameters with policy_document"), nil
}
compacted := policyDocumentRaw.(string)
if len(compacted) > 0 {
compacted, err = compactJSON(policyDocumentRaw.(string))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("cannot parse policy document: %q", policyDocumentRaw.(string))), nil
}
}
roleEntry.PolicyDocument = compacted
}
if legacyRole != "" {
roleEntry = upgradeLegacyPolicyEntry(legacyRole)
if roleEntry.InvalidData != "" {
return logical.ErrorResponse(fmt.Sprintf("unable to parse supplied data: %q", roleEntry.InvalidData)), nil
}
resp.AddWarning("Detected use of legacy role or policy paramemter. Please upgrade to use the new parameters.")
} else {
roleEntry.ProhibitFlexibleCredPath = false
}
if len(roleEntry.CredentialTypes) == 0 {
return logical.ErrorResponse("did not supply credential_type"), nil
}
if len(roleEntry.RoleArns) > 0 && !strutil.StrListContains(roleEntry.CredentialTypes, assumedRoleCred) {
return logical.ErrorResponse(fmt.Sprintf("cannot supply role_arns when credential_type isn't %s", assumedRoleCred)), nil
}
if len(roleEntry.PolicyArns) > 0 && !strutil.StrListContains(roleEntry.CredentialTypes, iamUserCred) {
return logical.ErrorResponse(fmt.Sprintf("cannot supply policy_arns when credential_type isn't %s", iamUserCred)), nil
}
err = setAwsRole(ctx, req.Storage, roleName, roleEntry)
if err != nil {
return nil, err
}
return &resp, nil
}
func (b *backend) roleRead(ctx context.Context, s logical.Storage, roleName string, shouldLock bool) (*awsRoleEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role name")
}
if shouldLock {
b.roleMutex.RLock()
}
entry, err := s.Get(ctx, "role/"+roleName)
if shouldLock {
b.roleMutex.RUnlock()
}
if err != nil {
return nil, err
}
var roleEntry awsRoleEntry
if entry != nil {
if err := entry.DecodeJSON(&roleEntry); err != nil {
return nil, err
}
return &roleEntry, nil
}
if shouldLock {
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
}
entry, err = s.Get(ctx, "role/"+roleName)
if err != nil {
return nil, err
}
if entry != nil {
if err := entry.DecodeJSON(&roleEntry); err != nil {
return nil, err
}
return &roleEntry, nil
}
legacyEntry, err := s.Get(ctx, "policy/"+roleName)
if err != nil {
return nil, err
}
if legacyEntry == nil {
return nil, nil
}
newRoleEntry := upgradeLegacyPolicyEntry(string(legacyEntry.Value))
if b.System().LocalMount() || !b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary) {
err = setAwsRole(ctx, s, roleName, newRoleEntry)
if err != nil {
return nil, err
}
} else {
// Write the arn ref into storage
err := req.Storage.Put(ctx, &logical.StorageEntry{
Key: "policy/" + d.Get("name").(string),
Value: []byte(d.Get("arn").(string)),
})
// This can leave legacy data around in the policy/ path if it fails for some reason,
// but should be pretty rare for this to fail but prior writes to succeed, so not worrying
// about cleaning it up in case of error
err = s.Delete(ctx, "policy/"+roleName)
if err != nil {
return nil, err
}
}
return nil, nil
return newRoleEntry, nil
}
func upgradeLegacyPolicyEntry(entry string) *awsRoleEntry {
var newRoleEntry *awsRoleEntry
if strings.HasPrefix(entry, "arn:") {
parsedArn, err := arn.Parse(entry)
if err != nil {
newRoleEntry = &awsRoleEntry{
InvalidData: entry,
Version: 1,
}
return newRoleEntry
}
resourceParts := strings.Split(parsedArn.Resource, "/")
resourceType := resourceParts[0]
switch resourceType {
case "role":
newRoleEntry = &awsRoleEntry{
CredentialTypes: []string{assumedRoleCred},
RoleArns: []string{entry},
ProhibitFlexibleCredPath: true,
Version: 1,
}
case "policy":
newRoleEntry = &awsRoleEntry{
CredentialTypes: []string{iamUserCred},
PolicyArns: []string{entry},
ProhibitFlexibleCredPath: true,
Version: 1,
}
default:
newRoleEntry = &awsRoleEntry{
InvalidData: entry,
Version: 1,
}
}
} else {
compacted, err := compactJSON(entry)
if err != nil {
newRoleEntry = &awsRoleEntry{
InvalidData: entry,
Version: 1,
}
} else {
// unfortunately, this is ambiguous between the cred types, so allow both
newRoleEntry = &awsRoleEntry{
CredentialTypes: []string{iamUserCred, federationTokenCred},
PolicyDocument: compacted,
ProhibitFlexibleCredPath: true,
Version: 1,
}
}
}
return newRoleEntry
}
func setAwsRole(ctx context.Context, s logical.Storage, roleName string, roleEntry *awsRoleEntry) error {
if roleName == "" {
return fmt.Errorf("empty role name")
}
if roleEntry == nil {
return fmt.Errorf("nil roleEntry")
}
entry, err := logical.StorageEntryJSON("role/"+roleName, roleEntry)
if err != nil {
return err
}
if entry == nil {
return fmt.Errorf("nil result when writing to storage")
}
if err := s.Put(ctx, entry); err != nil {
return err
}
return nil
}
type awsRoleEntry struct {
CredentialTypes []string `json:"credential_types"` // Entries must all be in the set of ("iam_user", "assumed_role", "federation_token")
PolicyArns []string `json:"policy_arns"` // ARNs of managed policies to attach to an IAM user
RoleArns []string `json:"role_arns"` // ARNs of roles to assume for AssumedRole credentials
PolicyDocument string `json:"policy_document"` // JSON-serialized inline policy to attach to IAM users and/or to specify as the Policy parameter in AssumeRole calls
InvalidData string `json:"invalid_data,omitempty"` // Invalid role data. Exists to support converting the legacy role data into the new format
ProhibitFlexibleCredPath bool `json:"prohibit_flexible_cred_path,omitempty"` // Disallow accessing STS credentials via the creds path and vice verse
Version int `json:"version"` // Version number of the role format
}
func (r *awsRoleEntry) toResponseData() map[string]interface{} {
respData := map[string]interface{}{
"credential_types": r.CredentialTypes,
"policy_arns": r.PolicyArns,
"role_arns": r.RoleArns,
"policy_document": r.PolicyDocument,
}
if r.InvalidData != "" {
respData["invalid_data"] = r.InvalidData
}
return respData
}
func compactJSON(input string) (string, error) {
var compacted bytes.Buffer
err := json.Compact(&compacted, []byte(input))
return compacted.String(), err
}
const (
assumedRoleCred = "assumed_role"
iamUserCred = "iam_user"
federationTokenCred = "federation_token"
)
const pathListRolesHelpSyn = `List the existing roles in this backend`
const pathListRolesHelpDesc = `Roles will be listed by the role name.`

View File

@ -2,6 +2,7 @@ package aws
import (
"context"
"reflect"
"strconv"
"testing"
@ -20,7 +21,8 @@ func TestBackend_PathListRoles(t *testing.T) {
}
roleData := map[string]interface{}{
"arn": "testarn",
"role_arns": []string{"arn:aws:iam::123456789012:role/path/RoleName"},
"credential_type": assumedRoleCred,
}
roleReq := &logical.Request{
@ -63,3 +65,90 @@ func TestBackend_PathListRoles(t *testing.T) {
t.Fatalf("failed to list all 10 roles")
}
}
func TestUpgradeLegacyPolicyEntry(t *testing.T) {
var input string
var expected awsRoleEntry
var output *awsRoleEntry
input = "arn:aws:iam::123456789012:role/path/RoleName"
expected = awsRoleEntry{
CredentialTypes: []string{assumedRoleCred},
RoleArns: []string{input},
ProhibitFlexibleCredPath: true,
Version: 1,
}
output = upgradeLegacyPolicyEntry(input)
if output.InvalidData != "" {
t.Fatalf("bad: error processing upgrade of %q: got invalid data of %v", input, output.InvalidData)
}
if !reflect.DeepEqual(*output, expected) {
t.Fatalf("bad: expected %#v; received %#v", expected, *output)
}
input = "arn:aws:iam::123456789012:policy/MyPolicy"
expected = awsRoleEntry{
CredentialTypes: []string{iamUserCred},
PolicyArns: []string{input},
ProhibitFlexibleCredPath: true,
Version: 1,
}
output = upgradeLegacyPolicyEntry(input)
if output.InvalidData != "" {
t.Fatalf("bad: error processing upgrade of %q: got invalid data of %v", input, output.InvalidData)
}
if !reflect.DeepEqual(*output, expected) {
t.Fatalf("bad: expected %#v; received %#v", expected, *output)
}
input = "arn:aws:iam::aws:policy/AWSManagedPolicy"
expected.PolicyArns = []string{input}
output = upgradeLegacyPolicyEntry(input)
if output.InvalidData != "" {
t.Fatalf("bad: error processing upgrade of %q: got invalid data of %v", input, output.InvalidData)
}
if !reflect.DeepEqual(*output, expected) {
t.Fatalf("bad: expected %#v; received %#v", expected, *output)
}
input = `
{
"Version": "2012-10-07",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:Describe*",
"Resource": "*"
}
]
}`
compacted, err := compactJSON(input)
if err != nil {
t.Fatalf("error parsing JSON: %v", err)
}
expected = awsRoleEntry{
CredentialTypes: []string{iamUserCred, federationTokenCred},
PolicyDocument: compacted,
ProhibitFlexibleCredPath: true,
Version: 1,
}
output = upgradeLegacyPolicyEntry(input)
if output.InvalidData != "" {
t.Fatalf("bad: error processing upgrade of %q: got invalid data of %v", input, output.InvalidData)
}
if !reflect.DeepEqual(*output, expected) {
t.Fatalf("bad: expected %#v; received %#v", expected, *output)
}
// Due to lack of prior input validation, this could exist in the storage, and we need
// to be able to read it out in some fashion, so have to handle this in a poor fashion
input = "arn:gobbledygook"
expected = awsRoleEntry{
InvalidData: input,
Version: 1,
}
output = upgradeLegacyPolicyEntry(input)
if !reflect.DeepEqual(*output, expected) {
t.Fatalf("bad: expected %#v; received %#v", expected, *output)
}
}

View File

@ -1,95 +0,0 @@
package aws
import (
"context"
"fmt"
"strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathSTS(b *backend) *framework.Path {
return &framework.Path{
Pattern: "sts/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the role",
},
"ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `Lifetime of the token in seconds.
AWS documentation excerpt: The duration, in seconds, that the credentials
should remain valid. Acceptable durations for IAM user sessions range from 900
seconds (15 minutes) to 129600 seconds (36 hours), with 43200 seconds (12
hours) as the default. Sessions for AWS account owners are restricted to a
maximum of 3600 seconds (one hour). If the duration is longer than one hour,
the session for AWS account owners defaults to one hour.`,
Default: 3600,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathSTSRead,
logical.UpdateOperation: b.pathSTSRead,
},
HelpSynopsis: pathSTSHelpSyn,
HelpDescription: pathSTSHelpDesc,
}
}
func (b *backend) pathSTSRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
policyName := d.Get("name").(string)
ttl := int64(d.Get("ttl").(int))
// Read the policy
policy, err := req.Storage.Get(ctx, "policy/"+policyName)
if err != nil {
return nil, errwrap.Wrapf("error retrieving role: {{err}}", err)
}
if policy == nil {
return logical.ErrorResponse(fmt.Sprintf(
"Role '%s' not found", policyName)), nil
}
policyValue := string(policy.Value)
if strings.HasPrefix(policyValue, "arn:") {
if strings.Contains(policyValue, ":role/") {
return b.assumeRole(
ctx,
req.Storage,
req.DisplayName, policyName, policyValue,
ttl,
)
} else {
return logical.ErrorResponse(
"Can't generate STS credentials for a managed policy; use a role to assume or an inline policy instead"),
logical.ErrInvalidRequest
}
}
// Use the helper to create the secret
return b.secretTokenCreate(
ctx,
req.Storage,
req.DisplayName, policyName, policyValue,
ttl,
)
}
const pathSTSHelpSyn = `
Generate an access key pair + security token for a specific role.
`
const pathSTSHelpDesc = `
This path will generate a new, never before used key pair + security token for
accessing AWS. The IAM policy used to back this key pair will be
the "name" parameter. For example, if this backend is mounted at "aws",
then "aws/sts/deploy" would generate access keys for the "deploy" role.
Note, these credentials are instantiated using the AWS STS backend.
The access keys will have a lease associated with them, but revoking the lease
does not revoke the access keys.
`

View File

@ -3,10 +3,12 @@ package aws
import (
"context"
"fmt"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/mitchellh/mapstructure"
@ -14,16 +16,26 @@ import (
func pathUser(b *backend) *framework.Path {
return &framework.Path{
Pattern: "creds/" + framework.GenericNameRegex("name"),
Pattern: "(creds|sts)/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the role",
},
"role_arn": &framework.FieldSchema{
Type: framework.TypeString,
Description: "ARN of role to assume when credential_type is " + assumedRoleCred,
},
"ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: "Lifetime of the returned credentials in seconds",
Default: 3600,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathUserRead,
logical.ReadOperation: b.pathCredsRead,
logical.UpdateOperation: b.pathCredsRead,
},
HelpSynopsis: pathUserHelpSyn,
@ -31,22 +43,72 @@ func pathUser(b *backend) *framework.Path {
}
}
func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
policyName := d.Get("name").(string)
func (b *backend) pathCredsRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
roleName := d.Get("name").(string)
// Read the policy
policy, err := req.Storage.Get(ctx, "policy/"+policyName)
role, err := b.roleRead(ctx, req.Storage, roleName, true)
if err != nil {
return nil, errwrap.Wrapf("error retrieving role: {{err}}", err)
}
if policy == nil {
if role == nil {
return logical.ErrorResponse(fmt.Sprintf(
"Role '%s' not found", policyName)), nil
"Role '%s' not found", roleName)), nil
}
// Use the helper to create the secret
return b.secretAccessKeysCreate(
ctx, req.Storage, req.DisplayName, policyName, string(policy.Value))
ttl := int64(d.Get("ttl").(int))
roleArn := d.Get("role_arn").(string)
var credentialType string
switch {
case len(role.CredentialTypes) == 1:
credentialType = role.CredentialTypes[0]
// There is only one way for the CredentialTypes to contain more than one entry, and that's an upgrade path
// where it contains iamUserCred and federationTokenCred
// This ambiguity can be resolved based on req.Path, so resolve it assuming CredentialTypes only has those values
case len(role.CredentialTypes) > 1:
if strings.HasPrefix(req.Path, "creds") {
credentialType = iamUserCred
} else {
credentialType = federationTokenCred
}
// sanity check on the assumption above
if !strutil.StrListContains(role.CredentialTypes, credentialType) {
return logical.ErrorResponse(fmt.Sprintf("requested credential type %q not in allowed credential types %#v", credentialType, role.CredentialTypes)), nil
}
}
// creds requested through the sts path shouldn't be allowed to get iamUserCred type creds
// when the role is created from legacy data because they might have more privileges in AWS.
// See https://github.com/hashicorp/vault/issues/4229#issuecomment-380316788 for details.
if role.ProhibitFlexibleCredPath {
if credentialType == iamUserCred && strings.HasPrefix(req.Path, "sts") {
return logical.ErrorResponse(fmt.Sprintf("attempted to retrieve %s credentials through the sts path; this is not allowed for legacy roles", iamUserCred)), nil
}
if credentialType != iamUserCred && strings.HasPrefix(req.Path, "creds") {
return logical.ErrorResponse(fmt.Sprintf("attempted to retrieve %s credentials through the creds path; this is not allowed for legacy roles", credentialType)), nil
}
}
switch credentialType {
case iamUserCred:
return b.secretAccessKeysCreate(ctx, req.Storage, req.DisplayName, roleName, role)
case assumedRoleCred:
switch {
case roleArn == "":
if len(role.RoleArns) != 1 {
return logical.ErrorResponse("did not supply a role_arn parameter and unable to determine one"), nil
}
roleArn = role.RoleArns[0]
case !strutil.StrListContains(role.RoleArns, roleArn):
return logical.ErrorResponse(fmt.Sprintf("role_arn %q not in allowed role arns for Vault role %q", roleArn, roleName)), nil
}
return b.assumeRole(ctx, req.Storage, req.DisplayName, roleName, roleArn, role.PolicyDocument, ttl)
case federationTokenCred:
return b.secretTokenCreate(ctx, req.Storage, req.DisplayName, roleName, role.PolicyDocument, ttl)
default:
return logical.ErrorResponse(fmt.Sprintf("unknown credential_type: %q", credentialType)), nil
}
}
func pathUserRollback(ctx context.Context, req *logical.Request, _kind string, data interface{}) error {
@ -161,15 +223,17 @@ type walUser struct {
}
const pathUserHelpSyn = `
Generate an access key pair for a specific role.
Generate AWS credentials from a specific Vault role.
`
const pathUserHelpDesc = `
This path will generate a new, never before used key pair for
This path will generate new, never before used AWS credentials for
accessing AWS. The IAM policy used to back this key pair will be
the "name" parameter. For example, if this backend is mounted at "aws",
then "aws/creds/deploy" would generate access keys for the "deploy" role.
The access keys will have a lease associated with them. The access keys
can be revoked by using the lease ID.
can be revoked by using the lease ID when using the iam_user credential type.
When using AWS STS credential types (assumed_role or federation_token),
revoking the lease does not revoke the access keys.
`

View File

@ -7,8 +7,6 @@ import (
"regexp"
"time"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
@ -112,21 +110,24 @@ func (b *backend) secretTokenCreate(ctx context.Context, s logical.Storage,
}
func (b *backend) assumeRole(ctx context.Context, s logical.Storage,
displayName, policyName, policy string,
displayName, roleName, roleArn, policy string,
lifeTimeInSeconds int64) (*logical.Response, error) {
STSClient, err := clientSTS(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
username, usernameWarning := genUsername(displayName, policyName, "iam_user")
username, usernameWarning := genUsername(displayName, roleName, "iam_user")
tokenResp, err := STSClient.AssumeRole(
&sts.AssumeRoleInput{
RoleSessionName: aws.String(username),
RoleArn: aws.String(policy),
DurationSeconds: &lifeTimeInSeconds,
})
assumeRoleInput := &sts.AssumeRoleInput{
RoleSessionName: aws.String(username),
RoleArn: aws.String(roleArn),
DurationSeconds: &lifeTimeInSeconds,
}
if policy != "" {
assumeRoleInput.SetPolicy(policy)
}
tokenResp, err := STSClient.AssumeRole(assumeRoleInput)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
@ -139,7 +140,7 @@ func (b *backend) assumeRole(ctx context.Context, s logical.Storage,
"security_token": *tokenResp.Credentials.SessionToken,
}, map[string]interface{}{
"username": username,
"policy": policy,
"policy": roleArn,
"is_sts": true,
})
@ -159,7 +160,7 @@ func (b *backend) assumeRole(ctx context.Context, s logical.Storage,
func (b *backend) secretAccessKeysCreate(
ctx context.Context,
s logical.Storage,
displayName, policyName string, policy string) (*logical.Response, error) {
displayName, policyName string, role *awsRoleEntry) (*logical.Response, error) {
client, err := clientIAM(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
@ -187,23 +188,24 @@ func (b *backend) secretAccessKeysCreate(
"Error creating IAM user: %s", err)), nil
}
if strings.HasPrefix(policy, "arn:") {
for _, arn := range role.PolicyArns {
// Attach existing policy against user
_, err = client.AttachUserPolicy(&iam.AttachUserPolicyInput{
UserName: aws.String(username),
PolicyArn: aws.String(policy),
PolicyArn: aws.String(arn),
})
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error attaching user policy: %s", err)), nil
}
} else {
}
if role.PolicyDocument != "" {
// Add new inline user policy against user
_, err = client.PutUserPolicy(&iam.PutUserPolicyInput{
UserName: aws.String(username),
PolicyName: aws.String(policyName),
PolicyDocument: aws.String(policy),
PolicyDocument: aws.String(role.PolicyDocument),
})
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
@ -234,7 +236,7 @@ func (b *backend) secretAccessKeysCreate(
"security_token": nil,
}, map[string]interface{}{
"username": username,
"policy": policy,
"policy": role,
"is_sts": false,
})

View File

@ -77,6 +77,7 @@ func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Bac
return nil
}); err != nil {
cleanup()
t.Fatalf("Could not connect to PostgreSQL docker container: %s", err)
}
@ -899,13 +900,11 @@ func TestBackend_roleCrud(t *testing.T) {
}
}
// Test role modification
// Test role modification of TTL
{
data = map[string]interface{}{
"name": "plugin-role-test",
"rollback_statements": testRole,
"renew_statements": defaultRevocationSQL,
"max_ttl": "7m",
"name": "plugin-role-test",
"max_ttl": "7m",
}
req = &logical.Request{
Operation: logical.UpdateOperation,
@ -944,9 +943,7 @@ func TestBackend_roleCrud(t *testing.T) {
expected := dbplugin.Statements{
Creation: []string{strings.TrimSpace(testRole)},
Rollback: []string{strings.TrimSpace(testRole)},
Revocation: []string{strings.TrimSpace(defaultRevocationSQL)},
Renewal: []string{strings.TrimSpace(defaultRevocationSQL)},
}
actual := dbplugin.Statements{
@ -972,6 +969,79 @@ func TestBackend_roleCrud(t *testing.T) {
}
// Test role modification of statements
{
data = map[string]interface{}{
"name": "plugin-role-test",
"creation_statements": []string{testRole, testRole},
"revocation_statements": []string{defaultRevocationSQL, defaultRevocationSQL},
"rollback_statements": testRole,
"renew_statements": defaultRevocationSQL,
}
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v\n", err, resp)
}
exists, err := b.pathRoleExistenceCheck()(context.Background(), req, &framework.FieldData{
Raw: data,
Schema: pathRoles(b).Fields,
})
if err != nil {
t.Fatal(err)
}
if !exists {
t.Fatal("expected exists")
}
// Read the role
data = map[string]interface{}{}
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "roles/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
expected := dbplugin.Statements{
Creation: []string{strings.TrimSpace(testRole), strings.TrimSpace(testRole)},
Rollback: []string{strings.TrimSpace(testRole)},
Revocation: []string{strings.TrimSpace(defaultRevocationSQL), strings.TrimSpace(defaultRevocationSQL)},
Renewal: []string{strings.TrimSpace(defaultRevocationSQL)},
}
actual := dbplugin.Statements{
Creation: resp.Data["creation_statements"].([]string),
Revocation: resp.Data["revocation_statements"].([]string),
Rollback: resp.Data["rollback_statements"].([]string),
Renewal: resp.Data["renew_statements"].([]string),
}
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("Statements did not match, expected %#v, got %#v", expected, actual)
}
if diff := deep.Equal(resp.Data["db_name"], "plugin-test"); diff != nil {
t.Fatal(diff)
}
if diff := deep.Equal(resp.Data["default_ttl"], float64(300)); diff != nil {
t.Fatal(diff)
}
if diff := deep.Equal(resp.Data["max_ttl"], float64(420)); diff != nil {
t.Fatal(diff)
}
}
// Delete the role
data = map[string]interface{}{}
req = &logical.Request{

View File

@ -210,6 +210,12 @@ func (b *databaseBackend) pathRoleCreateUpdate() framework.OperationFunc {
} else if req.Operation == logical.CreateOperation {
role.Statements.Renewal = data.Get("renew_statements").([]string)
}
// Do not persist deprecated statements that are populated on role read
role.Statements.CreationStatements = ""
role.Statements.RevocationStatements = ""
role.Statements.RenewStatements = ""
role.Statements.RollbackStatements = ""
}
// Store it

View File

@ -9,71 +9,57 @@ import (
"os"
"path"
"reflect"
"sync"
"testing"
"time"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"github.com/lib/pq"
"github.com/mitchellh/mapstructure"
dockertest "gopkg.in/ory-am/dockertest.v2"
"github.com/ory/dockertest"
)
var (
testImagePull sync.Once
)
func prepareTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cid dockertest.ContainerID, retURL string) {
func prepareTestContainer(t *testing.T) (cleanup func(), retURL string) {
if os.Getenv("PG_URL") != "" {
return "", os.Getenv("PG_URL")
return func() {}, os.Getenv("PG_URL")
}
// Without this the checks for whether the container has started seem to
// never actually pass. There's really no reason to expose the test
// containers, so don't.
dockertest.BindDockerToLocalhost = "yep"
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Failed to connect to docker: %s", err)
}
testImagePull.Do(func() {
dockertest.Pull("postgres")
})
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=database"})
if err != nil {
t.Fatalf("Could not start local PostgreSQL docker container: %s", err)
}
cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool {
// This will cause a validation to run
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Storage: s,
Operation: logical.UpdateOperation,
Path: "config/connection",
Data: map[string]interface{}{
"connection_url": connURL,
},
})
if err != nil || (resp != nil && resp.IsError()) {
// It's likely not up and running yet, so return false and try again
return false
cleanup = func() {
err := pool.Purge(resource)
if err != nil {
t.Fatalf("Failed to cleanup local container: %s", err)
}
if resp == nil {
t.Fatal("expected warning")
}
retURL = fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp"))
// exponential backoff-retry
if err = pool.Retry(func() error {
var err error
var db *sql.DB
db, err = sql.Open("postgres", retURL)
if err != nil {
return err
}
retURL = connURL
return true
})
if connErr != nil {
t.Fatalf("could not connect to database: %v", connErr)
defer db.Close()
return db.Ping()
}); err != nil {
cleanup()
t.Fatalf("Could not connect to PostgreSQL docker container: %s", err)
}
return
}
func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) {
err := cid.KillRemove()
if err != nil {
t.Fatal(err)
}
}
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
@ -123,14 +109,12 @@ func TestBackend_basic(t *testing.T) {
t.Fatal(err)
}
cid, connURL := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
cleanup, connURL := prepareTestContainer(t)
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{
@ -149,14 +133,12 @@ func TestBackend_roleCrud(t *testing.T) {
t.Fatal(err)
}
cid, connURL := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
cleanup, connURL := prepareTestContainer(t)
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{
@ -177,14 +159,12 @@ func TestBackend_BlockStatements(t *testing.T) {
t.Fatal(err)
}
cid, connURL := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
cleanup, connURL := prepareTestContainer(t)
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
jsonBlockStatement, err := json.Marshal(testBlockStatementRoleSlice)
if err != nil {
t.Fatal(err)
@ -209,14 +189,12 @@ func TestBackend_roleReadOnly(t *testing.T) {
t.Fatal(err)
}
cid, connURL := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
cleanup, connURL := prepareTestContainer(t)
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{
@ -242,14 +220,12 @@ func TestBackend_roleReadOnly_revocationSQL(t *testing.T) {
t.Fatal(err)
}
cid, connURL := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
cleanup, connURL := prepareTestContainer(t)
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{

390
command/agent.go Normal file
View File

@ -0,0 +1,390 @@
package command
import (
"context"
"fmt"
"io"
"os"
"sort"
"strings"
"sync"
"github.com/kr/pretty"
"github.com/mitchellh/cli"
"github.com/posener/complete"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/command/agent/auth"
"github.com/hashicorp/vault/command/agent/auth/aws"
"github.com/hashicorp/vault/command/agent/auth/azure"
"github.com/hashicorp/vault/command/agent/auth/gcp"
"github.com/hashicorp/vault/command/agent/auth/jwt"
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
"github.com/hashicorp/vault/command/agent/config"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
"github.com/hashicorp/vault/helper/gated-writer"
"github.com/hashicorp/vault/helper/logging"
"github.com/hashicorp/vault/version"
)
var _ cli.Command = (*AgentCommand)(nil)
var _ cli.CommandAutocomplete = (*AgentCommand)(nil)
type AgentCommand struct {
*BaseCommand
ShutdownCh chan struct{}
SighupCh chan struct{}
logWriter io.Writer
logGate *gatedwriter.Writer
logger log.Logger
cleanupGuard sync.Once
startedCh chan (struct{}) // for tests
flagConfigs []string
flagLogLevel string
flagTestVerifyOnly bool
flagCombineLogs bool
}
func (c *AgentCommand) Synopsis() string {
return "Start a Vault agent"
}
func (c *AgentCommand) Help() string {
helpText := `
Usage: vault agent [options]
This command starts a Vault agent that can perform automatic authentication
in certain environments.
Start an agent with a configuration file:
$ vault agent -config=/etc/vault/config.hcl
For a full list of examples, please see the documentation.
` + c.Flags().Help()
return strings.TrimSpace(helpText)
}
func (c *AgentCommand) Flags() *FlagSets {
set := c.flagSet(FlagSetHTTP)
f := set.NewFlagSet("Command Options")
f.StringSliceVar(&StringSliceVar{
Name: "config",
Target: &c.flagConfigs,
Completion: complete.PredictOr(
complete.PredictFiles("*.hcl"),
complete.PredictFiles("*.json"),
),
Usage: "Path to a configuration file. This configuration file should " +
"contain only agent directives.",
})
f.StringVar(&StringVar{
Name: "log-level",
Target: &c.flagLogLevel,
Default: "info",
EnvVar: "VAULT_LOG_LEVEL",
Completion: complete.PredictSet("trace", "debug", "info", "warn", "err"),
Usage: "Log verbosity level. Supported values (in order of detail) are " +
"\"trace\", \"debug\", \"info\", \"warn\", and \"err\".",
})
// Internal-only flags to follow.
//
// Why hello there little source code reader! Welcome to the Vault source
// code. The remaining options are intentionally undocumented and come with
// no warranty or backwards-compatability promise. Do not use these flags
// in production. Do not build automation using these flags. Unless you are
// developing against Vault, you should not need any of these flags.
// TODO: should the below flags be public?
f.BoolVar(&BoolVar{
Name: "combine-logs",
Target: &c.flagCombineLogs,
Default: false,
Hidden: true,
})
f.BoolVar(&BoolVar{
Name: "test-verify-only",
Target: &c.flagTestVerifyOnly,
Default: false,
Hidden: true,
})
// End internal-only flags.
return set
}
func (c *AgentCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictNothing
}
func (c *AgentCommand) AutocompleteFlags() complete.Flags {
return c.Flags().Completions()
}
func (c *AgentCommand) Run(args []string) int {
f := c.Flags()
if err := f.Parse(args); err != nil {
c.UI.Error(err.Error())
return 1
}
// Create a logger. We wrap it in a gated writer so that it doesn't
// start logging too early.
c.logGate = &gatedwriter.Writer{Writer: os.Stderr}
c.logWriter = c.logGate
if c.flagCombineLogs {
c.logWriter = os.Stdout
}
var level log.Level
c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel))
switch c.flagLogLevel {
case "trace":
level = log.Trace
case "debug":
level = log.Debug
case "notice", "info", "":
level = log.Info
case "warn", "warning":
level = log.Warn
case "err", "error":
level = log.Error
default:
c.UI.Error(fmt.Sprintf("Unknown log level: %s", c.flagLogLevel))
return 1
}
if c.logger == nil {
c.logger = logging.NewVaultLoggerWithWriter(c.logWriter, level)
}
// Validation
if len(c.flagConfigs) != 1 {
c.UI.Error("Must specify exactly one config path using -config")
return 1
}
// Load the configuration
config, err := config.LoadConfig(c.flagConfigs[0], c.logger)
if err != nil {
c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err))
return 1
}
// Ensure at least one config was found.
if config == nil {
c.UI.Output(wrapAtLength(
"No configuration read. Please provide the configuration with the " +
"-config flag."))
return 1
}
if config.AutoAuth == nil {
c.UI.Error("No auto_auth block found in config file")
return 1
}
infoKeys := make([]string, 0, 10)
info := make(map[string]string)
info["log level"] = c.flagLogLevel
infoKeys = append(infoKeys, "log level")
infoKeys = append(infoKeys, "version")
verInfo := version.GetVersion()
info["version"] = verInfo.FullVersionNumber(false)
if verInfo.Revision != "" {
info["version sha"] = strings.Trim(verInfo.Revision, "'")
infoKeys = append(infoKeys, "version sha")
}
infoKeys = append(infoKeys, "cgo")
info["cgo"] = "disabled"
if version.CgoEnabled {
info["cgo"] = "enabled"
}
// Server configuration output
padding := 24
sort.Strings(infoKeys)
c.UI.Output("==> Vault agent configuration:\n")
for _, k := range infoKeys {
c.UI.Output(fmt.Sprintf(
"%s%s: %s",
strings.Repeat(" ", padding-len(k)),
strings.Title(k),
info[k]))
}
c.UI.Output("")
// Tests might not want to start a vault server and just want to verify
// the configuration.
if c.flagTestVerifyOnly {
if os.Getenv("VAULT_TEST_VERIFY_ONLY_DUMP_CONFIG") != "" {
c.UI.Output(fmt.Sprintf(
"\nConfiguration:\n%s\n",
pretty.Sprint(*config)))
}
return 0
}
client, err := c.Client()
if err != nil {
c.UI.Error(fmt.Sprintf(
"Error fetching client: %v",
err))
return 1
}
ctx, cancelFunc := context.WithCancel(context.Background())
var sinks []*sink.SinkConfig
for _, sc := range config.AutoAuth.Sinks {
switch sc.Type {
case "file":
config := &sink.SinkConfig{
Logger: c.logger.Named("sink.file"),
Config: sc.Config,
Client: client,
WrapTTL: sc.WrapTTL,
DHType: sc.DHType,
DHPath: sc.DHPath,
AAD: sc.AAD,
}
s, err := file.NewFileSink(config)
if err != nil {
c.UI.Error(errwrap.Wrapf("Error creating file sink: {{err}}", err).Error())
return 1
}
config.Sink = s
sinks = append(sinks, config)
default:
c.UI.Error(fmt.Sprintf("Unknown sink type %q", sc.Type))
return 1
}
}
var method auth.AuthMethod
authConfig := &auth.AuthConfig{
Logger: c.logger.Named(fmt.Sprintf("auth.%s", config.AutoAuth.Method.Type)),
MountPath: config.AutoAuth.Method.MountPath,
WrapTTL: config.AutoAuth.Method.WrapTTL,
Config: config.AutoAuth.Method.Config,
}
switch config.AutoAuth.Method.Type {
case "aws":
method, err = aws.NewAWSAuthMethod(authConfig)
case "azure":
method, err = azure.NewAzureAuthMethod(authConfig)
case "gcp":
method, err = gcp.NewGCPAuthMethod(authConfig)
case "jwt":
method, err = jwt.NewJWTAuthMethod(authConfig)
case "kubernetes":
method, err = kubernetes.NewKubernetesAuthMethod(authConfig)
default:
c.UI.Error(fmt.Sprintf("Unknown auth method %q", config.AutoAuth.Method.Type))
return 1
}
if err != nil {
c.UI.Error(errwrap.Wrapf(fmt.Sprintf("Error creating %s auth method: {{err}}", config.AutoAuth.Method.Type), err).Error())
return 1
}
// Output the header that the server has started
if !c.flagCombineLogs {
c.UI.Output("==> Vault server started! Log data will stream in below:\n")
}
// Inform any tests that the server is ready
select {
case c.startedCh <- struct{}{}:
default:
}
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: c.logger.Named("sink.server"),
Client: client,
ExitAfterAuth: config.ExitAfterAuth,
})
ah := auth.NewAuthHandler(&auth.AuthHandlerConfig{
Logger: c.logger.Named("auth.handler"),
Client: c.client,
})
// Start things running
go ah.Run(ctx, method)
go ss.Run(ctx, ah.OutputCh, sinks)
// Release the log gate.
c.logGate.Flush()
// Write out the PID to the file now that server has successfully started
if err := c.storePidFile(config.PidFile); err != nil {
c.UI.Error(fmt.Sprintf("Error storing PID: %s", err))
return 1
}
defer func() {
if err := c.removePidFile(config.PidFile); err != nil {
c.UI.Error(fmt.Sprintf("Error deleting the PID file: %s", err))
}
}()
select {
case <-ss.DoneCh:
// This will happen if we exit-on-auth
c.logger.Info("sinks finished, exiting")
case <-c.ShutdownCh:
c.UI.Output("==> Vault agent shutdown triggered")
cancelFunc()
<-ah.DoneCh
<-ss.DoneCh
}
return 0
}
// storePidFile is used to write out our PID to a file if necessary
func (c *AgentCommand) storePidFile(pidPath string) error {
// Quit fast if no pidfile
if pidPath == "" {
return nil
}
// Open the PID file
pidFile, err := os.OpenFile(pidPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return errwrap.Wrapf("could not open pid file: {{err}}", err)
}
defer pidFile.Close()
// Write out the PID
pid := os.Getpid()
_, err = pidFile.WriteString(fmt.Sprintf("%d", pid))
if err != nil {
return errwrap.Wrapf("could not write to pid file: {{err}}", err)
}
return nil
}
// removePidFile is used to cleanup the PID file if necessary
func (c *AgentCommand) removePidFile(pidPath string) error {
if pidPath == "" {
return nil
}
return os.Remove(pidPath)
}

215
command/agent/auth/auth.go Normal file
View File

@ -0,0 +1,215 @@
package auth
import (
"context"
"math/rand"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/jsonutil"
)
type AuthMethod interface {
Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error)
NewCreds() chan struct{}
CredSuccess()
Shutdown()
}
type AuthConfig struct {
Logger hclog.Logger
MountPath string
WrapTTL time.Duration
Config map[string]interface{}
}
// AuthHandler is responsible for keeping a token alive and renewed and passing
// new tokens to the sink server
type AuthHandler struct {
DoneCh chan struct{}
OutputCh chan string
logger hclog.Logger
client *api.Client
random *rand.Rand
wrapTTL time.Duration
}
type AuthHandlerConfig struct {
Logger hclog.Logger
Client *api.Client
WrapTTL time.Duration
}
func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
ah := &AuthHandler{
DoneCh: make(chan struct{}),
OutputCh: make(chan string),
logger: conf.Logger,
client: conf.Client,
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
wrapTTL: conf.WrapTTL,
}
return ah
}
func backoffOrQuit(ctx context.Context, backoff time.Duration) {
select {
case <-time.After(backoff):
case <-ctx.Done():
}
}
func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
if am == nil {
panic("nil auth method")
}
ah.logger.Info("starting auth handler")
defer func() {
ah.logger.Info("auth handler stopped")
close(ah.DoneCh)
}()
credCh := am.NewCreds()
if credCh == nil {
credCh = make(chan struct{})
}
var renewer *api.Renewer
for {
select {
case <-ctx.Done():
am.Shutdown()
return
default:
}
// Create a fresh backoff value
backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second))
ah.logger.Info("authenticating")
path, data, err := am.Authenticate(ctx, ah.client)
if err != nil {
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
clientToUse := ah.client
if ah.wrapTTL > 0 {
wrapClient, err := ah.client.Clone()
if err != nil {
ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
wrapClient.SetWrappingLookupFunc(func(string, string) string {
return ah.wrapTTL.String()
})
clientToUse = wrapClient
}
secret, err := clientToUse.Logical().Write(path, data)
// Check errors/sanity
if err != nil {
ah.logger.Error("error authenticating", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
switch {
case ah.wrapTTL > 0:
if secret.WrapInfo == nil {
ah.logger.Error("authentication returned nil wrap info", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
if secret.WrapInfo.Token == "" {
ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo)
if err != nil {
ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing")
ah.OutputCh <- string(wrappedResp)
am.CredSuccess()
select {
case <-ctx.Done():
ah.logger.Info("shutdown triggered")
return
case <-credCh:
ah.logger.Info("auth method found new credentials, re-authenticating")
continue
}
default:
if secret.Auth == nil {
ah.logger.Error("authentication returned nil auth info", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
if secret.Auth.ClientToken == "" {
ah.logger.Error("authentication returned empty client token", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
ah.logger.Info("authentication successful, sending token to sinks")
ah.OutputCh <- secret.Auth.ClientToken
am.CredSuccess()
}
if renewer != nil {
renewer.Stop()
}
renewer, err = ah.client.NewRenewer(&api.RenewerInput{
Secret: secret,
})
if err != nil {
ah.logger.Error("error creating renewer, backing off and retrying", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
// Start the renewal process
ah.logger.Info("starting renewal process")
go renewer.Renew()
RenewerLoop:
for {
select {
case <-ctx.Done():
ah.logger.Info("shutdown triggered, stopping renewer")
renewer.Stop()
break RenewerLoop
case err := <-renewer.DoneCh():
ah.logger.Info("renewer done channel triggered")
if err != nil {
ah.logger.Error("error renewing token", "error", err)
}
break RenewerLoop
case <-renewer.RenewCh():
ah.logger.Info("renewed auth token")
case <-credCh:
ah.logger.Info("auth method found new credentials, re-authenticating")
break RenewerLoop
}
}
}
}

View File

@ -0,0 +1,100 @@
package auth
import (
"context"
"testing"
"time"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/helper/logging"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
type userpassTestMethod struct{}
func newUserpassTestMethod(t *testing.T, client *api.Client) AuthMethod {
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
Type: "userpass",
Config: api.AuthConfigInput{
DefaultLeaseTTL: "1s",
MaxLeaseTTL: "3s",
},
})
if err != nil {
t.Fatal(err)
}
return &userpassTestMethod{}
}
func (u *userpassTestMethod) Authenticate(_ context.Context, client *api.Client) (string, map[string]interface{}, error) {
_, err := client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
"password": "bar",
})
if err != nil {
return "", nil, err
}
return "auth/userpass/login/foo", map[string]interface{}{
"password": "bar",
}, nil
}
func (u *userpassTestMethod) NewCreds() chan struct{} {
return nil
}
func (u *userpassTestMethod) CredSuccess() {
}
func (u *userpassTestMethod) Shutdown() {
}
func TestAuthHandler(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"userpass": userpass.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
client := cluster.Cores[0].Client
ctx, cancelFunc := context.WithCancel(context.Background())
ah := NewAuthHandler(&AuthHandlerConfig{
Logger: logger.Named("auth.handler"),
Client: client,
})
am := newUserpassTestMethod(t, client)
go ah.Run(ctx, am)
// Consume tokens so we don't block
stopTime := time.Now().Add(5 * time.Second)
closed := false
consumption:
for {
select {
case <-ah.OutputCh:
// Nothing
case <-time.After(stopTime.Sub(time.Now())):
if !closed {
cancelFunc()
closed = true
}
case <-ah.DoneCh:
break consumption
}
}
}

View File

@ -0,0 +1,207 @@
package aws
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io/ioutil"
"net/http"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
awsauth "github.com/hashicorp/vault/builtin/credential/aws"
"github.com/hashicorp/vault/command/agent/auth"
)
const (
typeEC2 = "ec2"
typeIAM = "iam"
identityEndpoint = "http://169.254.169.254/latest/dynamic/instance-identity"
)
type awsMethod struct {
logger hclog.Logger
authType string
nonce string
mountPath string
role string
headerValue string
accessKey string
secretKey string
sessionToken string
}
func NewAWSAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
if conf == nil {
return nil, errors.New("empty config")
}
if conf.Config == nil {
return nil, errors.New("empty config data")
}
a := &awsMethod{
logger: conf.Logger,
mountPath: conf.MountPath,
}
typeRaw, ok := conf.Config["type"]
if !ok {
return nil, errors.New("missing 'type' value")
}
a.authType, ok = typeRaw.(string)
if !ok {
return nil, errors.New("could not convert 'type' config value to string")
}
roleRaw, ok := conf.Config["role"]
if !ok {
return nil, errors.New("missing 'role' value")
}
a.role, ok = roleRaw.(string)
if !ok {
return nil, errors.New("could not convert 'role' config value to string")
}
switch {
case a.role == "":
return nil, errors.New("'role' value is empty")
case a.authType == "":
return nil, errors.New("'type' value is empty")
case a.authType != typeEC2 && a.authType != typeIAM:
return nil, errors.New("'type' value is invalid")
}
accessKeyRaw, ok := conf.Config["access_key"]
if ok {
a.accessKey, ok = accessKeyRaw.(string)
if !ok {
return nil, errors.New("could not convert 'access_key' value into string")
}
}
secretKeyRaw, ok := conf.Config["secret_key"]
if ok {
a.secretKey, ok = secretKeyRaw.(string)
if !ok {
return nil, errors.New("could not convert 'secret_key' value into string")
}
}
sessionTokenRaw, ok := conf.Config["session_token"]
if ok {
a.sessionToken, ok = sessionTokenRaw.(string)
if !ok {
return nil, errors.New("could not convert 'session_token' value into string")
}
}
headerValueRaw, ok := conf.Config["header_value"]
if ok {
a.headerValue, ok = headerValueRaw.(string)
if !ok {
return nil, errors.New("could not convert 'header_value' value into string")
}
}
return a, nil
}
func (a *awsMethod) Authenticate(ctx context.Context, client *api.Client) (retToken string, retData map[string]interface{}, retErr error) {
a.logger.Trace("beginning authentication")
data := make(map[string]interface{})
switch a.authType {
case typeEC2:
client := cleanhttp.DefaultClient()
// Fetch document
{
req, err := http.NewRequest("GET", fmt.Sprintf("%s/document", identityEndpoint), nil)
if err != nil {
retErr = errwrap.Wrapf("error creating request: {{err}}", err)
return
}
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
retErr = errwrap.Wrapf("error fetching instance document: {{err}}", err)
return
}
if resp == nil {
retErr = errors.New("empty response fetching instance document")
return
}
defer resp.Body.Close()
doc, err := ioutil.ReadAll(resp.Body)
if err != nil {
retErr = errwrap.Wrapf("error reading instance document response body: {{err}}", err)
return
}
data["identity"] = base64.StdEncoding.EncodeToString(doc)
}
// Fetch signature
{
req, err := http.NewRequest("GET", fmt.Sprintf("%s/signature", identityEndpoint), nil)
if err != nil {
retErr = errwrap.Wrapf("error creating request: {{err}}", err)
return
}
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
retErr = errwrap.Wrapf("error fetching instance document signature: {{err}}", err)
return
}
if resp == nil {
retErr = errors.New("empty response fetching instance document signature")
return
}
defer resp.Body.Close()
sig, err := ioutil.ReadAll(resp.Body)
if err != nil {
retErr = errwrap.Wrapf("error reading instance document signature response body: {{err}}", err)
return
}
data["signature"] = string(sig)
}
// Add the reauthentication value, if we have one
if a.nonce == "" {
uuid, err := uuid.GenerateUUID()
if err != nil {
retErr = errwrap.Wrapf("error generating uuid for reauthentication value: {{err}}", err)
return
}
a.nonce = uuid
}
data["nonce"] = a.nonce
default:
var err error
data, err = awsauth.GenerateLoginData(a.accessKey, a.secretKey, a.sessionToken, a.headerValue)
if err != nil {
retErr = errwrap.Wrapf("error creating login value: {{err}}", err)
return
}
}
data["role"] = a.role
return fmt.Sprintf("%s/login", a.mountPath), data, nil
}
func (a *awsMethod) NewCreds() chan struct{} {
return nil
}
func (a *awsMethod) CredSuccess() {
}
func (a *awsMethod) Shutdown() {
}

View File

@ -0,0 +1,180 @@
package azure
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/auth"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/useragent"
)
const (
instanceEndpoint = "http://169.254.169.254/metadata/instance"
identityEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
// minimum version 2018-02-01 needed for identity metadata
// regional availability: https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service
apiVersion = "2018-02-01"
)
type azureMethod struct {
logger hclog.Logger
mountPath string
role string
resource string
}
func NewAzureAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
if conf == nil {
return nil, errors.New("empty config")
}
if conf.Config == nil {
return nil, errors.New("empty config data")
}
a := &azureMethod{
logger: conf.Logger,
mountPath: conf.MountPath,
}
roleRaw, ok := conf.Config["role"]
if !ok {
return nil, errors.New("missing 'role' value")
}
a.role, ok = roleRaw.(string)
if !ok {
return nil, errors.New("could not convert 'role' config value to string")
}
resourceRaw, ok := conf.Config["resource"]
if !ok {
return nil, errors.New("missing 'resource' value")
}
a.resource, ok = resourceRaw.(string)
if !ok {
return nil, errors.New("could not convert 'resource' config value to string")
}
switch {
case a.role == "":
return nil, errors.New("'role' value is empty")
case a.resource == "":
return nil, errors.New("'resource' value is empty")
}
return a, nil
}
func (a *azureMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, retData map[string]interface{}, retErr error) {
a.logger.Trace("beginning authentication")
// Fetch instance data
var instance struct {
Compute struct {
Name string
ResourceGroupName string
SubscriptionID string
VMScaleSetName string
}
}
body, err := getMetadataInfo(ctx, instanceEndpoint, "")
if err != nil {
retErr = err
return
}
err = jsonutil.DecodeJSON(body, &instance)
if err != nil {
retErr = errwrap.Wrapf("error parsing instance metadata response: {{err}}", err)
return
}
// Fetch JWT
var identity struct {
AccessToken string `json:"access_token"`
}
body, err = getMetadataInfo(ctx, identityEndpoint, a.resource)
if err != nil {
retErr = err
return
}
err = jsonutil.DecodeJSON(body, &identity)
if err != nil {
retErr = errwrap.Wrapf("error parsing identity metadata response: {{err}}", err)
return
}
// Attempt login
data := map[string]interface{}{
"role": a.role,
"vm_name": instance.Compute.Name,
"vmss_name": instance.Compute.VMScaleSetName,
"resource_group_name": instance.Compute.ResourceGroupName,
"subscription_id": instance.Compute.SubscriptionID,
"jwt": identity.AccessToken,
}
return fmt.Sprintf("%s/login", a.mountPath), data, nil
}
func (a *azureMethod) NewCreds() chan struct{} {
return nil
}
func (a *azureMethod) CredSuccess() {
}
func (a *azureMethod) Shutdown() {
}
func getMetadataInfo(ctx context.Context, endpoint, resource string) ([]byte, error) {
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, err
}
q := req.URL.Query()
q.Add("api-version", apiVersion)
if resource != "" {
q.Add("resource", resource)
}
req.URL.RawQuery = q.Encode()
req.Header.Set("Metadata", "true")
req.Header.Set("User-Agent", useragent.String())
req = req.WithContext(ctx)
client := cleanhttp.DefaultClient()
resp, err := client.Do(req)
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error fetching metadata from %s: {{err}}", endpoint), err)
}
if resp == nil {
return nil, fmt.Errorf("empty response fetching metadata from %s", endpoint)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error reading metadata from %s: {{err}}", endpoint), err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("error response in metadata from %s: %s", endpoint, body)
}
return body, nil
}

View File

@ -0,0 +1,241 @@
package gcp
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"time"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-gcp-common/gcputil"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/auth"
"github.com/hashicorp/vault/helper/parseutil"
"golang.org/x/oauth2"
iam "google.golang.org/api/iam/v1"
)
const (
typeGCE = "gce"
typeIAM = "iam"
identityEndpoint = "http://metadata/computeMetadata/v1/instance/service-accounts/%s/identity"
defaultIamMaxJwtExpMinutes = 15
)
type gcpMethod struct {
logger hclog.Logger
authType string
mountPath string
role string
credentials string
serviceAccount string
project string
jwtExp int64
}
func NewGCPAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
if conf == nil {
return nil, errors.New("empty config")
}
if conf.Config == nil {
return nil, errors.New("empty config data")
}
var err error
g := &gcpMethod{
logger: conf.Logger,
mountPath: conf.MountPath,
serviceAccount: "default",
}
typeRaw, ok := conf.Config["type"]
if !ok {
return nil, errors.New("missing 'type' value")
}
g.authType, ok = typeRaw.(string)
if !ok {
return nil, errors.New("could not convert 'type' config value to string")
}
roleRaw, ok := conf.Config["role"]
if !ok {
return nil, errors.New("missing 'role' value")
}
g.role, ok = roleRaw.(string)
if !ok {
return nil, errors.New("could not convert 'role' config value to string")
}
switch {
case g.role == "":
return nil, errors.New("'role' value is empty")
case g.authType == "":
return nil, errors.New("'type' value is empty")
case g.authType != typeGCE && g.authType != typeIAM:
return nil, errors.New("'type' value is invalid")
}
credentialsRaw, ok := conf.Config["credentials"]
if ok {
g.credentials, ok = credentialsRaw.(string)
if !ok {
return nil, errors.New("could not convert 'credentials' value into string")
}
}
serviceAccountRaw, ok := conf.Config["service_account"]
if ok {
g.serviceAccount, ok = serviceAccountRaw.(string)
if !ok {
return nil, errors.New("could not convert 'service_account' value into string")
}
}
projectRaw, ok := conf.Config["project"]
if ok {
g.project, ok = projectRaw.(string)
if !ok {
return nil, errors.New("could not convert 'project' value into string")
}
}
jwtExpRaw, ok := conf.Config["jwt_exp"]
if ok {
g.jwtExp, err = parseutil.ParseInt(jwtExpRaw)
if err != nil {
return nil, errwrap.Wrapf("error parsing 'jwt_raw' into integer: {{err}}", err)
}
}
return g, nil
}
func (g *gcpMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, retData map[string]interface{}, retErr error) {
g.logger.Trace("beginning authentication")
data := make(map[string]interface{})
var jwt string
switch g.authType {
case typeGCE:
httpClient := cleanhttp.DefaultClient()
// Fetch token
{
req, err := http.NewRequest("GET", fmt.Sprintf(identityEndpoint, g.serviceAccount), nil)
if err != nil {
retErr = errwrap.Wrapf("error creating request: {{err}}", err)
return
}
req = req.WithContext(ctx)
req.Header.Add("Metadata-Flavor", "Google")
q := req.URL.Query()
q.Add("audience", fmt.Sprintf("%s/vault/%s", client.Address(), g.role))
q.Add("format", "full")
req.URL.RawQuery = q.Encode()
resp, err := httpClient.Do(req)
if err != nil {
retErr = errwrap.Wrapf("error fetching instance token: {{err}}", err)
return
}
if resp == nil {
retErr = errors.New("empty response fetching instance toke")
return
}
defer resp.Body.Close()
jwtBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
retErr = errwrap.Wrapf("error reading instance token response body: {{err}}", err)
return
}
jwt = string(jwtBytes)
}
default:
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, cleanhttp.DefaultClient())
credentials, tokenSource, err := gcputil.FindCredentials(g.credentials, ctx, iam.CloudPlatformScope)
if err != nil {
retErr = errwrap.Wrapf("could not obtain credentials: {{err}}", err)
return
}
httpClient := oauth2.NewClient(ctx, tokenSource)
var serviceAccount string
if g.serviceAccount == "" && credentials != nil {
serviceAccount = credentials.ClientEmail
} else {
serviceAccount = g.serviceAccount
}
if serviceAccount == "" {
retErr = errors.New("could not obtain service account from credentials (possibly Application Default Credentials are being used); a service account to authenticate as must be provided")
return
}
project := "-"
if g.project != "" {
project = g.project
} else if credentials != nil {
project = credentials.ProjectId
}
ttlMin := int64(defaultIamMaxJwtExpMinutes)
if g.jwtExp != 0 {
ttlMin = g.jwtExp
}
ttl := time.Minute * time.Duration(ttlMin)
jwtPayload := map[string]interface{}{
"aud": fmt.Sprintf("http://vault/%s", g.role),
"sub": serviceAccount,
"exp": time.Now().Add(ttl).Unix(),
}
payloadBytes, err := json.Marshal(jwtPayload)
if err != nil {
retErr = errwrap.Wrapf("could not convert JWT payload to JSON string: {{err}}", err)
return
}
jwtReq := &iam.SignJwtRequest{
Payload: string(payloadBytes),
}
iamClient, err := iam.New(httpClient)
if err != nil {
retErr = errwrap.Wrapf("could not create IAM client: {{err}}", err)
return
}
resourceName := fmt.Sprintf("projects/%s/serviceAccounts/%s", project, serviceAccount)
resp, err := iamClient.Projects.ServiceAccounts.SignJwt(resourceName, jwtReq).Do()
if err != nil {
retErr = errwrap.Wrapf(fmt.Sprintf("unable to sign JWT for %s using given Vault credentials: {{err}}", resourceName), err)
return
}
jwt = resp.SignedJwt
}
data["role"] = g.role
data["jwt"] = jwt
return fmt.Sprintf("%s/login", g.mountPath), data, nil
}
func (g *gcpMethod) NewCreds() chan struct{} {
return nil
}
func (g *gcpMethod) CredSuccess() {
}
func (g *gcpMethod) Shutdown() {
}

View File

@ -0,0 +1,184 @@
package jwt
import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/auth"
)
type jwtMethod struct {
logger hclog.Logger
path string
mountPath string
role string
credsFound chan struct{}
watchCh chan string
stopCh chan struct{}
doneCh chan struct{}
credSuccessGate chan struct{}
ticker *time.Ticker
once *sync.Once
latestToken *atomic.Value
}
func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
if conf == nil {
return nil, errors.New("empty config")
}
if conf.Config == nil {
return nil, errors.New("empty config data")
}
j := &jwtMethod{
logger: conf.Logger,
mountPath: conf.MountPath,
credsFound: make(chan struct{}),
watchCh: make(chan string),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
credSuccessGate: make(chan struct{}),
once: new(sync.Once),
latestToken: new(atomic.Value),
}
j.latestToken.Store("")
pathRaw, ok := conf.Config["path"]
if !ok {
return nil, errors.New("missing 'path' value")
}
j.path, ok = pathRaw.(string)
if !ok {
return nil, errors.New("could not convert 'path' config value to string")
}
roleRaw, ok := conf.Config["role"]
if !ok {
return nil, errors.New("missing 'role' value")
}
j.role, ok = roleRaw.(string)
if !ok {
return nil, errors.New("could not convert 'role' config value to string")
}
switch {
case j.path == "":
return nil, errors.New("'path' value is empty")
case j.role == "":
return nil, errors.New("'role' value is empty")
}
j.ticker = time.NewTicker(500 * time.Millisecond)
go j.runWatcher()
j.logger.Info("jwt auth method created", "path", j.path)
return j, nil
}
func (j *jwtMethod) Authenticate(_ context.Context, client *api.Client) (string, map[string]interface{}, error) {
j.logger.Trace("beginning authentication")
j.ingressToken()
latestToken := j.latestToken.Load().(string)
if latestToken == "" {
return "", nil, errors.New("latest known jwt is empty, cannot authenticate")
}
return fmt.Sprintf("%s/login", j.mountPath), map[string]interface{}{
"role": j.role,
"jwt": latestToken,
}, nil
}
func (j *jwtMethod) NewCreds() chan struct{} {
return j.credsFound
}
func (j *jwtMethod) CredSuccess() {
j.once.Do(func() {
close(j.credSuccessGate)
})
}
func (j *jwtMethod) Shutdown() {
j.ticker.Stop()
close(j.stopCh)
<-j.doneCh
}
func (j *jwtMethod) runWatcher() {
defer close(j.doneCh)
select {
case <-j.stopCh:
return
case <-j.credSuccessGate:
// We only start the next loop once we're initially successful,
// since at startup Authenticate will be called and we don't want
// to end up immediately reauthenticating by having found a new
// value
}
for {
select {
case <-j.stopCh:
return
case <-j.ticker.C:
latestToken := j.latestToken.Load().(string)
j.ingressToken()
newToken := j.latestToken.Load().(string)
if newToken != latestToken {
j.credsFound <- struct{}{}
}
}
}
}
func (j *jwtMethod) ingressToken() {
fi, err := os.Lstat(j.path)
if err != nil {
if os.IsNotExist(err) {
return
}
j.logger.Error("error encountered stat'ing jwt file", "error", err)
return
}
j.logger.Debug("new jwt file found")
if !fi.Mode().IsRegular() {
j.logger.Error("jwt file is not a regular file")
return
}
token, err := ioutil.ReadFile(j.path)
if err != nil {
j.logger.Error("failed to read jwt file", "error", err)
return
}
switch len(token) {
case 0:
j.logger.Warn("empty jwt file read")
default:
j.latestToken.Store(string(token))
}
if err := os.Remove(j.path); err != nil {
j.logger.Error("error removing jwt file", "error", err)
}
}

View File

@ -0,0 +1,76 @@
package kubernetes
import (
"context"
"errors"
"fmt"
"io/ioutil"
"log"
"strings"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/auth"
)
const (
serviceAccountFile = "/var/run/secrets/kubernetes.io/serviceaccount/token"
)
type kubernetesMethod struct {
logger hclog.Logger
mountPath string
role string
}
func NewKubernetesAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
if conf == nil {
return nil, errors.New("empty config")
}
if conf.Config == nil {
return nil, errors.New("empty config data")
}
k := &kubernetesMethod{
logger: conf.Logger,
mountPath: conf.MountPath,
}
roleRaw, ok := conf.Config["role"]
if !ok {
return nil, errors.New("missing 'role' value")
}
k.role, ok = roleRaw.(string)
if !ok {
return nil, errors.New("could not convert 'role' config value to string")
}
if k.role == "" {
return nil, errors.New("'role' value is empty")
}
return k, nil
}
func (k *kubernetesMethod) Authenticate(ctx context.Context, client *api.Client) (string, map[string]interface{}, error) {
k.logger.Trace("beginning authentication")
content, err := ioutil.ReadFile(serviceAccountFile)
if err != nil {
log.Fatal(err)
}
return fmt.Sprintf("%s/login", k.mountPath), map[string]interface{}{
"role": k.role,
"jwt": strings.TrimSpace(string(content)),
}, nil
}
func (k *kubernetesMethod) NewCreds() chan struct{} {
return nil
}
func (k *kubernetesMethod) CredSuccess() {
}
func (k *kubernetesMethod) Shutdown() {
}

View File

@ -0,0 +1,239 @@
package config
import (
"errors"
"fmt"
"io/ioutil"
"os"
"strings"
"time"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/hcl"
"github.com/hashicorp/hcl/hcl/ast"
)
// Config is the configuration for the vault server.
type Config struct {
AutoAuth *AutoAuth `hcl:"auto_auth"`
ExitAfterAuth bool `hcl:"exit_after_auth"`
PidFile string `hcl:"pid_file"`
}
type AutoAuth struct {
Method *Method `hcl:"-"`
Sinks []*Sink `hcl:"sinks"`
}
type Method struct {
Type string
MountPath string `hcl:"mount_path"`
WrapTTLRaw interface{} `hcl:"wrap_ttl"`
WrapTTL time.Duration `hcl:"-"`
Config map[string]interface{}
}
type Sink struct {
Type string
WrapTTLRaw interface{} `hcl:"wrap_ttl"`
WrapTTL time.Duration `hcl:"-"`
DHType string `hcl:"dh_type"`
DHPath string `hcl:"dh_path"`
AAD string `hcl:"aad"`
AADEnvVar string `hcl:"aad_env_var"`
Config map[string]interface{}
}
// LoadConfig loads the configuration at the given path, regardless if
// its a file or directory.
func LoadConfig(path string, logger log.Logger) (*Config, error) {
fi, err := os.Stat(path)
if err != nil {
return nil, err
}
if fi.IsDir() {
return nil, fmt.Errorf("location is a directory, not a file")
}
// Read the file
d, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
// Parse!
obj, err := hcl.Parse(string(d))
if err != nil {
return nil, err
}
// Start building the result
var result Config
if err := hcl.DecodeObject(&result, obj); err != nil {
return nil, err
}
list, ok := obj.Node.(*ast.ObjectList)
if !ok {
return nil, fmt.Errorf("error parsing: file doesn't contain a root object")
}
if err := parseAutoAuth(&result, list); err != nil {
return nil, errwrap.Wrapf("error parsing 'auto_auth': {{err}}", err)
}
return &result, nil
}
func parseAutoAuth(result *Config, list *ast.ObjectList) error {
name := "auto_auth"
autoAuthList := list.Filter(name)
if len(autoAuthList.Items) != 1 {
return fmt.Errorf("one and only one %q block is required", name)
}
// Get our item
item := autoAuthList.Items[0]
var a AutoAuth
if err := hcl.DecodeObject(&a, item.Val); err != nil {
return err
}
result.AutoAuth = &a
subs, ok := item.Val.(*ast.ObjectType)
if !ok {
return fmt.Errorf("could not parse %q as an object", name)
}
subList := subs.List
if err := parseMethod(result, subList); err != nil {
return errwrap.Wrapf("error parsing 'method': {{err}}", err)
}
if err := parseSinks(result, subList); err != nil {
return errwrap.Wrapf("error parsing 'sink' stanzas: {{err}}", err)
}
switch {
case a.Method == nil:
return fmt.Errorf("no 'method' block found")
case len(a.Sinks) == 0:
return fmt.Errorf("at least one 'sink' block must be provided")
}
return nil
}
func parseMethod(result *Config, list *ast.ObjectList) error {
name := "method"
methodList := list.Filter(name)
if len(methodList.Items) != 1 {
return fmt.Errorf("one and only one %q block is required", name)
}
// Get our item
item := methodList.Items[0]
var m Method
if err := hcl.DecodeObject(&m, item.Val); err != nil {
return err
}
if m.Type == "" {
if len(item.Keys) == 1 {
m.Type = strings.ToLower(item.Keys[0].Token.Value().(string))
}
if m.Type == "" {
return errors.New("method type must be specified")
}
}
// Default to Vault's default
if m.MountPath == "" {
m.MountPath = fmt.Sprintf("auth/%s", m.Type)
}
// Standardize on no trailing slash
m.MountPath = strings.TrimSuffix(m.MountPath, "/")
if m.WrapTTLRaw != nil {
var err error
if m.WrapTTL, err = parseutil.ParseDurationSecond(m.WrapTTLRaw); err != nil {
return err
}
m.WrapTTLRaw = nil
}
result.AutoAuth.Method = &m
return nil
}
func parseSinks(result *Config, list *ast.ObjectList) error {
name := "sink"
sinkList := list.Filter(name)
if len(sinkList.Items) < 1 {
return fmt.Errorf("at least one %q block is required", name)
}
var ts []*Sink
for _, item := range sinkList.Items {
var s Sink
if err := hcl.DecodeObject(&s, item.Val); err != nil {
return err
}
if s.Type == "" {
if len(item.Keys) == 1 {
s.Type = strings.ToLower(item.Keys[0].Token.Value().(string))
}
if s.Type == "" {
return errors.New("sink type must be specified")
}
}
if s.WrapTTLRaw != nil {
var err error
if s.WrapTTL, err = parseutil.ParseDurationSecond(s.WrapTTLRaw); err != nil {
return multierror.Prefix(err, fmt.Sprintf("sink.%s", s.Type))
}
s.WrapTTLRaw = nil
}
switch s.DHType {
case "":
case "curve25519":
default:
return multierror.Prefix(errors.New("invalid value for 'dh_type'"), fmt.Sprintf("sink.%s", s.Type))
}
if s.AADEnvVar != "" {
s.AAD = os.Getenv(s.AADEnvVar)
s.AADEnvVar = ""
}
switch {
case s.DHPath == "" && s.DHType == "":
if s.AAD != "" {
return multierror.Prefix(errors.New("specifying AAD data without 'dh_type' does not make sense"), fmt.Sprintf("sink.%s", s.Type))
}
case s.DHPath != "" && s.DHType != "":
default:
return multierror.Prefix(errors.New("'dh_type' and 'dh_path' must be specified together"), fmt.Sprintf("sink.%s", s.Type))
}
ts = append(ts, &s)
}
result.AutoAuth.Sinks = ts
return nil
}

View File

@ -0,0 +1,71 @@
package config
import (
"os"
"testing"
"time"
"github.com/go-test/deep"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/logging"
)
func TestLoadConfigFile(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)
os.Setenv("TEST_AAD_ENV", "aad")
defer os.Unsetenv("TEST_AAD_ENV")
config, err := LoadConfig("./test-fixtures/config.hcl", logger)
if err != nil {
t.Fatalf("err: %s", err)
}
expected := &Config{
AutoAuth: &AutoAuth{
Method: &Method{
Type: "aws",
WrapTTL: 300 * time.Second,
MountPath: "auth/aws",
Config: map[string]interface{}{
"role": "foobar",
},
},
Sinks: []*Sink{
&Sink{
Type: "file",
DHType: "curve25519",
DHPath: "/tmp/file-foo-dhpath",
AAD: "foobar",
Config: map[string]interface{}{
"path": "/tmp/file-foo",
},
},
&Sink{
Type: "file",
WrapTTL: 5 * time.Minute,
DHType: "curve25519",
DHPath: "/tmp/file-foo-dhpath2",
AAD: "aad",
Config: map[string]interface{}{
"path": "/tmp/file-bar",
},
},
},
},
PidFile: "./pidfile",
}
if diff := deep.Equal(config, expected); diff != nil {
t.Fatal(diff)
}
config, err = LoadConfig("./test-fixtures/config-embedded-type.hcl", logger)
if err != nil {
t.Fatalf("err: %s", err)
}
if diff := deep.Equal(config, expected); diff != nil {
t.Fatal(diff)
}
}

View File

@ -0,0 +1,30 @@
pid_file = "./pidfile"
auto_auth {
method "aws" {
mount_path = "auth/aws"
wrap_ttl = 300
config = {
role = "foobar"
}
}
sink "file" {
config = {
path = "/tmp/file-foo"
}
aad = "foobar"
dh_type = "curve25519"
dh_path = "/tmp/file-foo-dhpath"
}
sink "file" {
wrap_ttl = "5m"
aad_env_var = "TEST_AAD_ENV"
dh_type = "curve25519"
dh_path = "/tmp/file-foo-dhpath2"
config = {
path = "/tmp/file-bar"
}
}
}

View File

@ -0,0 +1,32 @@
pid_file = "./pidfile"
auto_auth {
method {
type = "aws"
wrap_ttl = 300
config = {
role = "foobar"
}
}
sink {
type = "file"
config = {
path = "/tmp/file-foo"
}
aad = "foobar"
dh_type = "curve25519"
dh_path = "/tmp/file-foo-dhpath"
}
sink {
type = "file"
wrap_ttl = "5m"
aad_env_var = "TEST_AAD_ENV"
dh_type = "curve25519"
dh_path = "/tmp/file-foo-dhpath2"
config = {
path = "/tmp/file-bar"
}
}
}

View File

@ -0,0 +1,351 @@
package agent
import (
"context"
"encoding/json"
"io/ioutil"
"os"
"testing"
"time"
hclog "github.com/hashicorp/go-hclog"
vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/auth"
agentjwt "github.com/hashicorp/vault/command/agent/auth/jwt"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
"github.com/hashicorp/vault/helper/dhutil"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logging"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
func TestJWTEndToEnd(t *testing.T) {
testJWTEndToEnd(t, false)
testJWTEndToEnd(t, true)
}
func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"jwt": vaultjwt.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
client := cluster.Cores[0].Client
// Setup Vault
err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{
Type: "jwt",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
"bound_issuer": "https://team-vault.auth0.com/",
"jwt_validation_pubkeys": TestECDSAPubKey,
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
"bound_audiences": "https://vault.plugin.auth.jwt.test",
"user_claim": "https://vault/user",
"groups_claim": "https://vault/groups",
"policies": "test",
"period": "3s",
})
if err != nil {
t.Fatal(err)
}
// Generate encryption params
pub, pri, err := dhutil.GeneratePublicPrivateKey()
if err != nil {
t.Fatal(err)
}
// We close these right away because we're just basically testing
// permissions and finding a usable file name
inf, err := ioutil.TempFile("", "auth.jwt.test.")
if err != nil {
t.Fatal(err)
}
in := inf.Name()
inf.Close()
os.Remove(in)
t.Logf("input: %s", in)
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
if err != nil {
t.Fatal(err)
}
out := ouf.Name()
ouf.Close()
os.Remove(out)
t.Logf("output: %s", out)
dhpathf, err := ioutil.TempFile("", "auth.dhpath.test.")
if err != nil {
t.Fatal(err)
}
dhpath := dhpathf.Name()
dhpathf.Close()
os.Remove(dhpath)
// Write DH public key to file
mPubKey, err := jsonutil.EncodeJSON(&dhutil.PublicKeyInfo{
Curve25519PublicKey: pub,
})
if err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(dhpath, mPubKey, 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote dh param file", "path", dhpath)
}
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
am, err := agentjwt.NewJWTAuthMethod(&auth.AuthConfig{
Logger: logger.Named("auth.jwt"),
MountPath: "auth/jwt",
Config: map[string]interface{}{
"path": in,
"role": "test",
},
})
if err != nil {
t.Fatal(err)
}
ahConfig := &auth.AuthHandlerConfig{
Logger: logger.Named("auth.handler"),
Client: client,
}
if ahWrapping {
ahConfig.WrapTTL = 10 * time.Second
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
defer func() {
<-ah.DoneCh
}()
config := &sink.SinkConfig{
Logger: logger.Named("sink.file"),
AAD: "foobar",
DHType: "curve25519",
DHPath: dhpath,
Config: map[string]interface{}{
"path": out,
},
}
if !ahWrapping {
config.WrapTTL = 10 * time.Second
}
fs, err := file.NewFileSink(config)
if err != nil {
t.Fatal(err)
}
config.Sink = fs
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
// Check that no jwt file exists
_, err = os.Lstat(in)
if err == nil {
t.Fatal("expected err")
}
if !os.IsNotExist(err) {
t.Fatal("expected notexist err")
}
_, err = os.Lstat(out)
if err == nil {
t.Fatal("expected err")
}
if !os.IsNotExist(err) {
t.Fatal("expected notexist err")
}
cloned, err := client.Clone()
if err != nil {
t.Fatal(err)
}
// Get a token
jwtToken, _ := GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test jwt", "path", in)
}
checkToken := func() string {
timeout := time.Now().Add(5 * time.Second)
for {
if time.Now().After(timeout) {
t.Fatal("did not find a written token after timeout")
}
val, err := ioutil.ReadFile(out)
if err == nil {
os.Remove(out)
if len(val) == 0 {
t.Fatal("written token was empty")
}
// First decrypt it
resp := new(dhutil.Envelope)
if err := jsonutil.DecodeJSON(val, resp); err != nil {
continue
}
aesKey, err := dhutil.GenerateSharedKey(pri, resp.Curve25519PublicKey)
if err != nil {
t.Fatal(err)
}
if len(aesKey) == 0 {
t.Fatal("got empty aes key")
}
val, err = dhutil.DecryptAES(aesKey, resp.EncryptedPayload, resp.Nonce, []byte("foobar"))
if err != nil {
t.Fatalf("error: %v\nresp: %v", err, string(val))
}
// Now unwrap it
wrapInfo := new(api.SecretWrapInfo)
if err := jsonutil.DecodeJSON(val, wrapInfo); err != nil {
t.Fatal(err)
}
switch {
case wrapInfo.TTL != 10:
t.Fatalf("bad wrap info: %v", wrapInfo.TTL)
case !ahWrapping && wrapInfo.CreationPath != "sys/wrapping/wrap":
t.Fatalf("bad wrap path: %v", wrapInfo.CreationPath)
case ahWrapping && wrapInfo.CreationPath != "auth/jwt/login":
t.Fatalf("bad wrap path: %v", wrapInfo.CreationPath)
case wrapInfo.Token == "":
t.Fatal("wrap token is empty")
}
cloned.SetToken(wrapInfo.Token)
secret, err := cloned.Logical().Unwrap("")
if err != nil {
t.Fatal(err)
}
if ahWrapping {
switch {
case secret.Auth == nil:
t.Fatal("unwrap secret auth is nil")
case secret.Auth.ClientToken == "":
t.Fatal("unwrap token is nil")
}
return secret.Auth.ClientToken
} else {
switch {
case secret.Data == nil:
t.Fatal("unwrap secret data is nil")
case secret.Data["token"] == nil:
t.Fatal("unwrap token is nil")
}
return secret.Data["token"].(string)
}
}
time.Sleep(250 * time.Millisecond)
}
}
origToken := checkToken()
// We only check this if the renewer is actually renewing for us
if !ahWrapping {
// Period of 3 seconds, so should still be alive after 7
timeout := time.Now().Add(7 * time.Second)
cloned.SetToken(origToken)
for {
if time.Now().After(timeout) {
break
}
secret, err := cloned.Auth().Token().LookupSelf()
if err != nil {
t.Fatal(err)
}
ttl, err := secret.Data["ttl"].(json.Number).Int64()
if err != nil {
t.Fatal(err)
}
if ttl > 3 {
t.Fatalf("unexpected ttl: %v", secret.Data["ttl"])
}
}
}
// Get another token to test the backend pushing the need to authenticate
// to the handler
jwtToken, _ = GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
t.Fatal(err)
}
newToken := checkToken()
if newToken == origToken {
t.Fatal("found same token written")
}
if !ahWrapping {
// Repeat the period test. At the end the old token should have expired and
// the new token should still be alive after 7
timeout := time.Now().Add(7 * time.Second)
cloned.SetToken(newToken)
for {
if time.Now().After(timeout) {
break
}
secret, err := cloned.Auth().Token().LookupSelf()
if err != nil {
t.Fatal(err)
}
ttl, err := secret.Data["ttl"].(json.Number).Int64()
if err != nil {
t.Fatal(err)
}
if ttl > 3 {
t.Fatalf("unexpected ttl: %v", secret.Data["ttl"])
}
}
cloned.SetToken(origToken)
_, err = cloned.Auth().Token().LookupSelf()
if err == nil {
t.Fatal("expected error")
}
}
}

View File

@ -0,0 +1,112 @@
package file
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/hashicorp/errwrap"
hclog "github.com/hashicorp/go-hclog"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/command/agent/sink"
)
// fileSink is a Sink implementation that writes a token to a file
type fileSink struct {
path string
logger hclog.Logger
}
// NewFileSink creates a new file sink with the given configuration
func NewFileSink(conf *sink.SinkConfig) (sink.Sink, error) {
if conf.Logger == nil {
return nil, errors.New("nil logger provided")
}
conf.Logger.Info("creating file sink")
f := &fileSink{
logger: conf.Logger,
}
pathRaw, ok := conf.Config["path"]
if !ok {
return nil, errors.New("'path' not specified for file sink")
}
path, ok := pathRaw.(string)
if !ok {
return nil, errors.New("could not parse 'path' as string")
}
f.path = path
if err := f.WriteToken(""); err != nil {
return nil, errwrap.Wrapf("error during write check: {{err}}", err)
}
f.logger.Info("file sink configured", "path", f.path)
return f, nil
}
// WriteToken implements the Server interface and writes the token to a path on
// disk. It writes into the path's directory into a temp file and does an
// atomic rename to ensure consistency. If a blank token is passed in, it
// performs a write check but does not write a blank value to the final
// location.
func (f *fileSink) WriteToken(token string) error {
f.logger.Trace("enter write_token", "path", f.path)
defer f.logger.Trace("exit write_token", "path", f.path)
u, err := uuid.GenerateUUID()
if err != nil {
return errwrap.Wrapf("error generating a uuid during write check: {{err}}", err)
}
targetDir := filepath.Dir(f.path)
fileName := filepath.Base(f.path)
tmpSuffix := strings.Split(u, "-")[0]
tmpFile, err := os.OpenFile(filepath.Join(targetDir, fmt.Sprintf("%s.tmp.%s", fileName, tmpSuffix)), os.O_WRONLY|os.O_CREATE, 0640)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error opening temp file in dir %s for writing: {{err}}", targetDir), err)
}
valToWrite := token
if token == "" {
valToWrite = u
}
_, err = tmpFile.WriteString(valToWrite)
if err != nil {
// Attempt closing and deleting but ignore any error
tmpFile.Close()
os.Remove(tmpFile.Name())
return errwrap.Wrapf(fmt.Sprintf("error writing to %s: {{err}}", tmpFile.Name()), err)
}
err = tmpFile.Close()
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error closing %s: {{err}}", tmpFile.Name()), err)
}
// Now, if we were just doing a write check (blank token), remove the file
// and exit; otherwise, atomically rename it
if token == "" {
err = os.Remove(tmpFile.Name())
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error removing temp file %s during write check: {{err}}", tmpFile.Name()), err)
}
return nil
}
err = os.Rename(tmpFile.Name(), f.path)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error renaming temp file %s to target file %s: {{err}}", tmpFile.Name(), f.path), err)
}
f.logger.Info("token written", "path", f.path)
return nil
}

View File

@ -0,0 +1,82 @@
package file
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"testing"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/helper/logging"
)
const (
fileServerTestDir = "vault-agent-file-test"
)
func testFileSink(t *testing.T, log hclog.Logger) (*sink.SinkConfig, string) {
tmpDir, err := ioutil.TempDir("", fmt.Sprintf("%s.", fileServerTestDir))
if err != nil {
t.Fatal(err)
}
path := filepath.Join(tmpDir, "token")
config := &sink.SinkConfig{
Logger: log.Named("sink.file"),
Config: map[string]interface{}{
"path": path,
},
}
s, err := NewFileSink(config)
if err != nil {
t.Fatal(err)
}
config.Sink = s
return config, tmpDir
}
func TestFileSink(t *testing.T) {
log := logging.NewVaultLogger(hclog.Trace)
fs, tmpDir := testFileSink(t, log)
defer os.RemoveAll(tmpDir)
path := filepath.Join(tmpDir, "token")
uuidStr, _ := uuid.GenerateUUID()
if err := fs.WriteToken(uuidStr); err != nil {
t.Fatal(err)
}
file, err := os.Open(path)
if err != nil {
t.Fatal(err)
}
fi, err := file.Stat()
if err != nil {
t.Fatal(err)
}
if fi.Mode() != os.FileMode(0640) {
t.Fatalf("wrong file mode was detected at %s", path)
}
err = file.Close()
if err != nil {
t.Fatal(err)
}
fileBytes, err := ioutil.ReadFile(path)
if err != nil {
t.Fatal(err)
}
if string(fileBytes) != uuidStr {
t.Fatalf("expected %s, got %s", uuidStr, string(fileBytes))
}
}

View File

@ -0,0 +1,121 @@
package file
import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
"sync/atomic"
"testing"
"time"
hclog "github.com/hashicorp/go-hclog"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/helper/logging"
)
func TestSinkServer(t *testing.T) {
log := logging.NewVaultLogger(hclog.Trace)
fs1, path1 := testFileSink(t, log)
defer os.RemoveAll(path1)
fs2, path2 := testFileSink(t, log)
defer os.RemoveAll(path2)
ctx, cancelFunc := context.WithCancel(context.Background())
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: log.Named("sink.server"),
})
uuidStr, _ := uuid.GenerateUUID()
in := make(chan string)
sinks := []*sink.SinkConfig{fs1, fs2}
go ss.Run(ctx, in, sinks)
// Seed a token
in <- uuidStr
// Give it time to finish writing
time.Sleep(1 * time.Second)
// Tell it to shut down and give it time to do so
cancelFunc()
<-ss.DoneCh
for _, path := range []string{path1, path2} {
fileBytes, err := ioutil.ReadFile(fmt.Sprintf("%s/token", path))
if err != nil {
t.Fatal(err)
}
if string(fileBytes) != uuidStr {
t.Fatalf("expected %s, got %s", uuidStr, string(fileBytes))
}
}
}
type badSink struct {
tryCount uint32
logger hclog.Logger
}
func (b *badSink) WriteToken(token string) error {
switch token {
case "bad":
atomic.AddUint32(&b.tryCount, 1)
b.logger.Info("got bad")
return errors.New("bad")
case "good":
atomic.StoreUint32(&b.tryCount, 0)
b.logger.Info("got good")
return nil
default:
return errors.New("unknown case")
}
}
func TestSinkServerRetry(t *testing.T) {
log := logging.NewVaultLogger(hclog.Trace)
b1 := &badSink{logger: log.Named("b1")}
b2 := &badSink{logger: log.Named("b2")}
ctx, cancelFunc := context.WithCancel(context.Background())
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: log.Named("sink.server"),
})
in := make(chan string)
sinks := []*sink.SinkConfig{&sink.SinkConfig{Sink: b1}, &sink.SinkConfig{Sink: b2}}
go ss.Run(ctx, in, sinks)
// Seed a token
in <- "bad"
// During this time we should see it retry multiple times
time.Sleep(10 * time.Second)
if atomic.LoadUint32(&b1.tryCount) < 2 {
t.Fatal("bad try count")
}
if atomic.LoadUint32(&b2.tryCount) < 2 {
t.Fatal("bad try count")
}
in <- "good"
time.Sleep(2 * time.Second)
if atomic.LoadUint32(&b1.tryCount) != 0 {
t.Fatal("bad try count")
}
if atomic.LoadUint32(&b2.tryCount) != 0 {
t.Fatal("bad try count")
}
// Tell it to shut down and give it time to do so
cancelFunc()
<-ss.DoneCh
}

242
command/agent/sink/sink.go Normal file
View File

@ -0,0 +1,242 @@
package sink
import (
"context"
"errors"
"io/ioutil"
"math/rand"
"os"
"sync/atomic"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/dhutil"
"github.com/hashicorp/vault/helper/jsonutil"
)
type Sink interface {
WriteToken(string) error
}
type SinkConfig struct {
Sink
Logger hclog.Logger
Config map[string]interface{}
Client *api.Client
WrapTTL time.Duration
DHType string
DHPath string
AAD string
cachedRemotePubKey []byte
cachedPubKey []byte
cachedPriKey []byte
}
type SinkServerConfig struct {
Logger hclog.Logger
Client *api.Client
Context context.Context
ExitAfterAuth bool
}
// SinkServer is responsible for pushing tokens to sinks
type SinkServer struct {
DoneCh chan struct{}
logger hclog.Logger
client *api.Client
random *rand.Rand
exitAfterAuth bool
remaining *int32
}
func NewSinkServer(conf *SinkServerConfig) *SinkServer {
ss := &SinkServer{
DoneCh: make(chan struct{}),
logger: conf.Logger,
client: conf.Client,
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
exitAfterAuth: conf.ExitAfterAuth,
remaining: new(int32),
}
return ss
}
// Run executes the server's run loop, which is responsible for reading
// in new tokens and pushing them out to the various sinks.
func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*SinkConfig) {
if incoming == nil {
panic("incoming or shutdown channel are nil")
}
ss.logger.Info("starting sink server")
defer func() {
ss.logger.Info("sink server stopped")
close(ss.DoneCh)
}()
latestToken := new(string)
sinkCh := make(chan func() error, len(sinks))
for {
select {
case <-ctx.Done():
return
case token := <-incoming:
if token != *latestToken {
// Drain the existing funcs
drainLoop:
for {
select {
case <-sinkCh:
atomic.AddInt32(ss.remaining, -1)
default:
break drainLoop
}
}
*latestToken = token
for _, s := range sinks {
sinkFunc := func(currSink *SinkConfig, currToken string) func() error {
return func() error {
if currToken != *latestToken {
return nil
}
var err error
if currSink.WrapTTL != 0 {
if currToken, err = s.wrapToken(ss.client, currSink.WrapTTL, currToken); err != nil {
return err
}
}
if s.DHType != "" {
if currToken, err = s.encryptToken(currToken); err != nil {
return err
}
}
return currSink.WriteToken(currToken)
}
}
atomic.AddInt32(ss.remaining, 1)
sinkCh <- sinkFunc(s, token)
}
}
case sinkFunc := <-sinkCh:
atomic.AddInt32(ss.remaining, -1)
select {
case <-ctx.Done():
return
default:
}
if err := sinkFunc(); err != nil {
backoff := 2*time.Second + time.Duration(ss.random.Int63()%int64(time.Second*2)-int64(time.Second))
ss.logger.Error("error returned by sink function, retrying", "error", err, "backoff", backoff.String())
select {
case <-ctx.Done():
return
case <-time.After(backoff):
atomic.AddInt32(ss.remaining, 1)
sinkCh <- sinkFunc
}
} else {
if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth {
return
}
}
}
}
}
func (s *SinkConfig) encryptToken(token string) (string, error) {
var aesKey []byte
var err error
resp := new(dhutil.Envelope)
switch s.DHType {
case "curve25519":
if len(s.cachedRemotePubKey) == 0 {
_, err = os.Lstat(s.DHPath)
if err != nil {
if !os.IsNotExist(err) {
return "", errwrap.Wrapf("error stat-ing dh parameters file: {{err}}", err)
}
return "", errors.New("no dh parameters file found, and no cached pub key")
}
fileBytes, err := ioutil.ReadFile(s.DHPath)
if err != nil {
return "", errwrap.Wrapf("error reading file for dh parameters: {{err}}", err)
}
theirPubKey := new(dhutil.PublicKeyInfo)
if err := jsonutil.DecodeJSON(fileBytes, theirPubKey); err != nil {
return "", errwrap.Wrapf("error decoding public key: {{err}}", err)
}
if len(theirPubKey.Curve25519PublicKey) == 0 {
return "", errors.New("public key is nil")
}
s.cachedRemotePubKey = theirPubKey.Curve25519PublicKey
}
if len(s.cachedPubKey) == 0 {
s.cachedPubKey, s.cachedPriKey, err = dhutil.GeneratePublicPrivateKey()
if err != nil {
return "", errwrap.Wrapf("error generating pub/pri curve25519 keys: {{err}}", err)
}
}
resp.Curve25519PublicKey = s.cachedPubKey
}
aesKey, err = dhutil.GenerateSharedKey(s.cachedPriKey, s.cachedRemotePubKey)
if err != nil {
return "", errwrap.Wrapf("error deriving shared key: {{err}}", err)
}
if len(aesKey) == 0 {
return "", errors.New("derived AES key is empty")
}
resp.EncryptedPayload, resp.Nonce, err = dhutil.EncryptAES(aesKey, []byte(token), []byte(s.AAD))
if err != nil {
return "", errwrap.Wrapf("error encrypting with shared key: {{err}}", err)
}
m, err := jsonutil.EncodeJSON(resp)
if err != nil {
return "", errwrap.Wrapf("error encoding encrypted payload: {{err}}", err)
}
return string(m), nil
}
func (s *SinkConfig) wrapToken(client *api.Client, wrapTTL time.Duration, token string) (string, error) {
wrapClient, err := client.Clone()
if err != nil {
return "", errwrap.Wrapf("error deriving client for wrapping, not writing out to sink: {{err}})", err)
}
wrapClient.SetToken(token)
wrapClient.SetWrappingLookupFunc(func(string, string) string {
return wrapTTL.String()
})
secret, err := wrapClient.Logical().Write("sys/wrapping/wrap", map[string]interface{}{
"token": token,
})
if err != nil {
return "", errwrap.Wrapf("error wrapping token, not writing out to sink: {{err}})", err)
}
if secret == nil {
return "", errors.New("nil secret returned, not writing out to sink")
}
if secret.WrapInfo == nil {
return "", errors.New("nil wrap info returned, not writing out to sink")
}
m, err := jsonutil.EncodeJSON(secret.WrapInfo)
if err != nil {
return "", errwrap.Wrapf("error marshaling token, not writing out to sink: {{err}})", err)
}
return string(m), nil
}

65
command/agent/testing.go Normal file
View File

@ -0,0 +1,65 @@
package agent
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/pem"
"testing"
"time"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)
func GetTestJWT(t *testing.T) (string, *ecdsa.PrivateKey) {
t.Helper()
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: "https://team-vault.auth0.com/",
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Audience: jwt.Audience{"https://vault.plugin.auth.jwt.test"},
}
privateCl := struct {
User string `json:"https://vault/user"`
Groups []string `json:"https://vault/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}
var key *ecdsa.PrivateKey
block, _ := pem.Decode([]byte(TestECDSAPrivKey))
if block != nil {
var err error
key, err = x509.ParseECPrivateKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
}
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
t.Fatal(err)
}
raw, err := jwt.Signed(sig).Claims(cl).Claims(privateCl).CompactSerialize()
if err != nil {
t.Fatal(err)
}
return raw, key
}
const (
TestECDSAPrivKey string = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKfldwWLPYsHjRL9EVTsjSbzTtcGRu6icohNfIqcb6A+oAoGCCqGSM49
AwEHoUQDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbSq+7+1q9BFxAkzjgKnlkXk5qx
hzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ==
-----END EC PRIVATE KEY-----`
TestECDSAPubKey string = `-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbS
q+7+1q9BFxAkzjgKnlkXk5qxhzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ==
-----END PUBLIC KEY-----`
)

186
command/agent_test.go Normal file
View File

@ -0,0 +1,186 @@
package command
import (
"fmt"
"io/ioutil"
"os"
"testing"
hclog "github.com/hashicorp/go-hclog"
vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent"
"github.com/hashicorp/vault/helper/logging"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
)
func testAgentCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *AgentCommand) {
tb.Helper()
ui := cli.NewMockUi()
return ui, &AgentCommand{
BaseCommand: &BaseCommand{
UI: ui,
},
ShutdownCh: MakeShutdownCh(),
logger: logger,
}
}
func TestExitAfterAuth(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"jwt": vaultjwt.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
client := cluster.Cores[0].Client
// Setup Vault
err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{
Type: "jwt",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
"bound_issuer": "https://team-vault.auth0.com/",
"jwt_validation_pubkeys": agent.TestECDSAPubKey,
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
"bound_audiences": "https://vault.plugin.auth.jwt.test",
"user_claim": "https://vault/user",
"groups_claim": "https://vault/groups",
"policies": "test",
"period": "3s",
})
if err != nil {
t.Fatal(err)
}
inf, err := ioutil.TempFile("", "auth.jwt.test.")
if err != nil {
t.Fatal(err)
}
in := inf.Name()
inf.Close()
os.Remove(in)
t.Logf("input: %s", in)
sink1f, err := ioutil.TempFile("", "sink1.jwt.test.")
if err != nil {
t.Fatal(err)
}
sink1 := sink1f.Name()
sink1f.Close()
os.Remove(sink1)
t.Logf("sink1: %s", sink1)
sink2f, err := ioutil.TempFile("", "sink2.jwt.test.")
if err != nil {
t.Fatal(err)
}
sink2 := sink2f.Name()
sink2f.Close()
os.Remove(sink2)
t.Logf("sink2: %s", sink2)
conff, err := ioutil.TempFile("", "conf.jwt.test.")
if err != nil {
t.Fatal(err)
}
conf := conff.Name()
conff.Close()
os.Remove(conf)
t.Logf("config: %s", conf)
jwtToken, _ := agent.GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test jwt", "path", in)
}
config := `
exit_after_auth = true
auto_auth {
method {
type = "jwt"
config = {
role = "test"
path = "%s"
}
}
sink {
type = "file"
config = {
path = "%s"
}
}
sink "file" {
config = {
path = "%s"
}
}
}
`
config = fmt.Sprintf(config, in, sink1, sink2)
if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test config", "path", conf)
}
// If this hangs forever until the test times out, exit-after-auth isn't
// working
ui, cmd := testAgentCommand(t, logger)
cmd.client = client
code := cmd.Run([]string{"-config", conf})
if code != 0 {
t.Errorf("expected %d to be %d", code, 0)
t.Logf("output from agent:\n%s", ui.OutputWriter.String())
t.Logf("error from agent:\n%s", ui.ErrorWriter.String())
}
sink1Bytes, err := ioutil.ReadFile(sink1)
if err != nil {
t.Fatal(err)
}
if len(sink1Bytes) == 0 {
t.Fatal("got no output from sink 1")
}
sink2Bytes, err := ioutil.ReadFile(sink2)
if err != nil {
t.Fatal(err)
}
if len(sink2Bytes) == 0 {
t.Fatal("got no output from sink 2")
}
if string(sink1Bytes) != string(sink2Bytes) {
t.Fatal("sink 1/2 values don't match")
}
}

View File

@ -154,7 +154,7 @@ func (c *AuditListCommand) detailedAudits(audits map[string]*api.Audit) []string
}
columns = append(columns, fmt.Sprintf("%s | %s | %s | %s | %s",
audit.Path,
path,
audit.Type,
audit.Description,
replication,

View File

@ -14,6 +14,7 @@ import (
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/token"
"github.com/hashicorp/vault/helper/namespace"
"github.com/mitchellh/cli"
"github.com/pkg/errors"
"github.com/posener/complete"
@ -37,6 +38,7 @@ type BaseCommand struct {
flagCAPath string
flagClientCert string
flagClientKey string
flagNamespace string
flagTLSServerName string
flagTLSSkipVerify bool
flagWrapTTL time.Duration
@ -118,6 +120,7 @@ func (c *BaseCommand) Client() (*api.Client, error) {
}
client.SetMFACreds(c.flagMFA)
client.SetNamespace(namespace.Canonicalize(c.flagNamespace))
c.client = client
@ -236,6 +239,16 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
"matching the client certificate from -client-cert.",
})
f.StringVar(&StringVar{
Name: "namespace",
Target: &c.flagNamespace,
Default: "",
EnvVar: "VAULT_NAMESPACE",
Completion: complete.PredictAnything,
Usage: "The namespace to use for the command. Setting this is not " +
"necessary but allows using relative paths.",
})
f.StringVar(&StringVar{
Name: "tls-server-name",
Target: &c.flagTLSServerName,

View File

@ -152,6 +152,22 @@ func parseArgsDataString(stdin io.Reader, args []string) (map[string]string, err
return result, nil
}
// parseArgsDataStringLists parses the args data and returns the values as
// string lists. If the values cannot be represented as strings, an error is
// returned.
func parseArgsDataStringLists(stdin io.Reader, args []string) (map[string][]string, error) {
raw, err := parseArgsData(stdin, args)
if err != nil {
return nil, err
}
var result map[string][]string
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, errors.Wrap(err, "failed to convert values to strings")
}
return result, nil
}
// truncateToSeconds truncates the given duration to the number of seconds. If
// the duration is less than 1s, it is returned as 0. The integer represents
// the whole number unit of seconds for the duration.

View File

@ -51,6 +51,7 @@ import (
credToken "github.com/hashicorp/vault/builtin/credential/token"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
physAliCloudOSS "github.com/hashicorp/vault/physical/alicloudoss"
physAzure "github.com/hashicorp/vault/physical/azure"
physCassandra "github.com/hashicorp/vault/physical/cassandra"
physCockroachDB "github.com/hashicorp/vault/physical/cockroachdb"
@ -137,6 +138,7 @@ var (
}
physicalBackends = map[string]physical.Factory{
"alicloudoss": physAliCloudOSS.NewAliCloudOSSBackend,
"azure": physAzure.NewAzureBackend,
"cassandra": physCassandra.NewCassandraBackend,
"cockroachdb": physCockroachDB.NewCockroachDBBackend,
@ -230,6 +232,14 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
}
Commands = map[string]cli.CommandFactory{
"agent": func() (cli.Command, error) {
return &AgentCommand{
BaseCommand: &BaseCommand{
UI: serverCmdUi,
},
ShutdownCh: MakeShutdownCh(),
}, nil
},
"audit": func() (cli.Command, error) {
return &AuditCommand{
BaseCommand: getBaseCommand(),
@ -313,6 +323,31 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
Handlers: loginHandlers,
}, nil
},
"namespace": func() (cli.Command, error) {
return &NamespaceCommand{
BaseCommand: getBaseCommand(),
}, nil
},
"namespace list": func() (cli.Command, error) {
return &NamespaceListCommand{
BaseCommand: getBaseCommand(),
}, nil
},
"namespace lookup": func() (cli.Command, error) {
return &NamespaceLookupCommand{
BaseCommand: getBaseCommand(),
}, nil
},
"namespace create": func() (cli.Command, error) {
return &NamespaceCreateCommand{
BaseCommand: getBaseCommand(),
}, nil
},
"namespace delete": func() (cli.Command, error) {
return &NamespaceDeleteCommand{
BaseCommand: getBaseCommand(),
}, nil
},
"operator": func() (cli.Command, error) {
return &OperatorCommand{
BaseCommand: getBaseCommand(),

View File

@ -182,6 +182,7 @@ var commonCommands = []string{
"delete",
"list",
"login",
"agent",
"server",
"status",
"unwrap",

51
command/namespace.go Normal file
View File

@ -0,0 +1,51 @@
package command
import (
"strings"
"github.com/mitchellh/cli"
)
var _ cli.Command = (*NamespaceCommand)(nil)
type NamespaceCommand struct {
*BaseCommand
}
func (c *NamespaceCommand) Synopsis() string {
return "Interact with namespaces"
}
func (c *NamespaceCommand) Help() string {
helpText := `
Usage: vault namespace <subcommand> [options] [args]
This command groups subcommands for interacting with Vault namespaces.
These set of subcommands operate on the context of the namespace that the
current logged in token belongs to.
List enabled child namespaces:
$ vault namespace list
Look up an existing namespace:
$ vault namespace lookup
Create a new namespace:
$ vault namespace create
Delete an existing namespace:
$ vault namespace delete
Please see the individual subcommand help for detailed usage information.
`
return strings.TrimSpace(helpText)
}
func (c *NamespaceCommand) Run(args []string) int {
return cli.RunResultHelp
}

View File

@ -0,0 +1,92 @@
package command
import (
"fmt"
"path"
"strings"
"github.com/mitchellh/cli"
"github.com/posener/complete"
)
var _ cli.Command = (*NamespaceCreateCommand)(nil)
var _ cli.CommandAutocomplete = (*NamespaceCreateCommand)(nil)
type NamespaceCreateCommand struct {
*BaseCommand
}
func (c *NamespaceCreateCommand) Synopsis() string {
return "Create a new namespace"
}
func (c *NamespaceCreateCommand) Help() string {
helpText := `
Usage: vault namespace create [options] PATH
Create a child namespace. The namespace created will be relative to the
namespace provided in either VAULT_NAMESPACE environemnt variable or
-namespace CLI flag.
Create a child namespace (e.g. ns1/):
$ vault namespace create ns1
Create a child namespace from a parent namespace (e.g. ns1/ns2/):
$ vault namespace create -namespace=ns1 ns2
` + c.Flags().Help()
return strings.TrimSpace(helpText)
}
func (c *NamespaceCreateCommand) Flags() *FlagSets {
return c.flagSet(FlagSetHTTP)
}
func (c *NamespaceCreateCommand) AutocompleteArgs() complete.Predictor {
return c.PredictVaultFolders()
}
func (c *NamespaceCreateCommand) AutocompleteFlags() complete.Flags {
return c.Flags().Completions()
}
func (c *NamespaceCreateCommand) Run(args []string) int {
f := c.Flags()
if err := f.Parse(args); err != nil {
c.UI.Error(err.Error())
return 1
}
args = f.Args()
switch {
case len(args) < 1:
c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args)))
return 1
case len(args) > 1:
c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args)))
return 1
}
namespacePath := strings.TrimSpace(args[0])
client, err := c.Client()
if err != nil {
c.UI.Error(err.Error())
return 2
}
err = client.Sys().CreateNamespace(namespacePath)
if err != nil {
c.UI.Error(fmt.Sprintf("Error creating namespace: %s", err))
return 2
}
// Output full path
fullPath := path.Join(c.flagNamespace, namespacePath) + "/"
c.UI.Output(fmt.Sprintf("Success! Namespace created at: %s", fullPath))
return 0
}

View File

@ -0,0 +1,92 @@
package command
import (
"fmt"
"path"
"strings"
"github.com/mitchellh/cli"
"github.com/posener/complete"
)
var _ cli.Command = (*NamespaceDeleteCommand)(nil)
var _ cli.CommandAutocomplete = (*NamespaceDeleteCommand)(nil)
type NamespaceDeleteCommand struct {
*BaseCommand
}
func (c *NamespaceDeleteCommand) Synopsis() string {
return "Delete an existing namespace"
}
func (c *NamespaceDeleteCommand) Help() string {
helpText := `
Usage: vault namespace delete [options] PATH
Delete an existing namespace. The namespace deleted will be relative to the
namespace provided in either VAULT_NAMESPACE environemnt variable or
-namespace CLI flag.
Delete a namespace (e.g. ns1/):
$ vault namespace delete ns1
Delete a namespace namespace from a parent namespace (e.g. ns1/ns2/):
$ vault namespace create -namespace=ns1 ns2
` + c.Flags().Help()
return strings.TrimSpace(helpText)
}
func (c *NamespaceDeleteCommand) Flags() *FlagSets {
return c.flagSet(FlagSetHTTP)
}
func (c *NamespaceDeleteCommand) AutocompleteArgs() complete.Predictor {
return c.PredictVaultFolders()
}
func (c *NamespaceDeleteCommand) AutocompleteFlags() complete.Flags {
return c.Flags().Completions()
}
func (c *NamespaceDeleteCommand) Run(args []string) int {
f := c.Flags()
if err := f.Parse(args); err != nil {
c.UI.Error(err.Error())
return 1
}
args = f.Args()
switch {
case len(args) < 1:
c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args)))
return 1
case len(args) > 1:
c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args)))
return 1
}
namespacePath := strings.TrimSpace(args[0])
client, err := c.Client()
if err != nil {
c.UI.Error(err.Error())
return 2
}
err = client.Sys().DeleteNamespace(namespacePath)
if err != nil {
c.UI.Error(fmt.Sprintf("Error deleting namespace: %s", err))
return 2
}
// Output full path
fullPath := path.Join(c.flagNamespace, namespacePath) + "/"
c.UI.Output(fmt.Sprintf("Success! Namespace deleted at: %s", fullPath))
return 0
}

84
command/namespace_list.go Normal file
View File

@ -0,0 +1,84 @@
package command
import (
"fmt"
"strings"
"github.com/mitchellh/cli"
"github.com/posener/complete"
)
var _ cli.Command = (*NamespaceListCommand)(nil)
var _ cli.CommandAutocomplete = (*NamespaceListCommand)(nil)
type NamespaceListCommand struct {
*BaseCommand
}
func (c *NamespaceListCommand) Synopsis() string {
return "List child namespaces"
}
func (c *NamespaceListCommand) Help() string {
helpText := `
Usage: vault namespaces list [options]
Lists the enabled child namespaces.
List all enabled child namespaces:
$ vault namespaces list
` + c.Flags().Help()
return strings.TrimSpace(helpText)
}
func (c *NamespaceListCommand) Flags() *FlagSets {
return c.flagSet(FlagSetHTTP | FlagSetOutputFormat)
}
func (c *NamespaceListCommand) AutocompleteArgs() complete.Predictor {
return c.PredictVaultFolders()
}
func (c *NamespaceListCommand) AutocompleteFlags() complete.Flags {
return c.Flags().Completions()
}
func (c *NamespaceListCommand) Run(args []string) int {
f := c.Flags()
if err := f.Parse(args); err != nil {
c.UI.Error(err.Error())
return 1
}
args = f.Args()
if len(args) > 0 {
c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args)))
return 1
}
client, err := c.Client()
if err != nil {
c.UI.Error(err.Error())
return 2
}
namespaces, err := client.Sys().ListNamespaces()
if err != nil {
c.UI.Error(fmt.Sprintf("Error listing namespaces: %s", err))
return 2
}
switch Format(c.UI) {
case "table":
for _, ns := range namespaces.NamespacePaths {
c.UI.Output(ns)
}
return 0
default:
return OutputData(c.UI, namespaces)
}
}

View File

@ -0,0 +1,96 @@
package command
import (
"fmt"
"strings"
"github.com/mitchellh/cli"
"github.com/posener/complete"
)
var _ cli.Command = (*NamespaceLookupCommand)(nil)
var _ cli.CommandAutocomplete = (*NamespaceLookupCommand)(nil)
type NamespaceLookupCommand struct {
*BaseCommand
}
func (c *NamespaceLookupCommand) Synopsis() string {
return "Create a new namespace"
}
func (c *NamespaceLookupCommand) Help() string {
helpText := `
Usage: vault namespace create [options] PATH
Create a child namespace. The namespace created will be relative to the
namespace provided in either VAULT_NAMESPACE environemnt variable or
-namespace CLI flag.
Get information about the namespace of the locally authenticated token:
$ vault namespace lookup
Get information about the namespace of a particular child token (e.g. ns1/ns2/):
$ vault namespace create -namespace=ns1 ns2
` + c.Flags().Help()
return strings.TrimSpace(helpText)
}
func (c *NamespaceLookupCommand) Flags() *FlagSets {
return c.flagSet(FlagSetHTTP | FlagSetOutputFormat)
}
func (c *NamespaceLookupCommand) AutocompleteArgs() complete.Predictor {
return c.PredictVaultFolders()
}
func (c *NamespaceLookupCommand) AutocompleteFlags() complete.Flags {
return c.Flags().Completions()
}
func (c *NamespaceLookupCommand) Run(args []string) int {
f := c.Flags()
if err := f.Parse(args); err != nil {
c.UI.Error(err.Error())
return 1
}
args = f.Args()
switch {
case len(args) < 1:
c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args)))
return 1
case len(args) > 1:
c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args)))
return 1
}
namespacePath := strings.TrimSpace(args[0])
client, err := c.Client()
if err != nil {
c.UI.Error(err.Error())
return 2
}
resp, err := client.Sys().GetNamespace(namespacePath)
if err != nil {
c.UI.Error(fmt.Sprintf("Error looking up namespace: %s", err))
return 2
}
switch Format(c.UI) {
case "table":
data := map[string]interface{}{
"path": resp.Path,
}
return OutputData(c.UI, data)
default:
return OutputData(c.UI, resp)
}
}

View File

@ -2,6 +2,8 @@ package command
import (
"fmt"
"io"
"os"
"strings"
"github.com/mitchellh/cli"
@ -13,6 +15,8 @@ var _ cli.CommandAutocomplete = (*ReadCommand)(nil)
type ReadCommand struct {
*BaseCommand
testStdin io.Reader // for tests
}
func (c *ReadCommand) Synopsis() string {
@ -63,9 +67,6 @@ func (c *ReadCommand) Run(args []string) int {
case len(args) < 1:
c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args)))
return 1
case len(args) > 1:
c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args)))
return 1
}
client, err := c.Client()
@ -74,9 +75,21 @@ func (c *ReadCommand) Run(args []string) int {
return 2
}
// Pull our fake stdin if needed
stdin := (io.Reader)(os.Stdin)
if c.testStdin != nil {
stdin = c.testStdin
}
path := sanitizePath(args[0])
secret, err := client.Logical().Read(path)
data, err := parseArgsDataStringLists(stdin, args[1:])
if err != nil {
c.UI.Error(fmt.Sprintf("Failed to parse K=V data: %s", err))
return 1
}
secret, err := client.Logical().ReadWithData(path, data)
if err != nil {
c.UI.Error(fmt.Sprintf("Error reading %s: %s", path, err))
return 2

View File

@ -34,10 +34,10 @@ func TestReadCommand_Run(t *testing.T) {
1,
},
{
"too_many_args",
[]string{"foo", "bar"},
"Too many arguments",
1,
"proper_args",
[]string{"foo", "bar=baz"},
"No value found at foo\n",
2,
},
{
"not_found",
@ -99,7 +99,7 @@ func TestReadCommand_Run(t *testing.T) {
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, tc.out) {
t.Errorf("expected %q to contain %q", combined, tc.out)
t.Errorf("%s: expected %q to contain %q", tc.name, combined, tc.out)
}
})
}

View File

@ -97,8 +97,9 @@ type ServerCommand struct {
type ServerListener struct {
net.Listener
config map[string]interface{}
maxRequestSize int64
config map[string]interface{}
maxRequestSize int64
maxRequestDuration time.Duration
}
func (c *ServerCommand) Synopsis() string {
@ -395,6 +396,10 @@ func (c *ServerCommand) Run(args []string) int {
return 1
}
if config.DefaultMaxRequestDuration != 0 {
vault.DefaultMaxRequestDuration = config.DefaultMaxRequestDuration
}
// If mlockall(2) isn't supported, show a warning. We disable this in dev
// because it is quite scary to see when first using Vault. We also disable
// this if the user has explicitly disabled mlock in configuration.
@ -738,10 +743,25 @@ CLUSTER_SYNTHESIS_COMPLETE:
}
props["max_request_size"] = fmt.Sprintf("%d", maxRequestSize)
var maxRequestDuration time.Duration = vault.DefaultMaxRequestDuration
if valRaw, ok := lnConfig.Config["max_request_duration"]; ok {
val, err := parseutil.ParseDurationSecond(valRaw)
if err != nil {
c.UI.Error(fmt.Sprintf("Could not parse max_request_duration value %v", valRaw))
return 1
}
if val >= 0 {
maxRequestDuration = val
}
}
props["max_request_duration"] = fmt.Sprintf("%s", maxRequestDuration.String())
lns = append(lns, ServerListener{
Listener: ln,
config: lnConfig.Config,
maxRequestSize: maxRequestSize,
Listener: ln,
config: lnConfig.Config,
maxRequestSize: maxRequestSize,
maxRequestDuration: maxRequestDuration,
})
// Store the listener props for output later
@ -939,6 +959,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
handler := vaulthttp.Handler(&vault.HandlerProperties{
Core: core,
MaxRequestSize: ln.maxRequestSize,
MaxRequestDuration: ln.maxRequestDuration,
DisablePrintableCheck: config.DisablePrintableCheck,
})
@ -1113,7 +1134,7 @@ func (c *ServerCommand) enableDev(core *vault.Core, coreConfig *vault.CoreConfig
"no_default_policy": true,
},
}
resp, err := core.HandleRequest(req)
resp, err := core.HandleRequest(context.Background(), req)
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("failed to create root token with ID %q: {{err}}", coreConfig.DevToken), err)
}
@ -1129,7 +1150,7 @@ func (c *ServerCommand) enableDev(core *vault.Core, coreConfig *vault.CoreConfig
req.ID = "dev-revoke-init-root"
req.Path = "auth/token/revoke-self"
req.Data = nil
resp, err = core.HandleRequest(req)
resp, err = core.HandleRequest(context.Background(), req)
if err != nil {
return nil, errwrap.Wrapf("failed to revoke initial root token: {{err}}", err)
}
@ -1156,7 +1177,7 @@ func (c *ServerCommand) enableDev(core *vault.Core, coreConfig *vault.CoreConfig
},
},
}
resp, err := core.HandleRequest(req)
resp, err := core.HandleRequest(context.Background(), req)
if err != nil {
return nil, errwrap.Wrapf("error upgrading default K/V store: {{err}}", err)
}
@ -1233,7 +1254,7 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m
"no_default_policy": true,
},
}
resp, err := testCluster.Cores[0].HandleRequest(req)
resp, err := testCluster.Cores[0].HandleRequest(context.Background(), req)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to create root token with ID %s: %s", base.DevToken, err))
return 1
@ -1252,7 +1273,7 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m
req.ID = "dev-revoke-init-root"
req.Path = "auth/token/revoke-self"
req.Data = nil
resp, err = testCluster.Cores[0].HandleRequest(req)
resp, err = testCluster.Cores[0].HandleRequest(context.Background(), req)
if err != nil {
c.UI.Output(fmt.Sprintf("failed to revoke initial root token: %s", err))
return 1
@ -1385,7 +1406,7 @@ func (c *ServerCommand) addPlugin(path, token string, core *vault.Core) error {
"command": name,
},
}
if _, err := core.HandleRequest(req); err != nil {
if _, err := core.HandleRequest(context.Background(), req); err != nil {
return err
}

View File

@ -16,7 +16,6 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/hcl"
"github.com/hashicorp/hcl/hcl/ast"
"github.com/hashicorp/vault/helper/hclutil"
"github.com/hashicorp/vault/helper/parseutil"
)
@ -46,6 +45,9 @@ type Config struct {
DefaultLeaseTTL time.Duration `hcl:"-"`
DefaultLeaseTTLRaw interface{} `hcl:"default_lease_ttl"`
DefaultMaxRequestDuration time.Duration `hcl:"-"`
DefaultMaxRequestDurationRaw interface{} `hcl:"default_max_request_time"`
ClusterName string `hcl:"cluster_name"`
ClusterCipherSuites string `hcl:"cluster_cipher_suites"`
@ -289,6 +291,11 @@ func (c *Config) Merge(c2 *Config) *Config {
result.DefaultLeaseTTL = c2.DefaultLeaseTTL
}
result.DefaultMaxRequestDuration = c.DefaultMaxRequestDuration
if c2.DefaultMaxRequestDuration > result.DefaultMaxRequestDuration {
result.DefaultMaxRequestDuration = c2.DefaultMaxRequestDuration
}
result.ClusterName = c.ClusterName
if c2.ClusterName != "" {
result.ClusterName = c2.ClusterName
@ -375,6 +382,12 @@ func ParseConfig(d string, logger log.Logger) (*Config, error) {
}
}
if result.DefaultMaxRequestDurationRaw != nil {
if result.DefaultMaxRequestDuration, err = parseutil.ParseDurationSecond(result.DefaultMaxRequestDurationRaw); err != nil {
return nil, err
}
}
if result.EnableUIRaw != nil {
if result.EnableUI, err = parseutil.ParseBool(result.EnableUIRaw); err != nil {
return nil, err
@ -422,36 +435,6 @@ func ParseConfig(d string, logger log.Logger) (*Config, error) {
return nil, fmt.Errorf("error parsing: file doesn't contain a root object")
}
valid := []string{
"storage",
"ha_storage",
"backend",
"ha_backend",
"hsm",
"seal",
"listener",
"cache_size",
"disable_cache",
"disable_mlock",
"disable_printable_check",
"ui",
"telemetry",
"default_lease_ttl",
"max_lease_ttl",
"cluster_name",
"cluster_cipher_suites",
"plugin_directory",
"pid_file",
"raw_storage_endpoint",
"api_addr",
"cluster_addr",
"disable_clustering",
"disable_sealwrap",
}
if err := hclutil.CheckHCLKeys(list, valid); err != nil {
return nil, err
}
// Look for storage but still support old backend
if o := list.Filter("storage"); len(o.Items) > 0 {
if err := parseStorage(&result, o, "storage"); err != nil {
@ -728,61 +711,16 @@ func parseSeal(result *Config, list *ast.ObjectList, blockName string) error {
key = item.Keys[0].Token.Value().(string)
}
var valid []string
// Valid parameter for the Seal types
switch key {
case "pkcs11":
valid = []string{
"lib",
"slot",
"token_label",
"pin",
"mechanism",
"hmac_mechanism",
"key_label",
"default_key_label",
"hmac_key_label",
"hmac_default_key_label",
"generate_key",
"regenerate_key",
"max_parallel",
"disable_auto_reinit_on_error",
"rsa_encrypt_local",
"rsa_oaep_hash",
}
case "awskms":
valid = []string{
"region",
"access_key",
"secret_key",
"kms_key_id",
"max_parallel",
}
case "gcpckms":
valid = []string{
"credentials",
"project",
"region",
"key_ring",
"crypto_key",
}
case "azurekeyvault":
valid = []string{
"tenant_id",
"client_id",
"client_secret",
"environment",
"vault_name",
"key_name",
}
default:
return fmt.Errorf("invalid seal type %q", key)
}
if err := hclutil.CheckHCLKeys(item.Val, valid); err != nil {
return multierror.Prefix(err, fmt.Sprintf("%s.%s:", blockName, key))
}
var m map[string]string
if err := hcl.DecodeObject(&m, item.Val); err != nil {
return multierror.Prefix(err, fmt.Sprintf("%s.%s:", blockName, key))
@ -804,34 +742,6 @@ func parseListeners(result *Config, list *ast.ObjectList) error {
key = item.Keys[0].Token.Value().(string)
}
valid := []string{
"address",
"cluster_address",
"endpoint",
"x_forwarded_for_authorized_addrs",
"x_forwarded_for_hop_skips",
"x_forwarded_for_reject_not_authorized",
"x_forwarded_for_reject_not_present",
"infrastructure",
"max_request_size",
"node_id",
"proxy_protocol_behavior",
"proxy_protocol_authorized_addrs",
"tls_disable",
"tls_cert_file",
"tls_key_file",
"tls_min_version",
"tls_cipher_suites",
"tls_prefer_server_cipher_suites",
"tls_require_and_verify_client_cert",
"tls_disable_client_certs",
"tls_client_ca_file",
"token",
}
if err := hclutil.CheckHCLKeys(item.Val, valid); err != nil {
return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key))
}
var m map[string]interface{}
if err := hcl.DecodeObject(&m, item.Val); err != nil {
return multierror.Prefix(err, fmt.Sprintf("listeners.%s:", key))
@ -857,31 +767,6 @@ func parseTelemetry(result *Config, list *ast.ObjectList) error {
// Get our one item
item := list.Items[0]
// Check for invalid keys
valid := []string{
"circonus_api_token",
"circonus_api_app",
"circonus_api_url",
"circonus_submission_interval",
"circonus_submission_url",
"circonus_check_id",
"circonus_check_force_metric_activation",
"circonus_check_instance_id",
"circonus_check_search_tag",
"circonus_check_display_name",
"circonus_check_tags",
"circonus_broker_id",
"circonus_broker_select_tag",
"disable_hostname",
"dogstatsd_addr",
"dogstatsd_tags",
"statsd_address",
"statsite_address",
}
if err := hclutil.CheckHCLKeys(item.Val, valid); err != nil {
return multierror.Prefix(err, "telemetry:")
}
var t Telemetry
if err := hcl.DecodeObject(&t, item.Val); err != nil {
return multierror.Prefix(err, "telemetry:")

View File

@ -383,73 +383,3 @@ listener "tcp" {
}
}
func TestParseConfig_badTopLevel(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)
_, err := ParseConfig(strings.TrimSpace(`
backend {}
bad = "one"
nope = "yes"
`), logger)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), `invalid key "bad" on line 2`) {
t.Errorf("bad error: %q", err)
}
if !strings.Contains(err.Error(), `invalid key "nope" on line 3`) {
t.Errorf("bad error: %q", err)
}
}
func TestParseConfig_badListener(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)
_, err := ParseConfig(strings.TrimSpace(`
listener "tcp" {
address = "1.2.3.3"
bad = "one"
nope = "yes"
}
`), logger)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), `listeners.tcp: invalid key "bad" on line 3`) {
t.Errorf("bad error: %q", err)
}
if !strings.Contains(err.Error(), `listeners.tcp: invalid key "nope" on line 4`) {
t.Errorf("bad error: %q", err)
}
}
func TestParseConfig_badTelemetry(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)
_, err := ParseConfig(strings.TrimSpace(`
telemetry {
statsd_address = "1.2.3.3"
bad = "one"
nope = "yes"
}
`), logger)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), `telemetry: invalid key "bad" on line 3`) {
t.Errorf("bad error: %q", err)
}
if !strings.Contains(err.Error(), `telemetry: invalid key "nope" on line 4`) {
t.Errorf("bad error: %q", err)
}
}

51
helper/base62/base62.go Normal file
View File

@ -0,0 +1,51 @@
// Package base62 provides utilities for working with base62 strings.
// base62 strings will only contain characters: 0-9, a-z, A-Z
package base62
import (
"math/big"
uuid "github.com/hashicorp/go-uuid"
)
// Encode converts buf into a base62 string
//
// Note: this should only be used for reducing a string's character set range.
// It is not for use with arbitrary data since leading 0 bytes will be dropped.
func Encode(buf []byte) string {
var encoder big.Int
encoder.SetBytes(buf)
return encoder.Text(62)
}
// Decode converts input from base62 to its byte representation
// If the decoding fails, an empty slice is returned.
func Decode(input string) []byte {
var decoder big.Int
decoder.SetString(input, 62)
return decoder.Bytes()
}
// Random generates a random base62-encoded string.
// If truncate is true, the result will be a string of the requested length.
// Otherwise, it will be the encoded result of length bytes of random data.
func Random(length int, truncate bool) (string, error) {
for {
buf, err := uuid.GenerateRandomBytes(length)
if err != nil {
return "", err
}
result := Encode(buf)
if truncate {
if len(result) < length {
continue
}
result = result[:length]
}
return result, nil
}
}

View File

@ -0,0 +1,110 @@
package base62
import (
"testing"
)
func TestValid(t *testing.T) {
tCases := []struct {
in string
out string
}{
{
"",
"0",
},
{
"foo",
"sapp",
},
{
"5d5746d044b9a9429249966c9e3fee178ca679b91487b11d4b73c9865202104c",
"cozMP2pOYdDiNGeFQ2afKAOGIzO0HVpJ8OPFXuVPNbHasFyenK9CzIIPuOG7EFWOCy4YWvKGZa671N4kRSoaxZ",
},
{
"5ba33e16d742f3c785f6e7e8bb6f5fe82346ffa1c47aa8e95da4ddd5a55bb334",
"cotpEJPnhuTRofLi4lDe5iKw2fkSGc6TpUYeuWoBp8eLYJBWLRUVDZI414OjOCWXKZ0AI8gqNMoxd4eLOklwYk",
},
{
" ",
"w",
},
{
"-",
"J",
},
{
"0",
"M",
},
{
"1",
"N",
},
{
"-1",
"30B",
},
{
"11",
"3h7",
},
{
"abc",
"qMin",
},
{
"1234598760",
"1a0AFzKIPnihTq",
},
{
"abcdefghijklmnopqrstuvwxyz",
"hUBXsgd3F2swSlEgbVi2p0Ncr6kzVeJTLaW",
},
}
for _, c := range tCases {
e := Encode([]byte(c.in))
d := string(Decode(e))
if d != c.in {
t.Fatalf("decoded value didn't match input %#v %#v", c.in, d)
}
if e != c.out {
t.Fatalf("encoded value didn't match expected %#v, %#v", e, c.out)
}
}
}
func TestInvalid(t *testing.T) {
d := Decode("!0000/")
if len(d) != 0 {
t.Fatalf("Decode of invalid string should be empty, got %#v", d)
}
}
func TestRandom(t *testing.T) {
a, err1 := Random(16, true)
b, err2 := Random(16, true)
if err1 != nil || err2 != nil {
t.Fatalf("Unexpected errors: %v, %v", err1, err2)
}
if a == b {
t.Fatalf("Expected different random values. Got duplicate: %s", a)
}
for i := 0; i < 3000; i++ {
c, _ := Random(i, true)
if len(c) != i {
t.Fatalf("Expected length %d, got: %d", i, len(c))
}
}
d, _ := Random(100, false)
if len(d) < 133 || len(d) > 135 {
t.Fatalf("Expected length 133-135, got: %d", len(d))
}
}

121
helper/dhutil/dhutil.go Normal file
View File

@ -0,0 +1,121 @@
package dhutil
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"fmt"
"io"
"golang.org/x/crypto/curve25519"
)
type PublicKeyInfo struct {
Curve25519PublicKey []byte `json:"curve25519_public_key"`
}
type Envelope struct {
Curve25519PublicKey []byte `json:"curve25519_public_key"`
Nonce []byte `json:"nonce"`
EncryptedPayload []byte `json:"encrypted_payload"`
}
// generatePublicPrivateKey uses curve25519 to generate a public and private key
// pair.
func GeneratePublicPrivateKey() ([]byte, []byte, error) {
var scalar, public [32]byte
if _, err := io.ReadFull(rand.Reader, scalar[:]); err != nil {
return nil, nil, err
}
curve25519.ScalarBaseMult(&public, &scalar)
return public[:], scalar[:], nil
}
// generateSharedKey uses the private key and the other party's public key to
// generate the shared secret.
func GenerateSharedKey(ourPrivate, theirPublic []byte) ([]byte, error) {
if len(ourPrivate) != 32 {
return nil, fmt.Errorf("invalid private key length: %d", len(ourPrivate))
}
if len(theirPublic) != 32 {
return nil, fmt.Errorf("invalid public key length: %d", len(theirPublic))
}
var scalar, pub, secret [32]byte
copy(scalar[:], ourPrivate)
copy(pub[:], theirPublic)
curve25519.ScalarMult(&secret, &scalar, &pub)
return secret[:], nil
}
// Use AES256-GCM to encrypt some plaintext with a provided key. The returned values are
// the ciphertext, the nonce, and error respectively.
func EncryptAES(key, plaintext, aad []byte) ([]byte, []byte, error) {
// We enforce AES-256, so check explicitly for 32 bytes on the key
if len(key) != 32 {
return nil, nil, fmt.Errorf("invalid key length: %d", len(key))
}
if len(plaintext) == 0 {
return nil, nil, errors.New("empty plaintext provided")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
// Never use more than 2^32 random nonces with a given key because of the risk of a repeat.
nonce := make([]byte, 12)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, nil, err
}
ciphertext := aesgcm.Seal(nil, nonce, plaintext, aad)
return ciphertext, nonce, nil
}
// Use AES256-GCM to decrypt some ciphertext with a provided key and nonce. The
// returned values are the plaintext and error respectively.
func DecryptAES(key, ciphertext, nonce, aad []byte) ([]byte, error) {
// We enforce AES-256, so check explicitly for 32 bytes on the key
if len(key) != 32 {
return nil, fmt.Errorf("invalid key length: %d", len(key))
}
if len(ciphertext) == 0 {
return nil, errors.New("empty ciphertext provided")
}
if len(nonce) == 0 {
return nil, errors.New("empty nonce provided")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
plaintext, err := aesgcm.Open(nil, nonce, ciphertext, aad)
if err != nil {
return nil, err
}
return plaintext, nil
}

View File

@ -0,0 +1,204 @@
package identity
import (
"errors"
"strings"
)
var (
ErrUnbalancedTemplatingCharacter = errors.New("unbalanced templating characters")
ErrNoEntityAttachedToToken = errors.New("string contains entity template directives but no entity was provided")
ErrNoGroupsAttachedToToken = errors.New("string contains groups template directives but no groups were provided")
ErrTemplateValueNotFound = errors.New("no value could be found for one of the template directives")
)
type PopulateStringInput struct {
ValidityCheckOnly bool
String string
Entity *Entity
Groups []*Group
}
func PopulateString(p *PopulateStringInput) (bool, string, error) {
if p == nil {
return false, "", errors.New("nil input")
}
if p.String == "" {
return false, "", nil
}
var subst bool
splitStr := strings.Split(p.String, "{{")
if len(splitStr) >= 1 {
if strings.Index(splitStr[0], "}}") != -1 {
return false, "", ErrUnbalancedTemplatingCharacter
}
if len(splitStr) == 1 {
return false, p.String, nil
}
}
var b strings.Builder
if !p.ValidityCheckOnly {
b.Grow(2 * len(p.String))
}
for i, str := range splitStr {
if i == 0 {
if !p.ValidityCheckOnly {
b.WriteString(str)
}
continue
}
splitPiece := strings.Split(str, "}}")
switch len(splitPiece) {
case 2:
subst = true
if !p.ValidityCheckOnly {
tmplStr, err := performTemplating(strings.TrimSpace(splitPiece[0]), p.Entity, p.Groups)
if err != nil {
return false, "", err
}
b.WriteString(tmplStr)
b.WriteString(splitPiece[1])
}
default:
return false, "", ErrUnbalancedTemplatingCharacter
}
}
return subst, b.String(), nil
}
func performTemplating(input string, entity *Entity, groups []*Group) (string, error) {
performAliasTemplating := func(trimmed string, alias *Alias) (string, error) {
switch {
case trimmed == "id":
return alias.ID, nil
case trimmed == "name":
if alias.Name == "" {
return "", ErrTemplateValueNotFound
}
return alias.Name, nil
case strings.HasPrefix(trimmed, "metadata."):
val, ok := alias.Metadata[strings.TrimPrefix(trimmed, "metadata.")]
if !ok {
return "", ErrTemplateValueNotFound
}
return val, nil
}
return "", ErrTemplateValueNotFound
}
performEntityTemplating := func(trimmed string) (string, error) {
switch {
case trimmed == "id":
return entity.ID, nil
case trimmed == "name":
if entity.Name == "" {
return "", ErrTemplateValueNotFound
}
return entity.Name, nil
case strings.HasPrefix(trimmed, "metadata."):
val, ok := entity.Metadata[strings.TrimPrefix(trimmed, "metadata.")]
if !ok {
return "", ErrTemplateValueNotFound
}
return val, nil
case strings.HasPrefix(trimmed, "aliases."):
split := strings.SplitN(strings.TrimPrefix(trimmed, "aliases."), ".", 2)
if len(split) != 2 {
return "", errors.New("invalid alias selector")
}
var found *Alias
for _, alias := range entity.Aliases {
if split[0] == alias.MountAccessor {
found = alias
break
}
}
if found == nil {
return "", errors.New("alias not found")
}
return performAliasTemplating(split[1], found)
}
return "", ErrTemplateValueNotFound
}
performGroupsTemplating := func(trimmed string) (string, error) {
var ids bool
selectorSplit := strings.SplitN(trimmed, ".", 2)
switch {
case len(selectorSplit) != 2:
return "", errors.New("invalid groups selector")
case selectorSplit[0] == "ids":
ids = true
case selectorSplit[0] == "names":
default:
return "", errors.New("invalid groups selector")
}
trimmed = selectorSplit[1]
accessorSplit := strings.SplitN(trimmed, ".", 2)
if len(accessorSplit) != 2 {
return "", errors.New("invalid groups accessor")
}
var found *Group
for _, group := range groups {
compare := group.Name
if ids {
compare = group.ID
}
if compare == accessorSplit[0] {
found = group
break
}
}
if found == nil {
return "", errors.New("group not found")
}
trimmed = accessorSplit[1]
switch {
case trimmed == "id":
return found.ID, nil
case trimmed == "name":
if found.Name == "" {
return "", ErrTemplateValueNotFound
}
return found.Name, nil
case strings.HasPrefix(trimmed, "metadata."):
val, ok := found.Metadata[strings.TrimPrefix(trimmed, "metadata.")]
if !ok {
return "", ErrTemplateValueNotFound
}
return val, nil
}
return "", ErrTemplateValueNotFound
}
switch {
case strings.HasPrefix(input, "identity.entity."):
if entity == nil {
return "", ErrNoEntityAttachedToToken
}
return performEntityTemplating(strings.TrimPrefix(input, "identity.entity."))
case strings.HasPrefix(input, "identity.groups."):
if len(groups) == 0 {
return "", ErrNoGroupsAttachedToToken
}
return performGroupsTemplating(strings.TrimPrefix(input, "identity.groups."))
}
return "", ErrTemplateValueNotFound
}

View File

@ -0,0 +1,194 @@
package identity
import (
"errors"
"testing"
)
func TestPopulate_Basic(t *testing.T) {
var tests = []struct {
name string
input string
output string
err error
entityName string
metadata map[string]string
aliasAccessor string
aliasID string
aliasName string
nilEntity bool
validityCheckOnly bool
aliasMetadata map[string]string
groupName string
groupMetadata map[string]string
}{
{
name: "no_templating",
input: "path foobar {",
output: "path foobar {",
},
{
name: "only_closing",
input: "path foobar}} {",
err: ErrUnbalancedTemplatingCharacter,
},
{
name: "closing_in_front",
input: "path }} {{foobar}} {",
err: ErrUnbalancedTemplatingCharacter,
},
{
name: "closing_in_back",
input: "path {{foobar}} }}",
err: ErrUnbalancedTemplatingCharacter,
},
{
name: "basic",
input: "path /{{identity.entity.id}}/ {",
output: "path /entityID/ {",
},
{
name: "multiple",
input: "path {{identity.entity.name}} {\n\tval = {{identity.entity.metadata.foo}}\n}",
entityName: "entityName",
metadata: map[string]string{"foo": "bar"},
output: "path entityName {\n\tval = bar\n}",
},
{
name: "multiple_bad_name",
input: "path {{identity.entity.name}} {\n\tval = {{identity.entity.metadata.foo}}\n}",
metadata: map[string]string{"foo": "bar"},
err: ErrTemplateValueNotFound,
},
{
name: "unbalanced_close",
input: "path {{identity.entity.id}} {\n\tval = {{ent}}ity.metadata.foo}}\n}",
err: ErrUnbalancedTemplatingCharacter,
},
{
name: "unbalanced_open",
input: "path {{identity.entity.id}} {\n\tval = {{ent{{ity.metadata.foo}}\n}",
err: ErrUnbalancedTemplatingCharacter,
},
{
name: "no_entity_no_directives",
input: "path {{identity.entity.id}} {\n\tval = {{ent{{ity.metadata.foo}}\n}",
err: ErrNoEntityAttachedToToken,
nilEntity: true,
},
{
name: "no_entity_no_diretives",
input: "path name {\n\tval = foo\n}",
output: "path name {\n\tval = foo\n}",
nilEntity: true,
},
{
name: "alias_id_name",
input: "path {{ identity.entity.name}} {\n\tval = {{identity.entity.aliases.foomount.id}}\n}",
entityName: "entityName",
aliasAccessor: "foomount",
aliasID: "aliasID",
metadata: map[string]string{"foo": "bar"},
output: "path entityName {\n\tval = aliasID\n}",
},
{
name: "alias_id_name_bad_selector",
input: "path foobar {\n\tval = {{identity.entity.aliases.foomount}}\n}",
aliasAccessor: "foomount",
err: errors.New("invalid alias selector"),
},
{
name: "alias_id_name_bad_accessor",
input: "path \"foobar\" {\n\tval = {{identity.entity.aliases.barmount.id}}\n}",
aliasAccessor: "foomount",
err: errors.New("alias not found"),
},
{
name: "alias_id_name",
input: "path \"{{identity.entity.name}}\" {\n\tval = {{identity.entity.aliases.foomount.metadata.zip}}\n}",
entityName: "entityName",
aliasAccessor: "foomount",
aliasID: "aliasID",
metadata: map[string]string{"foo": "bar"},
aliasMetadata: map[string]string{"zip": "zap"},
output: "path \"entityName\" {\n\tval = zap\n}",
},
{
name: "group_name",
input: "path \"{{identity.groups.ids.groupID.name}}\" {\n\tval = {{identity.entity.name}}\n}",
entityName: "entityName",
groupName: "groupName",
output: "path \"groupName\" {\n\tval = entityName\n}",
},
{
name: "group_bad_id",
input: "path \"{{identity.groups.ids.hroupID.name}}\" {\n\tval = {{identity.entity.name}}\n}",
entityName: "entityName",
groupName: "groupName",
err: errors.New("group not found"),
},
{
name: "group_id",
input: "path \"{{identity.groups.names.groupName.id}}\" {\n\tval = {{identity.entity.name}}\n}",
entityName: "entityName",
groupName: "groupName",
output: "path \"groupID\" {\n\tval = entityName\n}",
},
{
name: "group_bad_name",
input: "path \"{{identity.groups.names.hroupName.id}}\" {\n\tval = {{identity.entity.name}}\n}",
entityName: "entityName",
groupName: "groupName",
err: errors.New("group not found"),
},
}
for _, test := range tests {
var entity *Entity
if !test.nilEntity {
entity = &Entity{
ID: "entityID",
Name: test.entityName,
Metadata: test.metadata,
}
}
if test.aliasAccessor != "" {
entity.Aliases = []*Alias{
&Alias{
MountAccessor: test.aliasAccessor,
ID: test.aliasID,
Name: test.aliasName,
Metadata: test.aliasMetadata,
},
}
}
var groups []*Group
if test.groupName != "" {
groups = append(groups, &Group{
ID: "groupID",
Name: test.groupName,
Metadata: test.groupMetadata,
})
}
subst, out, err := PopulateString(&PopulateStringInput{
ValidityCheckOnly: test.validityCheckOnly,
String: test.input,
Entity: entity,
Groups: groups,
})
if err != nil {
if test.err == nil {
t.Fatalf("%s: expected success, got error: %v", test.name, err)
}
if err.Error() != test.err.Error() {
t.Fatalf("%s: got error: %v", test.name, err)
}
}
if out != test.output {
t.Fatalf("%s: bad output: %s", test.name, out)
}
if err == nil && !subst && out != test.input {
t.Fatalf("%s: bad subst flag", test.name)
}
}
}

View File

@ -66,7 +66,7 @@ func (m *Group) Reset() { *m = Group{} }
func (m *Group) String() string { return proto.CompactTextString(m) }
func (*Group) ProtoMessage() {}
func (*Group) Descriptor() ([]byte, []int) {
return fileDescriptor_types_01b7fd3cfabd028f, []int{0}
return fileDescriptor_types_0360db4a8e77dd9b, []int{0}
}
func (m *Group) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Group.Unmarshal(m, b)
@ -223,7 +223,7 @@ func (m *Entity) Reset() { *m = Entity{} }
func (m *Entity) String() string { return proto.CompactTextString(m) }
func (*Entity) ProtoMessage() {}
func (*Entity) Descriptor() ([]byte, []int) {
return fileDescriptor_types_01b7fd3cfabd028f, []int{1}
return fileDescriptor_types_0360db4a8e77dd9b, []int{1}
}
func (m *Entity) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Entity.Unmarshal(m, b)
@ -358,7 +358,7 @@ func (m *Alias) Reset() { *m = Alias{} }
func (m *Alias) String() string { return proto.CompactTextString(m) }
func (*Alias) ProtoMessage() {}
func (*Alias) Descriptor() ([]byte, []int) {
return fileDescriptor_types_01b7fd3cfabd028f, []int{2}
return fileDescriptor_types_0360db4a8e77dd9b, []int{2}
}
func (m *Alias) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Alias.Unmarshal(m, b)
@ -457,9 +457,9 @@ func init() {
proto.RegisterMapType((map[string]string)(nil), "identity.Alias.MetadataEntry")
}
func init() { proto.RegisterFile("helper/identity/types.proto", fileDescriptor_types_01b7fd3cfabd028f) }
func init() { proto.RegisterFile("helper/identity/types.proto", fileDescriptor_types_0360db4a8e77dd9b) }
var fileDescriptor_types_01b7fd3cfabd028f = []byte{
var fileDescriptor_types_0360db4a8e77dd9b = []byte{
// 656 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xac, 0x95, 0x5d, 0x6f, 0xd3, 0x3c,
0x14, 0xc7, 0xd5, 0xa6, 0x2f, 0xe9, 0x69, 0xd7, 0xed, 0xb1, 0x1e, 0xa1, 0x50, 0x34, 0xe8, 0x26,

View File

@ -55,6 +55,12 @@ message Group {
// the memberships on the external group --for which a corresponding alias
// will be set-- will be managed automatically.
string type = 12;
// **Enterprise only**
// NamespaceID is the identifier of the namespace to which this group
// belongs to. Do not return this value over the API when reading the
// group.
//string namespace_id = 13;
}
@ -116,6 +122,12 @@ message Entity {
// Disabled indicates whether tokens associated with the account should not
// be able to be used
bool disabled = 11;
// **Enterprise only**
// NamespaceID is the identifier of the namespace to which this entity
// belongs to. Do not return this value over the API when reading the
// entity.
//string namespace_id = 12;
}
// Alias represents the alias that gets stored inside of the

View File

@ -4,12 +4,12 @@ import (
"context"
"encoding/base64"
"errors"
"math/big"
paths "path"
"sort"
"strings"
"github.com/hashicorp/golang-lru"
"github.com/hashicorp/vault/helper/base62"
"github.com/hashicorp/vault/logical"
)
@ -174,7 +174,7 @@ func (s *encryptedKeyStorage) List(ctx context.Context, prefix string) ([]string
k = strings.TrimSuffix(k, "/")
}
decoded := Base62Decode(k)
decoded := base62.Decode(k)
if len(decoded) == 0 {
return nil, errors.New("could not decode key")
}
@ -268,23 +268,9 @@ func (s *encryptedKeyStorage) encryptPath(path string) (string, error) {
return "", err
}
encPath = paths.Join(encPath, Base62Encode([]byte(ciphertext)))
encPath = paths.Join(encPath, base62.Encode([]byte(ciphertext)))
context = paths.Join(context, p)
}
return encPath, nil
}
func Base62Encode(buf []byte) string {
encoder := &big.Int{}
encoder.SetBytes(buf)
return encoder.Text(62)
}
func Base62Decode(input string) []byte {
decoder := &big.Int{}
decoder.SetString(input, 62)
return decoder.Bytes()
}

View File

@ -12,84 +12,6 @@ import (
var compilerOpt []string
func TestBase58(t *testing.T) {
tCases := []struct {
in string
out string
}{
{
"",
"0",
},
{
"foo",
"sapp",
},
{
"5d5746d044b9a9429249966c9e3fee178ca679b91487b11d4b73c9865202104c",
"cozMP2pOYdDiNGeFQ2afKAOGIzO0HVpJ8OPFXuVPNbHasFyenK9CzIIPuOG7EFWOCy4YWvKGZa671N4kRSoaxZ",
},
{
"5ba33e16d742f3c785f6e7e8bb6f5fe82346ffa1c47aa8e95da4ddd5a55bb334",
"cotpEJPnhuTRofLi4lDe5iKw2fkSGc6TpUYeuWoBp8eLYJBWLRUVDZI414OjOCWXKZ0AI8gqNMoxd4eLOklwYk",
},
{
" ",
"w",
},
{
"-",
"J",
},
{
"0",
"M",
},
{
"1",
"N",
},
{
"-1",
"30B",
},
{
"11",
"3h7",
},
{
"abc",
"qMin",
},
{
"1234598760",
"1a0AFzKIPnihTq",
},
{
"abcdefghijklmnopqrstuvwxyz",
"hUBXsgd3F2swSlEgbVi2p0Ncr6kzVeJTLaW",
},
}
for _, c := range tCases {
e := Base62Encode([]byte(c.in))
d := string(Base62Decode(e))
if d != c.in {
t.Fatalf("decoded value didn't match input %#v %#v", c.in, d)
}
if e != c.out {
t.Fatalf("encoded value didn't match expected %#v, %#v", e, c.out)
}
}
d := Base62Decode("!0000/")
if len(d) != 0 {
t.Fatalf("Decode of invalid string should be empty, got %#v", d)
}
}
func TestEncrytedKeysStorage_BadPolicy(t *testing.T) {
policy := NewPolicy(PolicyConfig{
Name: "metadata",

View File

@ -4,6 +4,7 @@ import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"fmt"
"math"
"net"
@ -118,7 +119,7 @@ func (c *Client) GetUserBindDN(cfg *ConfigEntry, conn Connection, username strin
}
result, err := conn.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Scope: ldap.ScopeWholeSubtree,
Filter: filter,
SizeLimit: math.MaxInt32,
})
@ -153,7 +154,7 @@ func (c *Client) GetUserDN(cfg *ConfigEntry, conn Connection, bindDN string) (st
}
result, err := conn.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Scope: ldap.ScopeWholeSubtree,
Filter: filter,
SizeLimit: math.MaxInt32,
})
@ -170,36 +171,15 @@ func (c *Client) GetUserDN(cfg *ConfigEntry, conn Connection, bindDN string) (st
return userDN, nil
}
/*
* getLdapGroups queries LDAP and returns a slice describing the set of groups the authenticated user is a member of.
*
* The search query is constructed according to cfg.GroupFilter, and run in context of cfg.GroupDN.
* Groups will be resolved from the query results by following the attribute defined in cfg.GroupAttr.
*
* cfg.GroupFilter is a go template and is compiled with the following context: [UserDN, Username]
* UserDN - The DN of the authenticated user
* Username - The Username of the authenticated user
*
* Example:
* cfg.GroupFilter = "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
* cfg.GroupDN = "OU=Groups,DC=myorg,DC=com"
* cfg.GroupAttr = "cn"
*
* NOTE - If cfg.GroupFilter is empty, no query is performed and an empty result slice is returned.
*
*/
func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string, username string) ([]string, error) {
// retrieve the groups in a string/bool map as a structure to avoid duplicates inside
ldapMap := make(map[string]bool)
func (c *Client) performLdapFilterGroupsSearch(cfg *ConfigEntry, conn Connection, userDN string, username string) ([]*ldap.Entry, error) {
if cfg.GroupFilter == "" {
c.Logger.Warn("groupfilter is empty, will not query server")
return make([]string, 0), nil
return make([]*ldap.Entry, 0), nil
}
if cfg.GroupDN == "" {
c.Logger.Warn("groupdn is empty, will not query server")
return make([]string, 0), nil
return make([]*ldap.Entry, 0), nil
}
// If groupfilter was defined, resolve it as a Go template and use the query for
@ -233,7 +213,7 @@ func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string,
result, err := conn.Search(&ldap.SearchRequest{
BaseDN: cfg.GroupDN,
Scope: 2, // subtree
Scope: ldap.ScopeWholeSubtree,
Filter: renderedQuery.String(),
Attributes: []string{
cfg.GroupAttr,
@ -244,7 +224,130 @@ func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string,
return nil, errwrap.Wrapf("LDAP search failed: {{err}}", err)
}
for _, e := range result.Entries {
return result.Entries, nil
}
func sidBytesToString(b []byte) (string, error) {
reader := bytes.NewReader(b)
var revision, subAuthorityCount uint8
var identifierAuthorityParts [3]uint16
if err := binary.Read(reader, binary.LittleEndian, &revision); err != nil {
return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading Revision: {{err}}", b), err)
}
if err := binary.Read(reader, binary.LittleEndian, &subAuthorityCount); err != nil {
return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading SubAuthorityCount: {{err}}", b), err)
}
if err := binary.Read(reader, binary.BigEndian, &identifierAuthorityParts); err != nil {
return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading IdentifierAuthority: {{err}}", b), err)
}
identifierAuthority := (uint64(identifierAuthorityParts[0]) << 32) + (uint64(identifierAuthorityParts[1]) << 16) + uint64(identifierAuthorityParts[2])
subAuthority := make([]uint32, subAuthorityCount)
if err := binary.Read(reader, binary.LittleEndian, &subAuthority); err != nil {
return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading SubAuthority: {{err}}", b), err)
}
result := fmt.Sprintf("S-%d-%d", revision, identifierAuthority)
for _, subAuthorityPart := range subAuthority {
result += fmt.Sprintf("-%d", subAuthorityPart)
}
return result, nil
}
func (c *Client) performLdapTokenGroupsSearch(cfg *ConfigEntry, conn Connection, userDN string) ([]*ldap.Entry, error) {
result, err := conn.Search(&ldap.SearchRequest{
BaseDN: userDN,
Scope: ldap.ScopeBaseObject,
Filter: "(objectClass=*)",
Attributes: []string{
"tokenGroups",
},
SizeLimit: 1,
})
if err != nil {
return nil, errwrap.Wrapf("LDAP search failed: {{err}}", err)
}
if len(result.Entries) == 0 {
c.Logger.Warn("unable to read object for group attributes", "userdn", userDN, "groupattr", cfg.GroupAttr)
return make([]*ldap.Entry, 0), nil
}
userEntry := result.Entries[0]
groupAttrValues := userEntry.GetRawAttributeValues("tokenGroups")
groupEntries := make([]*ldap.Entry, 0, len(groupAttrValues))
for _, sidBytes := range groupAttrValues {
sidString, err := sidBytesToString(sidBytes)
if err != nil {
c.Logger.Warn("unable to read sid", "err", err)
continue
}
groupResult, err := conn.Search(&ldap.SearchRequest{
BaseDN: fmt.Sprintf("<SID=%s>", sidString),
Scope: ldap.ScopeBaseObject,
Filter: "(objectClass=*)",
Attributes: []string{
"1.1", // RFC no attributes
},
SizeLimit: 1,
})
if err != nil {
c.Logger.Warn("unable to read the group sid", "sid", sidString)
continue
}
if len(groupResult.Entries) == 0 {
c.Logger.Warn("unable to find the group", "sid", sidString)
continue
}
groupEntries = append(groupEntries, groupResult.Entries[0])
}
return groupEntries, nil
}
/*
* getLdapGroups queries LDAP and returns a slice describing the set of groups the authenticated user is a member of.
*
* If cfg.UseTokenGroups is true then the search is performed directly on the userDN.
* The values of those attributes are converted to string SIDs, and then looked up to get ldap.Entry objects.
* Otherwise, the search query is constructed according to cfg.GroupFilter, and run in context of cfg.GroupDN.
* Groups will be resolved from the query results by following the attribute defined in cfg.GroupAttr.
*
* cfg.GroupFilter is a go template and is compiled with the following context: [UserDN, Username]
* UserDN - The DN of the authenticated user
* Username - The Username of the authenticated user
*
* Example:
* cfg.GroupFilter = "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
* cfg.GroupDN = "OU=Groups,DC=myorg,DC=com"
* cfg.GroupAttr = "cn"
*
* NOTE - If cfg.GroupFilter is empty, no query is performed and an empty result slice is returned.
*
*/
func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string, username string) ([]string, error) {
var entries []*ldap.Entry
var err error
if cfg.UseTokenGroups {
entries, err = c.performLdapTokenGroupsSearch(cfg, conn, userDN)
} else {
entries, err = c.performLdapFilterGroupsSearch(cfg, conn, userDN, username)
}
if err != nil {
return nil, err
}
// retrieve the groups in a string/bool map as a structure to avoid duplicates inside
ldapMap := make(map[string]bool)
for _, e := range entries {
dn, err := ldap.ParseDN(e.DN)
if err != nil || len(dn.RDNs) == 0 {
continue
@ -265,7 +368,7 @@ func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string,
}
ldapGroups := make([]string, 0, len(ldapMap))
for key, _ := range ldapMap {
for key := range ldapMap {
ldapGroups = append(ldapGroups, key)
}

View File

@ -44,3 +44,20 @@ func TestGetTLSConfigs(t *testing.T) {
t.Fatal("expected TLS min and max version of 771 which corresponds with TLS 1.2 since TLS 1.1 and 1.0 have known vulnerabilities")
}
}
func TestSIDBytesToString(t *testing.T) {
testcases := map[string][]byte{
"S-1-5-21-2127521184-1604012920-1887927527-72713": []byte{0x01, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x15, 0x00, 0x00, 0x00, 0xA0, 0x65, 0xCF, 0x7E, 0x78, 0x4B, 0x9B, 0x5F, 0xE7, 0x7C, 0x87, 0x70, 0x09, 0x1C, 0x01, 0x00},
"S-1-1-0": []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
"S-1-5": []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05},
}
for answer, test := range testcases {
res, err := sidBytesToString(test)
if err != nil {
t.Errorf("Failed to conver %#v: %s", test, err)
} else if answer != res {
t.Errorf("Failed to convert %#v: %s != %s", test, res, answer)
}
}
}

View File

@ -115,6 +115,12 @@ Default: cn`,
Type: framework.TypeBool,
Description: "If true, case sensitivity will be used when comparing usernames and groups for matching policies.",
},
"use_token_groups": {
Type: framework.TypeBool,
Default: false,
Description: "If true, use the Active Directory tokenGroups constructed attribute of the user to find the group memberships. This will find all security groups including nested ones.",
},
}
}
@ -231,26 +237,32 @@ func NewConfigEntry(d *framework.FieldData) (*ConfigEntry, error) {
*cfg.CaseSensitiveNames = caseSensitiveNames.(bool)
}
useTokenGroups := d.Get("use_token_groups").(bool)
if useTokenGroups {
cfg.UseTokenGroups = useTokenGroups
}
return cfg, nil
}
type ConfigEntry struct {
Url string `json:"url"`
UserDN string `json:"userdn"`
GroupDN string `json:"groupdn"`
GroupFilter string `json:"groupfilter"`
GroupAttr string `json:"groupattr"`
UPNDomain string `json:"upndomain"`
UserAttr string `json:"userattr"`
Certificate string `json:"certificate"`
InsecureTLS bool `json:"insecure_tls"`
StartTLS bool `json:"starttls"`
BindDN string `json:"binddn"`
BindPassword string `json:"bindpass"`
DenyNullBind bool `json:"deny_null_bind"`
DiscoverDN bool `json:"discoverdn"`
TLSMinVersion string `json:"tls_min_version"`
TLSMaxVersion string `json:"tls_max_version"`
Url string `json:"url"`
UserDN string `json:"userdn"`
GroupDN string `json:"groupdn"`
GroupFilter string `json:"groupfilter"`
GroupAttr string `json:"groupattr"`
UPNDomain string `json:"upndomain"`
UserAttr string `json:"userattr"`
Certificate string `json:"certificate"`
InsecureTLS bool `json:"insecure_tls"`
StartTLS bool `json:"starttls"`
BindDN string `json:"binddn"`
BindPassword string `json:"bindpass"`
DenyNullBind bool `json:"deny_null_bind"`
DiscoverDN bool `json:"discoverdn"`
TLSMinVersion string `json:"tls_min_version"`
TLSMaxVersion string `json:"tls_max_version"`
UseTokenGroups bool `json:"use_token_groups"`
// This json tag deviates from snake case because there was a past issue
// where the tag was being ignored, causing it to be jsonified as "CaseSensitiveNames".
@ -267,21 +279,22 @@ func (c *ConfigEntry) Map() map[string]interface{} {
func (c *ConfigEntry) PasswordlessMap() map[string]interface{} {
m := map[string]interface{}{
"url": c.Url,
"userdn": c.UserDN,
"groupdn": c.GroupDN,
"groupfilter": c.GroupFilter,
"groupattr": c.GroupAttr,
"upndomain": c.UPNDomain,
"userattr": c.UserAttr,
"certificate": c.Certificate,
"insecure_tls": c.InsecureTLS,
"starttls": c.StartTLS,
"binddn": c.BindDN,
"deny_null_bind": c.DenyNullBind,
"discoverdn": c.DiscoverDN,
"tls_min_version": c.TLSMinVersion,
"tls_max_version": c.TLSMaxVersion,
"url": c.Url,
"userdn": c.UserDN,
"groupdn": c.GroupDN,
"groupfilter": c.GroupFilter,
"groupattr": c.GroupAttr,
"upndomain": c.UPNDomain,
"userattr": c.UserAttr,
"certificate": c.Certificate,
"insecure_tls": c.InsecureTLS,
"starttls": c.StartTLS,
"binddn": c.BindDN,
"deny_null_bind": c.DenyNullBind,
"discoverdn": c.DiscoverDN,
"tls_min_version": c.TLSMinVersion,
"tls_max_version": c.TLSMaxVersion,
"use_token_groups": c.UseTokenGroups,
}
if c.CaseSensitiveNames != nil {
m["case_sensitive_names"] = *c.CaseSensitiveNames

View File

@ -1,6 +1,9 @@
package ldaputil
import "testing"
import (
"encoding/json"
"testing"
)
func TestCertificateValidation(t *testing.T) {
// certificate should default to "" without error if it doesn't exist
@ -25,6 +28,18 @@ func TestCertificateValidation(t *testing.T) {
}
}
func TestUseTokenGroupsDefault(t *testing.T) {
config := testConfig()
if config.UseTokenGroups {
t.Errorf("expected false UseTokenGroups but got %t", config.UseTokenGroups)
}
config = testJSONConfig(t)
if config.UseTokenGroups {
t.Errorf("expected false UseTokenGroups from JSON but got %t", config.UseTokenGroups)
}
}
func testConfig() *ConfigEntry {
return &ConfigEntry{
Url: "ldap://138.91.247.105",
@ -36,6 +51,14 @@ func testConfig() *ConfigEntry {
}
}
func testJSONConfig(t *testing.T) *ConfigEntry {
config := new(ConfigEntry)
if err := json.Unmarshal(jsonConfig, config); err != nil {
t.Fatal(err)
}
return config
}
const validCertificate = `
-----BEGIN CERTIFICATE-----
MIIF7zCCA9egAwIBAgIJAOY2qjn64Qq5MA0GCSqGSIb3DQEBCwUAMIGNMQswCQYD
@ -72,3 +95,14 @@ d6TqelcRw9WnDsb9IPxRwaXhvGljnYVAgXXlJEI/6nxj2T4wdmL1LWAr6C7DuWGz
Beq3QOqp2+dga36IzQybzPQ8QtotrpSJ3q82zztEvyWiJ7E=
-----END CERTIFICATE-----
`
var jsonConfig = []byte(`
{
"url": "ldap://138.91.247.105",
"userdn": "example,com",
"binddn": "kitty",
"bindpass": "cats",
"tls_max_version": "tls12",
"tls_min_version": "tls12"
}
`)

View File

@ -0,0 +1,107 @@
package namespace
import (
"context"
"errors"
"strings"
)
type nsContext struct {
context.Context
// Note: this is currently not locked because we think all uses will take
// place within a single goroutine. If that isn't the case, this should be
// protected by an atomic.Value.
cachedNS *Namespace
}
type contextValues struct{}
const (
RootNamespaceID = "root"
)
var (
contextNamespace contextValues = struct{}{}
ErrNoNamespace error = errors.New("no namespace")
)
type Namespace struct {
ID string `json:"id"`
Path string `json:"path"`
}
func New(id, path string) *Namespace {
return &Namespace{
ID: id,
Path: path,
}
}
func (n *Namespace) HasParent(possibleParent *Namespace) bool {
switch {
case n.Path == "":
return false
case possibleParent.Path == "":
return true
default:
return strings.HasPrefix(n.Path, possibleParent.Path)
}
}
func (n *Namespace) TrimmedPath(path string) string {
return strings.TrimPrefix(path, n.Path)
}
func ContextWithNamespace(ctx context.Context, ns *Namespace) context.Context {
nsCtx := context.WithValue(ctx, contextNamespace, ns)
return &nsContext{
Context: nsCtx,
cachedNS: ns,
}
}
func FromContext(ctx context.Context) (*Namespace, error) {
if ctx == nil {
return nil, errors.New("context was nil")
}
nsCtx, ok := ctx.(*nsContext)
if ok {
if nsCtx.cachedNS != nil {
return nsCtx.cachedNS, nil
}
}
ns := ctx.Value(contextNamespace)
if ns == nil {
return nil, ErrNoNamespace
}
if ok {
nsCtx.cachedNS = ns.(*Namespace)
}
return ns.(*Namespace), nil
}
func TestContext() context.Context {
return ContextWithNamespace(context.Background(), New(RootNamespaceID, ""))
}
// Canonicalize trims any prefix '/' and adds a trailing '/' to the
// provided string
func Canonicalize(nsPath string) string {
if nsPath == "" {
return ""
}
// Canonicalize the path to not have a '/' prefix
nsPath = strings.TrimPrefix(nsPath, "/")
// Canonicalize the path to always having a '/' suffix
if !strings.HasSuffix(nsPath, "/") {
nsPath += "/"
}
return nsPath
}

View File

@ -28,7 +28,7 @@ func ParseDurationSecond(in interface{}) (time.Duration, error) {
}
var err error
// Look for a suffix otherwise its a plain second value
if strings.HasSuffix(inp, "s") || strings.HasSuffix(inp, "m") || strings.HasSuffix(inp, "h") {
if strings.HasSuffix(inp, "s") || strings.HasSuffix(inp, "m") || strings.HasSuffix(inp, "h") || strings.HasSuffix(inp, "ms") {
dur, err = time.ParseDuration(inp)
if err != nil {
return dur, err

View File

@ -43,9 +43,9 @@ func StrListSubset(super, sub []string) bool {
return true
}
// Parses a comma separated list of strings into a slice of strings.
// The return slice will be sorted and will not contain duplicate or
// empty items.
// ParseDedupAndSortStrings parses a comma separated list of strings
// into a slice of strings. The return slice will be sorted and will
// not contain duplicate or empty items.
func ParseDedupAndSortStrings(input string, sep string) []string {
input = strings.TrimSpace(input)
parsed := []string{}
@ -56,9 +56,10 @@ func ParseDedupAndSortStrings(input string, sep string) []string {
return RemoveDuplicates(strings.Split(input, sep), false)
}
// Parses a comma separated list of strings into a slice of strings.
// The return slice will be sorted and will not contain duplicate or
// empty items. The values will be converted to lower case.
// ParseDedupLowercaseAndSortStrings parses a comma separated list of
// strings into a slice of strings. The return slice will be sorted and
// will not contain duplicate or empty items. The values will be converted
// to lower case.
func ParseDedupLowercaseAndSortStrings(input string, sep string) []string {
input = strings.TrimSpace(input)
parsed := []string{}
@ -69,8 +70,8 @@ func ParseDedupLowercaseAndSortStrings(input string, sep string) []string {
return RemoveDuplicates(strings.Split(input, sep), true)
}
// Parses a comma separated list of `<key>=<value>` tuples into a
// map[string]string.
// ParseKeyValues parses a comma separated list of `<key>=<value>` tuples
// into a map[string]string.
func ParseKeyValues(input string, out map[string]string, sep string) error {
if out == nil {
return fmt.Errorf("'out is nil")
@ -97,8 +98,8 @@ func ParseKeyValues(input string, out map[string]string, sep string) error {
return nil
}
// Parses arbitrary <key,value> tuples. The input can be one of
// the following:
// ParseArbitraryKeyValues parses arbitrary <key,value> tuples. The input
// can be one of the following:
// * JSON string
// * Base64 encoded JSON string
// * Comma separated list of `<key>=<value>` pairs
@ -144,8 +145,8 @@ func ParseArbitraryKeyValues(input string, out map[string]string, sep string) er
return nil
}
// Parses a `sep`-separated list of strings into a
// []string.
// ParseStringSlice parses a `sep`-separated list of strings into a
// []string with surrounding whitespace removed.
//
// The output will always be a valid slice but may be of length zero.
func ParseStringSlice(input string, sep string) []string {
@ -157,14 +158,14 @@ func ParseStringSlice(input string, sep string) []string {
splitStr := strings.Split(input, sep)
ret := make([]string, len(splitStr))
for i, val := range splitStr {
ret[i] = val
ret[i] = strings.TrimSpace(val)
}
return ret
}
// Parses arbitrary string slice. The input can be one of
// the following:
// ParseArbitraryStringSlice parses arbitrary string slice. The input
// can be one of the following:
// * JSON string
// * Base64 encoded JSON string
// * `sep` separated list of values
@ -215,8 +216,9 @@ func TrimStrings(items []string) []string {
return ret
}
// Removes duplicate and empty elements from a slice of strings. This also may
// convert the items in the slice to lower case and returns a sorted slice.
// RemoveDuplicates removes duplicate and empty elements from a slice of
// strings. This also may convert the items in the slice to lower case and
// returns a sorted slice.
func RemoveDuplicates(items []string, lowercase bool) []string {
itemsMap := map[string]bool{}
for _, item := range items {
@ -230,7 +232,7 @@ func RemoveDuplicates(items []string, lowercase bool) []string {
itemsMap[item] = true
}
items = make([]string, 0, len(itemsMap))
for item, _ := range itemsMap {
for item := range itemsMap {
items = append(items, item)
}
sort.Strings(items)
@ -260,10 +262,10 @@ func EquivalentSlices(a, b []string) bool {
// Now we'll build our checking slices
var sortedA, sortedB []string
for keyA, _ := range mapA {
for keyA := range mapA {
sortedA = append(sortedA, keyA)
}
for keyB, _ := range mapB {
for keyB := range mapB {
sortedB = append(sortedB, keyB)
}
sort.Strings(sortedA)
@ -299,6 +301,8 @@ func StrListDelete(s []string, d string) []string {
return s
}
// GlobbedStringsMatch compares item to val with support for a leading and/or
// trailing wildcard '*' in item.
func GlobbedStringsMatch(item, val string) bool {
if len(item) < 2 {
return val == item
@ -325,3 +329,20 @@ func AppendIfMissing(slice []string, i string) []string {
}
return append(slice, i)
}
// MergeSlices adds an arbitrary number of slices together, uniquely
func MergeSlices(args ...[]string) []string {
all := map[string]struct{}{}
for _, slice := range args {
for _, v := range slice {
all[v] = struct{}{}
}
}
result := make([]string, 0, len(all))
for k, _ := range all {
result = append(result, k)
}
sort.Strings(result)
return result
}

View File

@ -423,3 +423,39 @@ func TestStrUtil_RemoveDuplicates(t *testing.T) {
}
}
}
func TestStrUtil_ParseStringSlice(t *testing.T) {
type tCase struct {
input string
sep string
expect []string
}
tCases := []tCase{
tCase{"", "", []string{}},
tCase{" ", ",", []string{}},
tCase{", ", ",", []string{"", ""}},
tCase{"a", ",", []string{"a"}},
tCase{" a, b, c ", ",", []string{"a", "b", "c"}},
tCase{" a; b; c ", ";", []string{"a", "b", "c"}},
tCase{" a :: b :: c ", "::", []string{"a", "b", "c"}},
}
for _, tc := range tCases {
actual := ParseStringSlice(tc.input, tc.sep)
if !reflect.DeepEqual(actual, tc.expect) {
t.Fatalf("Bad testcase %#v, expected %v, got %v", tc, tc.expect, actual)
}
}
}
func TestStrUtil_MergeSlices(t *testing.T) {
res := MergeSlices([]string{"a", "c", "d"}, []string{}, []string{"c", "f", "a"}, nil, []string{"foo"})
expect := []string{"a", "c", "d", "f", "foo"}
if !reflect.DeepEqual(res, expect) {
t.Fatalf("expected %v, got %v", expect, res)
}
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"runtime"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/version"
)
@ -23,7 +24,24 @@ var (
)
// String returns the consistent user-agent string for Vault.
//
// e.g. Vault/0.10.4 (+https://www.vaultproject.io/; go1.10.1)
func String() string {
return fmt.Sprintf("Vault/%s (+%s; %s)",
versionFunc(), projectURL, rt)
}
// PluginString is usable by plugins to return a user-agent string reflecting
// the running Vault version and an optional plugin name.
//
// e.g. Vault/0.10.4 (+https://www.vaultproject.io/; azure-auth; go1.10.1)
func PluginString(env *logical.PluginEnvironment, pluginName string) string {
var name string
if pluginName != "" {
name = pluginName + "; "
}
return fmt.Sprintf("Vault/%s (+%s; %s%s)",
env.VaultVersion, projectURL, name, rt)
}

View File

@ -2,6 +2,8 @@ package useragent
import (
"testing"
"github.com/hashicorp/vault/logical"
)
func TestUserAgent(t *testing.T) {
@ -16,3 +18,27 @@ func TestUserAgent(t *testing.T) {
t.Errorf("expected %q to be %q", act, exp)
}
}
func TestUserAgentPlugin(t *testing.T) {
projectURL = "https://vault-test.com"
rt = "go5.0"
env := &logical.PluginEnvironment{
VaultVersion: "1.2.3",
}
pluginName := "azure-auth"
act := PluginString(env, pluginName)
exp := "Vault/1.2.3 (+https://vault-test.com; azure-auth; go5.0)"
if exp != act {
t.Errorf("expected %q to be %q", act, exp)
}
pluginName = ""
act = PluginString(env, pluginName)
exp = "Vault/1.2.3 (+https://vault-test.com; go5.0)"
if exp != act {
t.Errorf("expected %q to be %q", act, exp)
}
}

View File

@ -90,14 +90,11 @@ func Handler(props *vault.HandlerProperties) http.Handler {
mux.Handle("/v1/sys/rekey-recovery-key/init", handleRequestForwarding(core, handleSysRekeyInit(core, true)))
mux.Handle("/v1/sys/rekey-recovery-key/update", handleRequestForwarding(core, handleSysRekeyUpdate(core, true)))
mux.Handle("/v1/sys/rekey-recovery-key/verify", handleRequestForwarding(core, handleSysRekeyVerify(core, true)))
mux.Handle("/v1/sys/wrapping/lookup", handleRequestForwarding(core, handleLogical(core, false, wrappingVerificationFunc)))
mux.Handle("/v1/sys/wrapping/rewrap", handleRequestForwarding(core, handleLogical(core, false, wrappingVerificationFunc)))
mux.Handle("/v1/sys/wrapping/unwrap", handleRequestForwarding(core, handleLogical(core, false, wrappingVerificationFunc)))
for _, path := range injectDataIntoTopRoutes {
mux.Handle(path, handleRequestForwarding(core, handleLogical(core, true, nil)))
}
mux.Handle("/v1/sys/", handleRequestForwarding(core, handleLogical(core, false, nil)))
mux.Handle("/v1/", handleRequestForwarding(core, handleLogical(core, false, nil)))
mux.Handle("/v1/sys/wrapping/lookup", handleRequestForwarding(core, handleLogical(core, wrappingVerificationFunc)))
mux.Handle("/v1/sys/wrapping/rewrap", handleRequestForwarding(core, handleLogical(core, wrappingVerificationFunc)))
mux.Handle("/v1/sys/wrapping/unwrap", handleRequestForwarding(core, handleLogical(core, wrappingVerificationFunc)))
mux.Handle("/v1/sys/", handleRequestForwarding(core, handleLogical(core, nil)))
mux.Handle("/v1/", handleRequestForwarding(core, handleLogical(core, nil)))
if core.UIEnabled() == true {
if uiBuiltIn {
mux.Handle("/ui/", http.StripPrefix("/ui/", gziphandler.GzipHandler(handleUIHeaders(core, handleUI(http.FileServer(&UIAssetWrapper{FileSystem: assetFS()}))))))
@ -113,7 +110,7 @@ func Handler(props *vault.HandlerProperties) http.Handler {
// Wrap the help wrapped handler with another layer with a generic
// handler
genericWrappedHandler := wrapGenericHandler(corsWrappedHandler, props.MaxRequestSize)
genericWrappedHandler := wrapGenericHandler(corsWrappedHandler, props.MaxRequestSize, props.MaxRequestDuration)
// Wrap the handler with PrintablePathCheckHandler to check for non-printable
// characters in the request path.
@ -128,20 +125,27 @@ func Handler(props *vault.HandlerProperties) http.Handler {
// wrapGenericHandler wraps the handler with an extra layer of handler where
// tasks that should be commonly handled for all the requests and/or responses
// are performed.
func wrapGenericHandler(h http.Handler, maxRequestSize int64) http.Handler {
func wrapGenericHandler(h http.Handler, maxRequestSize int64, maxRequestDuration time.Duration) http.Handler {
if maxRequestDuration == 0 {
maxRequestDuration = vault.DefaultMaxRequestDuration
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set the Cache-Control header for all the responses returned
// by Vault
w.Header().Set("Cache-Control", "no-store")
// Add a context and put the request limit for this handler in it
// Start with the request context
ctx := r.Context()
var cancelFunc context.CancelFunc
// Add our timeout
ctx, cancelFunc = context.WithTimeout(ctx, maxRequestDuration)
// Add a size limiter if desired
if maxRequestSize > 0 {
ctx := context.WithValue(r.Context(), "max_request_size", maxRequestSize)
h.ServeHTTP(w, r.WithContext(ctx))
} else {
h.ServeHTTP(w, r)
ctx = context.WithValue(ctx, "max_request_size", maxRequestSize)
}
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
cancelFunc()
return
})
}
@ -432,7 +436,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle
// request is a helper to perform a request and properly exit in the
// case of an error.
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) {
resp, err := core.HandleRequest(r)
resp, err := core.HandleRequest(rawReq.Context(), r)
if errwrap.Contains(err, consts.ErrStandby.Error()) {
respondStandby(core, w, rawReq.URL)
return resp, false
@ -585,27 +589,3 @@ func respondOk(w http.ResponseWriter, body interface{}) {
type ErrorResponse struct {
Errors []string `json:"errors"`
}
var injectDataIntoTopRoutes = []string{
"/v1/sys/audit",
"/v1/sys/audit/",
"/v1/sys/audit-hash/",
"/v1/sys/auth",
"/v1/sys/auth/",
"/v1/sys/config/cors",
"/v1/sys/config/auditing/request-headers/",
"/v1/sys/config/auditing/request-headers",
"/v1/sys/capabilities",
"/v1/sys/capabilities-accessor",
"/v1/sys/capabilities-self",
"/v1/sys/key-status",
"/v1/sys/mounts",
"/v1/sys/mounts/",
"/v1/sys/policy",
"/v1/sys/policy/",
"/v1/sys/rekey/backup",
"/v1/sys/rekey/recovery-key-backup",
"/v1/sys/remount",
"/v1/sys/rotate",
"/v1/sys/wrapping/wrap",
}

View File

@ -230,58 +230,6 @@ func TestSysMounts_headerAuth(t *testing.T) {
"options": interface{}(nil),
},
},
"secret/": map[string]interface{}{
"description": "key/value secret storage",
"type": "kv",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{"version": "1"},
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
"type": "system",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
"type": "cubbyhole",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": true,
"seal_wrap": false,
"options": interface{}(nil),
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -291,7 +239,6 @@ func TestSysMounts_headerAuth(t *testing.T) {
if v.(map[string]interface{})["accessor"] == "" {
t.Fatalf("no accessor from %s", k)
}
expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}

View File

@ -37,7 +37,7 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, req *http.Request) {
Connection: getConnection(req),
})
resp, err := core.HandleRequest(lreq)
resp, err := core.HandleRequest(req.Context(), lreq)
if err != nil {
respondErrorCommon(w, lreq, resp, err)
return

View File

@ -28,75 +28,79 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
return nil, http.StatusNotFound, nil
}
var data map[string]interface{}
// Determine the operation
var op logical.Operation
switch r.Method {
case "DELETE":
op = logical.DeleteOperation
case "GET":
op = logical.ReadOperation
// Need to call ParseForm to get query params loaded
queryVals := r.URL.Query()
var list bool
var err error
listStr := queryVals.Get("list")
if listStr != "" {
list, err := strconv.ParseBool(listStr)
list, err = strconv.ParseBool(listStr)
if err != nil {
return nil, http.StatusBadRequest, nil
}
if list {
op = logical.ListOperation
if !strings.HasSuffix(path, "/") {
path += "/"
}
}
}
if !list {
getData := map[string]interface{}{}
for k, v := range r.URL.Query() {
// Skip the help key as this is a reserved parameter
if k == "help" {
continue
}
switch {
case len(v) == 0:
case len(v) == 1:
getData[k] = v[0]
default:
getData[k] = v
}
}
if len(getData) > 0 {
data = getData
}
}
case "POST", "PUT":
op = logical.UpdateOperation
// Parse the request if we can
if op == logical.UpdateOperation {
err := parseRequest(r, w, &data)
if err == io.EOF {
data = nil
err = nil
}
if err != nil {
return nil, http.StatusBadRequest, err
}
}
case "LIST":
op = logical.ListOperation
case "OPTIONS":
default:
return nil, http.StatusMethodNotAllowed, nil
}
if op == logical.ListOperation {
if !strings.HasSuffix(path, "/") {
path += "/"
}
}
// Parse the request if we can
var data map[string]interface{}
if op == logical.UpdateOperation {
err := parseRequest(r, w, &data)
if err == io.EOF {
data = nil
err = nil
}
if err != nil {
return nil, http.StatusBadRequest, err
}
}
// If we are a read operation, try and parse any parameters
if op == logical.ReadOperation {
getData := map[string]interface{}{}
for k, v := range r.URL.Query() {
// Skip the help key as this is a reserved parameter
if k == "help" {
continue
}
switch {
case len(v) == 0:
case len(v) == 1:
getData[k] = v[0]
default:
getData[k] = v
}
}
if len(getData) > 0 {
data = getData
}
case "OPTIONS":
default:
return nil, http.StatusMethodNotAllowed, nil
}
var err error
@ -122,7 +126,7 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
return req, 0, nil
}
func handleLogical(core *vault.Core, injectDataIntoTopLevel bool, prepareRequestCallback PrepareRequestFunc) http.Handler {
func handleLogical(core *vault.Core, prepareRequestCallback PrepareRequestFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, statusCode, err := buildLogicalRequest(core, w, r)
if err != nil || statusCode != 0 {
@ -151,11 +155,11 @@ func handleLogical(core *vault.Core, injectDataIntoTopLevel bool, prepareRequest
}
// Build the proper response
respondLogical(w, r, req, injectDataIntoTopLevel, resp)
respondLogical(w, r, req, resp)
})
}
func respondLogical(w http.ResponseWriter, r *http.Request, req *logical.Request, injectDataIntoTopLevel bool, resp *logical.Response) {
func respondLogical(w http.ResponseWriter, r *http.Request, req *logical.Request, resp *logical.Response) {
var httpResp *logical.HTTPResponse
var ret interface{}
@ -190,13 +194,6 @@ func respondLogical(w http.ResponseWriter, r *http.Request, req *logical.Request
}
ret = httpResp
if injectDataIntoTopLevel {
injector := logical.HTTPSysInjector{
Response: httpResp,
}
ret = injector
}
}
// Respond

View File

@ -318,7 +318,7 @@ func TestLogical_RespondWithStatusCode(t *testing.T) {
}
w := httptest.NewRecorder()
respondLogical(w, nil, nil, false, resp404)
respondLogical(w, nil, nil, resp404)
if w.Code != 404 {
t.Fatalf("Bad Status code: %d", w.Code)

View File

@ -38,13 +38,6 @@ func TestSysAudit(t *testing.T) {
"local": false,
},
},
"noop/": map[string]interface{}{
"path": "noop/",
"type": "noop",
"description": "",
"options": map[string]interface{}{},
"local": false,
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -119,7 +112,6 @@ func TestSysAuditHash(t *testing.T) {
"data": map[string]interface{}{
"hash": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
},
"hash": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)

View File

@ -38,18 +38,6 @@ func TestSysAuth(t *testing.T) {
"options": interface{}(nil),
},
},
"token/": map[string]interface{}{
"description": "token based credentials",
"type": "token",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -59,7 +47,6 @@ func TestSysAuth(t *testing.T) {
if v.(map[string]interface{})["accessor"] == "" {
t.Fatalf("no accessor from %s", k)
}
expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
@ -116,30 +103,6 @@ func TestSysEnableAuth(t *testing.T) {
"options": interface{}(nil),
},
},
"foo/": map[string]interface{}{
"description": "foo",
"type": "noop",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{},
},
"token/": map[string]interface{}{
"description": "token based credentials",
"type": "token",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -149,7 +112,6 @@ func TestSysEnableAuth(t *testing.T) {
if v.(map[string]interface{})["accessor"] == "" {
t.Fatalf("no accessor from %s", k)
}
expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
@ -197,18 +159,6 @@ func TestSysDisableAuth(t *testing.T) {
"options": interface{}(nil),
},
},
"token/": map[string]interface{}{
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"plugin_name": "",
},
"description": "token based credentials",
"type": "token",
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -218,7 +168,6 @@ func TestSysDisableAuth(t *testing.T) {
if v.(map[string]interface{})["accessor"] == "" {
t.Fatalf("no accessor from %s", k)
}
expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
@ -264,11 +213,6 @@ func TestSysTuneAuth_nonHMACKeys(t *testing.T) {
"audit_non_hmac_request_keys": []interface{}{"foo"},
"audit_non_hmac_response_keys": []interface{}{"bar"},
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
"audit_non_hmac_request_keys": []interface{}{"foo"},
"audit_non_hmac_response_keys": []interface{}{"bar"},
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]
@ -303,9 +247,6 @@ func TestSysTuneAuth_nonHMACKeys(t *testing.T) {
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]
@ -337,9 +278,6 @@ func TestSysTuneAuth_showUIMount(t *testing.T) {
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]
@ -371,10 +309,6 @@ func TestSysTuneAuth_showUIMount(t *testing.T) {
"force_no_cache": false,
"listing_visibility": "unauth",
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
"listing_visibility": "unauth",
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]

View File

@ -61,9 +61,6 @@ func TestSysConfigCors(t *testing.T) {
"allowed_origins": []interface{}{addr},
"allowed_headers": expectedHeaders,
},
"enabled": true,
"allowed_origins": []interface{}{addr},
"allowed_headers": expectedHeaders,
}
testResponseStatus(t, resp, 200)

View File

@ -79,58 +79,6 @@ func TestSysMounts(t *testing.T) {
"options": interface{}(nil),
},
},
"secret/": map[string]interface{}{
"description": "key/value secret storage",
"type": "kv",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{"version": "1"},
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
"type": "system",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
"type": "cubbyhole",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": true,
"seal_wrap": false,
"options": interface{}(nil),
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -139,7 +87,6 @@ func TestSysMounts(t *testing.T) {
if v.(map[string]interface{})["accessor"] == "" {
t.Fatalf("no accessor from %s", k)
}
expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
@ -240,71 +187,6 @@ func TestSysMount(t *testing.T) {
"options": interface{}(nil),
},
},
"foo/": map[string]interface{}{
"description": "foo",
"type": "kv",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{"version": "1"},
},
"secret/": map[string]interface{}{
"description": "key/value secret storage",
"type": "kv",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{"version": "1"},
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
"type": "system",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
"type": "cubbyhole",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": true,
"seal_wrap": false,
"options": interface{}(nil),
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -313,7 +195,6 @@ func TestSysMount(t *testing.T) {
if v.(map[string]interface{})["accessor"] == "" {
t.Fatalf("no accessor from %s", k)
}
expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
@ -433,71 +314,6 @@ func TestSysRemount(t *testing.T) {
"options": interface{}(nil),
},
},
"bar/": map[string]interface{}{
"description": "foo",
"type": "kv",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{},
},
"secret/": map[string]interface{}{
"description": "key/value secret storage",
"type": "kv",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{"version": "1"},
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
"type": "system",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
"type": "cubbyhole",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": true,
"seal_wrap": false,
"options": interface{}(nil),
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -506,7 +322,6 @@ func TestSysRemount(t *testing.T) {
if v.(map[string]interface{})["accessor"] == "" {
t.Fatalf("no accessor from %s", k)
}
expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
@ -594,58 +409,6 @@ func TestSysUnmount(t *testing.T) {
"options": interface{}(nil),
},
},
"secret/": map[string]interface{}{
"description": "key/value secret storage",
"type": "kv",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{"version": "1"},
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
"type": "system",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
"type": "cubbyhole",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": true,
"seal_wrap": false,
"options": interface{}(nil),
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -654,7 +417,6 @@ func TestSysUnmount(t *testing.T) {
if v.(map[string]interface{})["accessor"] == "" {
t.Fatalf("no accessor from %s", k)
}
expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
@ -701,10 +463,6 @@ func TestSysTuneMount_Options(t *testing.T) {
"force_no_cache": false,
"options": map[string]interface{}{"test": "true"},
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
"options": map[string]interface{}{"test": "true"},
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]
@ -736,10 +494,6 @@ func TestSysTuneMount_Options(t *testing.T) {
"force_no_cache": false,
"options": map[string]interface{}{"test": "true"},
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
"options": map[string]interface{}{"test": "true"},
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]
@ -837,71 +591,6 @@ func TestSysTuneMount(t *testing.T) {
"options": interface{}(nil),
},
},
"foo/": map[string]interface{}{
"description": "foo",
"type": "kv",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{},
},
"secret/": map[string]interface{}{
"description": "key/value secret storage",
"type": "kv",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{"version": "1"},
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
"type": "system",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
"type": "cubbyhole",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": true,
"seal_wrap": false,
"options": interface{}(nil),
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
@ -910,7 +599,6 @@ func TestSysTuneMount(t *testing.T) {
if v.(map[string]interface{})["accessor"] == "" {
t.Fatalf("no accessor from %s", k)
}
expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
@ -1037,71 +725,6 @@ func TestSysTuneMount(t *testing.T) {
"options": interface{}(nil),
},
},
"foo/": map[string]interface{}{
"description": "foo",
"type": "kv",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("259196400"),
"max_lease_ttl": json.Number("259200000"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{"version": "1"},
},
"secret/": map[string]interface{}{
"description": "key/value secret storage",
"type": "kv",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": map[string]interface{}{"version": "1"},
},
"sys/": map[string]interface{}{
"description": "system endpoints used for control, policy and debugging",
"type": "system",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
"cubbyhole/": map[string]interface{}{
"description": "per-token private secret storage",
"type": "cubbyhole",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": true,
"seal_wrap": false,
"options": interface{}(nil),
},
"identity/": map[string]interface{}{
"description": "identity store",
"type": "identity",
"config": map[string]interface{}{
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"plugin_name": "",
},
"local": false,
"seal_wrap": false,
"options": interface{}(nil),
},
}
testResponseStatus(t, resp, 200)
@ -1111,7 +734,6 @@ func TestSysTuneMount(t *testing.T) {
if v.(map[string]interface{})["accessor"] == "" {
t.Fatalf("no accessor from %s", k)
}
expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
@ -1135,10 +757,6 @@ func TestSysTuneMount(t *testing.T) {
"force_no_cache": false,
"options": map[string]interface{}{"version": "1"},
},
"default_lease_ttl": json.Number("259196400"),
"max_lease_ttl": json.Number("259200000"),
"force_no_cache": false,
"options": map[string]interface{}{"version": "1"},
}
testResponseStatus(t, resp, 200)
@ -1170,10 +788,6 @@ func TestSysTuneMount(t *testing.T) {
"force_no_cache": false,
"options": map[string]interface{}{"version": "1"},
},
"default_lease_ttl": json.Number("40"),
"max_lease_ttl": json.Number("80"),
"force_no_cache": false,
"options": map[string]interface{}{"version": "1"},
}
testResponseStatus(t, resp, 200)
@ -1267,12 +881,6 @@ func TestSysTuneMount_nonHMACKeys(t *testing.T) {
"audit_non_hmac_response_keys": []interface{}{"bar"},
"options": map[string]interface{}{"version": "1"},
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
"audit_non_hmac_request_keys": []interface{}{"foo"},
"audit_non_hmac_response_keys": []interface{}{"bar"},
"options": map[string]interface{}{"version": "1"},
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]
@ -1309,10 +917,6 @@ func TestSysTuneMount_nonHMACKeys(t *testing.T) {
"force_no_cache": false,
"options": map[string]interface{}{"version": "1"},
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
"options": map[string]interface{}{"version": "1"},
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]
@ -1345,10 +949,6 @@ func TestSysTuneMount_listingVisibility(t *testing.T) {
"force_no_cache": false,
"options": map[string]interface{}{"version": "1"},
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
"options": map[string]interface{}{"version": "1"},
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]
@ -1381,11 +981,6 @@ func TestSysTuneMount_listingVisibility(t *testing.T) {
"listing_visibility": "unauth",
"options": map[string]interface{}{"version": "1"},
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
"listing_visibility": "unauth",
"options": map[string]interface{}{"version": "1"},
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]
@ -1425,11 +1020,6 @@ func TestSysTuneMount_passthroughRequestHeaders(t *testing.T) {
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"X-Vault-Foo"},
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"options": map[string]interface{}{"version": "1"},
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"X-Vault-Foo"},
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]
@ -1461,10 +1051,6 @@ func TestSysTuneMount_passthroughRequestHeaders(t *testing.T) {
"force_no_cache": false,
"options": map[string]interface{}{"version": "1"},
},
"default_lease_ttl": json.Number("2764800"),
"max_lease_ttl": json.Number("2764800"),
"force_no_cache": false,
"options": map[string]interface{}{"version": "1"},
}
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]

Some files were not shown because too many files have changed in this diff Show More