Merge remote-tracking branch 'upstream/master'

This commit is contained in:
mwoolsey 2016-11-20 18:31:55 -08:00
commit 3e72e50fa5
320 changed files with 55380 additions and 18209 deletions

View file

@ -1,3 +1,40 @@
## Next (Unreleased)
DEPRECATIONS/CHANGES:
* http: impose a maximum request size of 32MB to prevent a denial of service
with arbitrarily large requests. [GH-2108]
IMPROVEMENTS:
* auth/github: Policies can now be assigned to users as well as to teams
[GH-2079]
* cli: Set the number of retries on 500 down to 0 by default (no retrying). It
can be very confusing to users when there is a pause while the retries
happen if they haven't explicitly set it. With request forwarding the need
for this is lessened anyways. [GH-2093]
* core: Response wrapping is now allowed to be specified by backend responses
(requires backends gaining support) [GH-2088]
* secret/consul: Added listing functionality to roles [GH-2065]
* secret/postgresql: Added `revocation_sql` parameter on the role endpoint to
enable customization of user revocation SQL statements [GH-2033]
* secret/transit: Add listing of keys [GH-1987]
BUG FIXES:
* auth/approle: Creating the index for the role_id properly [GH-2004]
* auth/aws-ec2: Handle the case of multiple upgrade attempts when setting the
instance-profile ARN [GH-2035]
* api/unwrap, command/unwrap: Fix compatibility of `unwrap` command with Vault
0.6.1 and older [GH-2014]
* api/unwrap, command/unwrap: Fix error when no client token exists [GH-2077]
* command/ssh: Use temporary file for identity and ensure its deletion before
the command returns [GH-2016]
* core: Fix bug where a failure to come up as active node (e.g. if an audit
backend failed) could lead to deadlock [GH-2083]
* physical/mysql: Fix potential crash during setup due to a query failure
[GH-2105]
## 0.6.2 (October 5, 2016) ## 0.6.2 (October 5, 2016)
DEPRECATIONS/CHANGES: DEPRECATIONS/CHANGES:
@ -140,11 +177,11 @@ DEPRECATIONS/CHANGES:
* Status codes for sealed/uninitialized Vaults have changed to `503`/`501` * Status codes for sealed/uninitialized Vaults have changed to `503`/`501`
respectively. See the [version-specific upgrade respectively. See the [version-specific upgrade
guide](https://www.vaultproject.io/docs/install/upgrade-to-0.6.1.html) for guide](https://www.vaultproject.io/docs/install/upgrade-to-0.6.1.html) for
more details. more details.
* Root tokens (tokens with the `root` policy) can no longer be created except * Root tokens (tokens with the `root` policy) can no longer be created except
by another root token or the `generate-root` endpoint. by another root token or the `generate-root` endpoint.
* Issued certificates from the `pki` backend against new roles created or * Issued certificates from the `pki` backend against new roles created or
modified after upgrading will contain a set of default key usages. modified after upgrading will contain a set of default key usages.
* The `dynamodb` physical data store no longer supports HA by default. It has * The `dynamodb` physical data store no longer supports HA by default. It has
some non-ideal behavior around failover that was causing confusion. See the some non-ideal behavior around failover that was causing confusion. See the
[documentation](https://www.vaultproject.io/docs/config/index.html#ha_enabled) [documentation](https://www.vaultproject.io/docs/config/index.html#ha_enabled)
@ -214,7 +251,7 @@ IMPROVEMENTS:
the request portion of the response. [GH-1650] the request portion of the response. [GH-1650]
* auth/aws-ec2: Added a new constraint `bound_account_id` to the role * auth/aws-ec2: Added a new constraint `bound_account_id` to the role
[GH-1523] [GH-1523]
* auth/aws-ec2: Added a new constraint `bound_iam_role_arn` to the role * auth/aws-ec2: Added a new constraint `bound_iam_role_arn` to the role
[GH-1522] [GH-1522]
* auth/aws-ec2: Added `ttl` field for the role [GH-1703] * auth/aws-ec2: Added `ttl` field for the role [GH-1703]
* auth/ldap, secret/cassandra, physical/consul: Clients with `tls.Config` * auth/ldap, secret/cassandra, physical/consul: Clients with `tls.Config`
@ -258,7 +295,7 @@ IMPROVEMENTS:
configuration [GH-1581] configuration [GH-1581]
* secret/mssql,mysql,postgresql: Reading of connection settings is supported * secret/mssql,mysql,postgresql: Reading of connection settings is supported
in all the sql backends [GH-1515] in all the sql backends [GH-1515]
* secret/mysql: Added optional maximum idle connections value to MySQL * secret/mysql: Added optional maximum idle connections value to MySQL
connection configuration [GH-1635] connection configuration [GH-1635]
* secret/mysql: Use a combination of the role name and token display name in * secret/mysql: Use a combination of the role name and token display name in
generated user names and allow the length to be controlled [GH-1604] generated user names and allow the length to be controlled [GH-1604]
@ -601,7 +638,7 @@ BUG FIXES:
during renewals [GH-1176] during renewals [GH-1176]
## 0.5.1 (February 25th, 2016) ## 0.5.1 (February 25th, 2016)
DEPRECATIONS/CHANGES: DEPRECATIONS/CHANGES:
* RSA keys less than 2048 bits are no longer supported in the PKI backend. * RSA keys less than 2048 bits are no longer supported in the PKI backend.
@ -631,7 +668,7 @@ IMPROVEMENTS:
* api/health: Add the server's time in UTC to health responses [GH-1117] * api/health: Add the server's time in UTC to health responses [GH-1117]
* command/rekey and command/generate-root: These now return the status at * command/rekey and command/generate-root: These now return the status at
attempt initialization time, rather than requiring a separate fetch for the attempt initialization time, rather than requiring a separate fetch for the
nonce [GH-1054] nonce [GH-1054]
* credential/cert: Don't require root/sudo tokens for the `certs/` and `crls/` * credential/cert: Don't require root/sudo tokens for the `certs/` and `crls/`
paths; use normal ACL behavior instead [GH-468] paths; use normal ACL behavior instead [GH-468]
* credential/github: The validity of the token used for login will be checked * credential/github: The validity of the token used for login will be checked
@ -761,7 +798,7 @@ FEATURES:
documentation](https://vaultproject.io/docs/config/index.html) for details. documentation](https://vaultproject.io/docs/config/index.html) for details.
[GH-945] [GH-945]
* **STS Support in AWS Secret Backend**: You can now use the AWS secret * **STS Support in AWS Secret Backend**: You can now use the AWS secret
backend to fetch STS tokens rather than IAM users. [GH-927] backend to fetch STS tokens rather than IAM users. [GH-927]
* **Speedups in the transit backend**: The `transit` backend has gained a * **Speedups in the transit backend**: The `transit` backend has gained a
cache, and now loads only the working set of keys (e.g. from the cache, and now loads only the working set of keys (e.g. from the
`min_decryption_version` to the current key version) into its working set. `min_decryption_version` to the current key version) into its working set.

View file

@ -60,4 +60,8 @@ bootstrap:
go get -u $$tool; \ go get -u $$tool; \
done done
proto:
protoc -I helper/forwarding -I vault -I ../../.. vault/request_forwarding_service.proto --go_out=plugins=grpc:vault
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
.PHONY: bin default generate test vet bootstrap .PHONY: bin default generate test vet bootstrap

View file

@ -48,7 +48,7 @@ type Config struct {
redirectSetup sync.Once redirectSetup sync.Once
// MaxRetries controls the maximum number of times to retry when a 5xx error // MaxRetries controls the maximum number of times to retry when a 5xx error
// occurs. Set to 0 or less to disable retrying. // occurs. Set to 0 or less to disable retrying. Defaults to 0.
MaxRetries int MaxRetries int
} }
@ -99,8 +99,6 @@ func DefaultConfig() *Config {
config.Address = v config.Address = v
} }
config.MaxRetries = pester.DefaultClient.MaxRetries
return config return config
} }

View file

@ -120,8 +120,12 @@ func (c *Logical) Delete(path string) (*Secret, error) {
func (c *Logical) Unwrap(wrappingToken string) (*Secret, error) { func (c *Logical) Unwrap(wrappingToken string) (*Secret, error) {
var data map[string]interface{} var data map[string]interface{}
if wrappingToken != "" { if wrappingToken != "" {
data = map[string]interface{}{ if c.c.Token() == "" {
"token": wrappingToken, c.c.SetToken(wrappingToken)
} else if wrappingToken != c.c.Token() {
data = map[string]interface{}{
"token": wrappingToken,
}
} }
} }
@ -146,7 +150,7 @@ func (c *Logical) Unwrap(wrappingToken string) (*Secret, error) {
return nil, nil return nil, nil
} }
if wrappingToken == "" { if wrappingToken != "" {
origToken := c.c.Token() origToken := c.c.Token()
defer c.c.SetToken(origToken) defer c.c.SetToken(origToken)
c.c.SetToken(wrappingToken) c.c.SetToken(wrappingToken)

View file

@ -64,9 +64,18 @@ func (f *AuditFormatter) FormatRequest(
if err := Hash(config.Salt, auth); err != nil { if err := Hash(config.Salt, auth); err != nil {
return err return err
} }
// Cache and restore accessor in the request
var clientTokenAccessor string
if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" {
clientTokenAccessor = req.ClientTokenAccessor
}
if err := Hash(config.Salt, req); err != nil { if err := Hash(config.Salt, req); err != nil {
return err return err
} }
if clientTokenAccessor != "" {
req.ClientTokenAccessor = clientTokenAccessor
}
} }
// If auth is nil, make an empty one // If auth is nil, make an empty one
@ -89,13 +98,14 @@ func (f *AuditFormatter) FormatRequest(
}, },
Request: AuditRequest{ Request: AuditRequest{
ID: req.ID, ID: req.ID,
ClientToken: req.ClientToken, ClientToken: req.ClientToken,
Operation: req.Operation, ClientTokenAccessor: req.ClientTokenAccessor,
Path: req.Path, Operation: req.Operation,
Data: req.Data, Path: req.Path,
RemoteAddr: getRemoteAddr(req), Data: req.Data,
WrapTTL: int(req.WrapTTL / time.Second), RemoteAddr: getRemoteAddr(req),
WrapTTL: int(req.WrapTTL / time.Second),
}, },
} }
@ -167,9 +177,17 @@ func (f *AuditFormatter) FormatResponse(
auth.Accessor = accessor auth.Accessor = accessor
} }
// Cache and restore accessor in the request
var clientTokenAccessor string
if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" {
clientTokenAccessor = req.ClientTokenAccessor
}
if err := Hash(config.Salt, req); err != nil { if err := Hash(config.Salt, req); err != nil {
return err return err
} }
if clientTokenAccessor != "" {
req.ClientTokenAccessor = clientTokenAccessor
}
// Cache and restore accessor in the response // Cache and restore accessor in the response
accessor = "" accessor = ""
@ -241,13 +259,14 @@ func (f *AuditFormatter) FormatResponse(
}, },
Request: AuditRequest{ Request: AuditRequest{
ID: req.ID, ID: req.ID,
ClientToken: req.ClientToken, ClientToken: req.ClientToken,
Operation: req.Operation, ClientTokenAccessor: req.ClientTokenAccessor,
Path: req.Path, Operation: req.Operation,
Data: req.Data, Path: req.Path,
RemoteAddr: getRemoteAddr(req), Data: req.Data,
WrapTTL: int(req.WrapTTL / time.Second), RemoteAddr: getRemoteAddr(req),
WrapTTL: int(req.WrapTTL / time.Second),
}, },
Response: AuditResponse{ Response: AuditResponse{
@ -286,13 +305,14 @@ type AuditResponseEntry struct {
} }
type AuditRequest struct { type AuditRequest struct {
ID string `json:"id"` ID string `json:"id"`
Operation logical.Operation `json:"operation"` Operation logical.Operation `json:"operation"`
ClientToken string `json:"client_token"` ClientToken string `json:"client_token"`
Path string `json:"path"` ClientTokenAccessor string `json:"client_token_accessor"`
Data map[string]interface{} `json:"data"` Path string `json:"path"`
RemoteAddr string `json:"remote_address"` Data map[string]interface{} `json:"data"`
WrapTTL int `json:"wrap_ttl"` RemoteAddr string `json:"remote_address"`
WrapTTL int `json:"wrap_ttl"`
} }
type AuditResponse struct { type AuditResponse struct {

View file

@ -32,7 +32,7 @@ func TestFormatJSONx_formatRequest(t *testing.T) {
}, },
errors.New("this is an error"), errors.New("this is an error"),
"", "",
`<json:object name="auth"><json:string name="accessor"></json:string><json:string name="client_token"></json:string><json:string name="display_name"></json:string><json:null name="metadata" /><json:array name="policies"><json:string>root</json:string></json:array></json:object><json:string name="error">this is an error</json:string><json:object name="request"><json:string name="client_token"></json:string><json:null name="data" /><json:string name="id"></json:string><json:string name="operation">update</json:string><json:string name="path">/foo</json:string><json:string name="remote_address">127.0.0.1</json:string><json:number name="wrap_ttl">60</json:number></json:object><json:string name="type">request</json:string>`, `<json:object name="auth"><json:string name="accessor"></json:string><json:string name="client_token"></json:string><json:string name="display_name"></json:string><json:null name="metadata" /><json:array name="policies"><json:string>root</json:string></json:array></json:object><json:string name="error">this is an error</json:string><json:object name="request"><json:string name="client_token"></json:string><json:string name="client_token_accessor"></json:string><json:null name="data" /><json:string name="id"></json:string><json:string name="operation">update</json:string><json:string name="path">/foo</json:string><json:string name="remote_address">127.0.0.1</json:string><json:number name="wrap_ttl">60</json:number></json:object><json:string name="type">request</json:string>`,
}, },
} }

View file

@ -49,6 +49,10 @@ func Hash(salter *salt.Salt, raw interface{}) error {
s.ClientToken = fn(s.ClientToken) s.ClientToken = fn(s.ClientToken)
} }
if s.ClientTokenAccessor != "" {
s.ClientTokenAccessor = fn(s.ClientTokenAccessor)
}
data, err := HashStructure(s.Data, fn) data, err := HashStructure(s.Data, fn)
if err != nil { if err != nil {
return err return err

View file

@ -121,7 +121,7 @@ will expire. Defaults to 0 meaning that the the secret_id is of unlimited use.`,
"secret_id_ttl": &framework.FieldSchema{ "secret_id_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond, Type: framework.TypeDurationSecond,
Description: `Duration in seconds after which the issued SecretID should expire. Defaults Description: `Duration in seconds after which the issued SecretID should expire. Defaults
to 0, in which case the value will fall back to the system/mount defaults.`, to 0, meaning no expiration.`,
}, },
"token_ttl": &framework.FieldSchema{ "token_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond, Type: framework.TypeDurationSecond,
@ -249,7 +249,7 @@ addresses which can perform the login operation`,
"secret_id_ttl": &framework.FieldSchema{ "secret_id_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond, Type: framework.TypeDurationSecond,
Description: `Duration in seconds after which the issued SecretID should expire. Defaults Description: `Duration in seconds after which the issued SecretID should expire. Defaults
to 0, in which case the value will fall back to the system/mount defaults.`, to 0, meaning no expiration.`,
}, },
}, },
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
@ -640,7 +640,7 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
} }
// If previousRoleID is still intact, don't create another one // If previousRoleID is still intact, don't create another one
if previousRoleID != "" { if previousRoleID != "" && previousRoleID == role.RoleID {
return nil return nil
} }

View file

@ -111,6 +111,77 @@ func TestAppRole_RoleConstraints(t *testing.T) {
} }
} }
func TestAppRole_RoleIDUpdate(t *testing.T) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)
roleData := map[string]interface{}{
"role_id": "role-id-123",
"policies": "a,b",
"secret_id_num_uses": 10,
"secret_id_ttl": 300,
"token_ttl": 400,
"token_max_ttl": 500,
}
roleReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "role/testrole1",
Storage: storage,
Data: roleData,
}
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
roleIDUpdateReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "role/testrole1/role-id",
Storage: storage,
Data: map[string]interface{}{
"role_id": "customroleid",
},
}
resp, err = b.HandleRequest(roleIDUpdateReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
secretIDReq := &logical.Request{
Operation: logical.UpdateOperation,
Storage: storage,
Path: "role/testrole1/secret-id",
}
resp, err = b.HandleRequest(secretIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
secretID := resp.Data["secret_id"].(string)
loginData := map[string]interface{}{
"role_id": "customroleid",
"secret_id": secretID,
}
loginReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "login",
Storage: storage,
Data: loginData,
Connection: &logical.Connection{
RemoteAddr: "127.0.0.1",
},
}
resp, err = b.HandleRequest(loginReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Auth == nil {
t.Fatalf("expected a non-nil auth object in the response")
}
}
func TestAppRole_RoleIDUniqueness(t *testing.T) { func TestAppRole_RoleIDUniqueness(t *testing.T) {
var resp *logical.Response var resp *logical.Response
var err error var err error

View file

@ -207,11 +207,6 @@ func (b *backend) nonLockedAWSRole(s logical.Storage, roleName string) (*awsRole
// Check if the value held by role ARN field is actually an instance profile ARN // Check if the value held by role ARN field is actually an instance profile ARN
if result.BoundIamRoleARN != "" && strings.Contains(result.BoundIamRoleARN, ":instance-profile/") { if result.BoundIamRoleARN != "" && strings.Contains(result.BoundIamRoleARN, ":instance-profile/") {
// For sanity
if result.BoundIamInstanceProfileARN != "" {
return nil, fmt.Errorf("bound_iam_role_arn contains instance profile ARN and bound_iam_instance_profile_arn is non empty")
}
// If yes, move it to the correct field // If yes, move it to the correct field
result.BoundIamInstanceProfileARN = result.BoundIamRoleARN result.BoundIamInstanceProfileARN = result.BoundIamRoleARN

View file

@ -14,12 +14,22 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Backend() *backend { func Backend() *backend {
var b backend var b backend
b.Map = &framework.PolicyMap{ b.TeamMap = &framework.PolicyMap{
PathMap: framework.PathMap{ PathMap: framework.PathMap{
Name: "teams", Name: "teams",
}, },
DefaultKey: "default", DefaultKey: "default",
} }
b.UserMap = &framework.PolicyMap{
PathMap: framework.PathMap{
Name: "users",
},
DefaultKey: "default",
}
allPaths := append(b.TeamMap.Paths(), b.UserMap.Paths()...)
b.Backend = &framework.Backend{ b.Backend = &framework.Backend{
Help: backendHelp, Help: backendHelp,
@ -32,7 +42,7 @@ func Backend() *backend {
Paths: append([]*framework.Path{ Paths: append([]*framework.Path{
pathConfig(&b), pathConfig(&b),
pathLogin(&b), pathLogin(&b),
}, b.Map.Paths()...), }, allPaths...),
AuthRenew: b.pathLoginRenew, AuthRenew: b.pathLoginRenew,
} }
@ -43,7 +53,9 @@ func Backend() *backend {
type backend struct { type backend struct {
*framework.Backend *framework.Backend
Map *framework.PolicyMap TeamMap *framework.PolicyMap
UserMap *framework.PolicyMap
} }
// Client returns the GitHub client to communicate to GitHub via the // Client returns the GitHub client to communicate to GitHub via the

View file

@ -110,17 +110,21 @@ func TestBackend_basic(t *testing.T) {
Backend: b, Backend: b,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testAccStepConfig(t, false), testAccStepConfig(t, false),
testAccMap(t, "default", "root"), testAccMap(t, "default", "fakepol"),
testAccMap(t, "oWnErs", "root"), testAccMap(t, "oWnErs", "fakepol"),
testAccLogin(t, []string{"root"}), testAccLogin(t, []string{"default", "fakepol"}),
testAccStepConfig(t, true), testAccStepConfig(t, true),
testAccMap(t, "default", "root"), testAccMap(t, "default", "fakepol"),
testAccMap(t, "oWnErs", "root"), testAccMap(t, "oWnErs", "fakepol"),
testAccLogin(t, []string{"root"}), testAccLogin(t, []string{"default", "fakepol"}),
testAccStepConfigWithBaseURL(t), testAccStepConfigWithBaseURL(t),
testAccMap(t, "default", "root"), testAccMap(t, "default", "fakepol"),
testAccMap(t, "oWnErs", "root"), testAccMap(t, "oWnErs", "fakepol"),
testAccLogin(t, []string{"root"}), testAccLogin(t, []string{"default", "fakepol"}),
testAccMap(t, "default", "fakepol"),
testAccStepConfig(t, true),
mapUserToPolicy(t, os.Getenv("GITHUB_USER"), "userpolicy"),
testAccLogin(t, []string{"default", "fakepol", "userpolicy"}),
}, },
}) })
} }
@ -174,7 +178,17 @@ func testAccMap(t *testing.T, k string, v string) logicaltest.TestStep {
} }
} }
func testAccLogin(t *testing.T, keys []string) logicaltest.TestStep { func mapUserToPolicy(t *testing.T, k string, v string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "map/users/" + k,
Data: map[string]interface{}{
"value": v,
},
}
}
func testAccLogin(t *testing.T, policies []string) logicaltest.TestStep {
return logicaltest.TestStep{ return logicaltest.TestStep{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
Path: "login", Path: "login",
@ -183,6 +197,6 @@ func testAccLogin(t *testing.T, keys []string) logicaltest.TestStep {
}, },
Unauthenticated: true, Unauthenticated: true,
Check: logicaltest.TestCheckAuth(keys), Check: logicaltest.TestCheckAuth(policies),
} }
} }

View file

@ -194,14 +194,22 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
} }
} }
policiesList, err := b.Map.Policies(req.Storage, teamNames...) groupPoliciesList, err := b.TeamMap.Policies(req.Storage, teamNames...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
userPoliciesList, err := b.UserMap.Policies(req.Storage, []string{*user.Login}...)
if err != nil {
return nil, nil, err
}
return &verifyCredentialsResp{ return &verifyCredentialsResp{
User: user, User: user,
Org: org, Org: org,
Policies: policiesList, Policies: append(groupPoliciesList, userPoliciesList...),
}, nil, nil }, nil, nil
} }

View file

@ -7,6 +7,7 @@ import (
"github.com/go-ldap/ldap" "github.com/go-ldap/ldap"
"github.com/hashicorp/vault/helper/mfa" "github.com/hashicorp/vault/helper/mfa"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -158,6 +159,9 @@ func (b *backend) Login(req *logical.Request, username string, password string)
} }
} }
// Policies from each group may overlap
policies = strutil.RemoveDuplicates(policies)
if len(policies) == 0 { if len(policies) == 0 {
errStr := "user is not a member of any authorized group" errStr := "user is not a member of any authorized group"
if len(ldapResponse.Warnings()) > 0 { if len(ldapResponse.Warnings()) > 0 {

View file

@ -100,6 +100,12 @@ Default: cn`,
Default: "tls12", Default: "tls12",
Description: "Minimum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'", Description: "Minimum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
}, },
"tls_max_version": &framework.FieldSchema{
Type: framework.TypeString,
Default: "tls12",
Description: "Maximum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
},
}, },
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
@ -225,6 +231,19 @@ func (b *backend) newConfigEntry(d *framework.FieldData) (*ConfigEntry, error) {
return nil, fmt.Errorf("invalid 'tls_min_version'") return nil, fmt.Errorf("invalid 'tls_min_version'")
} }
cfg.TLSMaxVersion = d.Get("tls_max_version").(string)
if cfg.TLSMaxVersion == "" {
return nil, fmt.Errorf("failed to get 'tls_max_version' value")
}
_, ok = tlsutil.TLSLookup[cfg.TLSMaxVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_max_version'")
}
if cfg.TLSMaxVersion < cfg.TLSMinVersion {
return nil, fmt.Errorf("'tls_max_version' must be greater than or equal to 'tls_min_version'")
}
startTLS := d.Get("starttls").(bool) startTLS := d.Get("starttls").(bool)
if startTLS { if startTLS {
cfg.StartTLS = startTLS cfg.StartTLS = startTLS
@ -280,6 +299,7 @@ type ConfigEntry struct {
BindPassword string `json:"bindpass" structs:"bindpass" mapstructure:"bindpass"` BindPassword string `json:"bindpass" structs:"bindpass" mapstructure:"bindpass"`
DiscoverDN bool `json:"discoverdn" structs:"discoverdn" mapstructure:"discoverdn"` DiscoverDN bool `json:"discoverdn" structs:"discoverdn" mapstructure:"discoverdn"`
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
TLSMaxVersion string `json:"tls_max_version" structs:"tls_max_version" mapstructure:"tls_max_version"`
} }
func (c *ConfigEntry) GetTLSConfig(host string) (*tls.Config, error) { func (c *ConfigEntry) GetTLSConfig(host string) (*tls.Config, error) {
@ -295,6 +315,14 @@ func (c *ConfigEntry) GetTLSConfig(host string) (*tls.Config, error) {
tlsConfig.MinVersion = tlsMinVersion tlsConfig.MinVersion = tlsMinVersion
} }
if c.TLSMaxVersion != "" {
tlsMaxVersion, ok := tlsutil.TLSLookup[c.TLSMaxVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_max_version' in config")
}
tlsConfig.MaxVersion = tlsMaxVersion
}
if c.InsecureTLS { if c.InsecureTLS {
tlsConfig.InsecureSkipVerify = true tlsConfig.InsecureSkipVerify = true
} }

View file

@ -14,6 +14,7 @@ func Backend() *backend {
b.Backend = &framework.Backend{ b.Backend = &framework.Backend{
Paths: []*framework.Path{ Paths: []*framework.Path{
pathConfigAccess(), pathConfigAccess(),
pathListRoles(&b),
pathRoles(), pathRoles(),
pathToken(&b), pathToken(&b),
}, },

View file

@ -293,7 +293,10 @@ func TestBackend_crud(t *testing.T) {
Backend: b, Backend: b,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testAccStepWritePolicy(t, "test", testPolicy, ""), testAccStepWritePolicy(t, "test", testPolicy, ""),
testAccStepWritePolicy(t, "test2", testPolicy, ""),
testAccStepWritePolicy(t, "test3", testPolicy, ""),
testAccStepReadPolicy(t, "test", testPolicy, 0), testAccStepReadPolicy(t, "test", testPolicy, 0),
testAccStepListPolicy(t, []string{"test", "test2", "test3"}),
testAccStepDeletePolicy(t, "test"), testAccStepDeletePolicy(t, "test"),
}, },
}) })
@ -443,6 +446,20 @@ func testAccStepReadPolicy(t *testing.T, name string, policy string, lease time.
} }
} }
func testAccStepListPolicy(t *testing.T, names []string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ListOperation,
Path: "roles/",
Check: func(resp *logical.Response) error {
respKeys := resp.Data["keys"].([]string)
if !reflect.DeepEqual(respKeys, names) {
return fmt.Errorf("mismatch: %#v %#v", respKeys, names)
}
return nil
},
}
}
func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep { func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{ return logicaltest.TestStep{
Operation: logical.DeleteOperation, Operation: logical.DeleteOperation,

View file

@ -9,6 +9,16 @@ import (
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
func pathListRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathRoleList,
},
}
}
func pathRoles() *framework.Path { func pathRoles() *framework.Path {
return &framework.Path{ return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"), Pattern: "roles/" + framework.GenericNameRegex("name"),
@ -47,6 +57,16 @@ Defaults to 'client'.`,
} }
} }
func (b *backend) pathRoleList(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("policy/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func pathRolesRead( func pathRolesRead(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) { req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string) name := d.Get("name").(string)

View file

@ -55,8 +55,9 @@ func (b *backend) pathTokenRead(
return logical.ErrorResponse(err.Error()), nil return logical.ErrorResponse(err.Error()), nil
} }
// Generate a random name for the token // Generate a name for the token
tokenName := fmt.Sprintf("Vault %s %d", req.DisplayName, time.Now().Unix()) tokenName := fmt.Sprintf("Vault %s %s %d", name, req.DisplayName, time.Now().UnixNano())
// Create it // Create it
token, _, err := c.ACL().Create(&api.ACLEntry{ token, _, err := c.ACL().Create(&api.ACLEntry{
Name: tokenName, Name: tokenName,

View file

@ -37,8 +37,11 @@ func pathRoles(b *backend) *framework.Path {
}, },
"revocation_sql": { "revocation_sql": {
Type: framework.TypeString, Type: framework.TypeString,
Description: "SQL string to revoke a user. This is in beta; use with caution.", Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
string, a base64-encoded semicolon-separated string, a serialized JSON string
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
will be substituted.`,
}, },
}, },
@ -90,19 +93,12 @@ func (b *backend) pathRoleRead(
return nil, nil return nil, nil
} }
resp := &logical.Response{ return &logical.Response{
Data: map[string]interface{}{ Data: map[string]interface{}{
"sql": role.SQL, "sql": role.SQL,
"revocation_sql": role.RevocationSQL,
}, },
} }, nil
// TODO: This is separate because this is in beta in 0.6.2, so we don't
// want it to show up in the general case.
if role.RevocationSQL != "" {
resp.Data["revocation_sql"] = role.RevocationSQL
}
return resp, nil
} }
func (b *backend) pathRoleList( func (b *backend) pathRoleList(
@ -193,4 +189,12 @@ Example of a decent SQL query to use:
Note the above user would be able to access everything in schema public. Note the above user would be able to access everything in schema public.
For more complex GRANT clauses, see the PostgreSQL manual. For more complex GRANT clauses, see the PostgreSQL manual.
The "revocation_sql" parameter customizes the SQL string used to revoke a user.
Example of a decent revocation SQL query to use:
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
REVOKE USAGE ON SCHEMA public FROM {{name}};
DROP ROLE IF EXISTS {{name}};
` `

View file

@ -7,8 +7,8 @@ const (
#!/bin/bash #!/bin/bash
# #
# This is a default script which installs or uninstalls an RSA public key to/from # This is a default script which installs or uninstalls an RSA public key to/from
# authoried_keys file in a typical linux machine. # authorized_keys file in a typical linux machine.
# #
# If the platform differs or if the binaries used in this script are not available # If the platform differs or if the binaries used in this script are not available
# in target machine, use the 'install_script' parameter with 'roles/' endpoint to # in target machine, use the 'install_script' parameter with 'roles/' endpoint to
# register a custom script (applicable for Dynamic type only). # register a custom script (applicable for Dynamic type only).
@ -51,7 +51,7 @@ fi
# Create the .ssh directory and authorized_keys file if it does not exist # Create the .ssh directory and authorized_keys file if it does not exist
SSH_DIR=$(dirname $AUTH_KEYS_FILE) SSH_DIR=$(dirname $AUTH_KEYS_FILE)
sudo mkdir -p "$SSH_DIR" sudo mkdir -p "$SSH_DIR"
sudo touch "$AUTH_KEYS_FILE" sudo touch "$AUTH_KEYS_FILE"
# Remove the key from authorized_keys file if it is already present. # Remove the key from authorized_keys file if it is already present.

View file

@ -72,7 +72,7 @@ func (b *backend) pathConfigZeroAddressWrite(req *logical.Request, d *framework.
return nil, err return nil, err
} }
if role == nil { if role == nil {
return logical.ErrorResponse(fmt.Sprintf("Role [%s] does not exist", item)), nil return logical.ErrorResponse(fmt.Sprintf("Role %q does not exist", item)), nil
} }
} }

View file

@ -55,10 +55,10 @@ func (b *backend) pathCredsCreateWrite(
role, err := b.getRole(req.Storage, roleName) role, err := b.getRole(req.Storage, roleName)
if err != nil { if err != nil {
return nil, fmt.Errorf("error retrieving role: %s", err) return nil, fmt.Errorf("error retrieving role: %v", err)
} }
if role == nil { if role == nil {
return logical.ErrorResponse(fmt.Sprintf("Role '%s' not found", roleName)), nil return logical.ErrorResponse(fmt.Sprintf("Role %q not found", roleName)), nil
} }
// username is an optional parameter. // username is an optional parameter.
@ -89,7 +89,7 @@ func (b *backend) pathCredsCreateWrite(
// Validate the IP address // Validate the IP address
ipAddr := net.ParseIP(ipRaw) ipAddr := net.ParseIP(ipRaw)
if ipAddr == nil { if ipAddr == nil {
return logical.ErrorResponse(fmt.Sprintf("Invalid IP '%s'", ipRaw)), nil return logical.ErrorResponse(fmt.Sprintf("Invalid IP %q", ipRaw)), nil
} }
// Check if the IP belongs to the registered list of CIDR blocks under the role // Check if the IP belongs to the registered list of CIDR blocks under the role
@ -97,7 +97,7 @@ func (b *backend) pathCredsCreateWrite(
zeroAddressEntry, err := b.getZeroAddressRoles(req.Storage) zeroAddressEntry, err := b.getZeroAddressRoles(req.Storage)
if err != nil { if err != nil {
return nil, fmt.Errorf("error retrieving zero-address roles: %s", err) return nil, fmt.Errorf("error retrieving zero-address roles: %v", err)
} }
var zeroAddressRoles []string var zeroAddressRoles []string
if zeroAddressEntry != nil { if zeroAddressEntry != nil {
@ -106,7 +106,7 @@ func (b *backend) pathCredsCreateWrite(
err = validateIP(ip, roleName, role.CIDRList, role.ExcludeCIDRList, zeroAddressRoles) err = validateIP(ip, roleName, role.CIDRList, role.ExcludeCIDRList, zeroAddressRoles)
if err != nil { if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error validating IP: %s", err)), nil return logical.ErrorResponse(fmt.Sprintf("Error validating IP: %v", err)), nil
} }
var result *logical.Response var result *logical.Response
@ -171,22 +171,22 @@ func (b *backend) GenerateDynamicCredential(req *logical.Request, role *sshRole,
// Fetch the host key to be used for dynamic key installation // Fetch the host key to be used for dynamic key installation
keyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", role.KeyName)) keyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", role.KeyName))
if err != nil { if err != nil {
return "", "", fmt.Errorf("key '%s' not found. err:%s", role.KeyName, err) return "", "", fmt.Errorf("key %q not found. err: %v", role.KeyName, err)
} }
if keyEntry == nil { if keyEntry == nil {
return "", "", fmt.Errorf("key '%s' not found", role.KeyName) return "", "", fmt.Errorf("key %q not found", role.KeyName)
} }
var hostKey sshHostKey var hostKey sshHostKey
if err := keyEntry.DecodeJSON(&hostKey); err != nil { if err := keyEntry.DecodeJSON(&hostKey); err != nil {
return "", "", fmt.Errorf("error reading the host key: %s", err) return "", "", fmt.Errorf("error reading the host key: %v", err)
} }
// Generate a new RSA key pair with the given key length. // Generate a new RSA key pair with the given key length.
dynamicPublicKey, dynamicPrivateKey, err := generateRSAKeys(role.KeyBits) dynamicPublicKey, dynamicPrivateKey, err := generateRSAKeys(role.KeyBits)
if err != nil { if err != nil {
return "", "", fmt.Errorf("error generating key: %s", err) return "", "", fmt.Errorf("error generating key: %v", err)
} }
if len(role.KeyOptionSpecs) != 0 { if len(role.KeyOptionSpecs) != 0 {
@ -196,7 +196,7 @@ func (b *backend) GenerateDynamicCredential(req *logical.Request, role *sshRole,
// Add the public key to authorized_keys file in target machine // Add the public key to authorized_keys file in target machine
err = b.installPublicKeyInTarget(role.AdminUser, username, ip, role.Port, hostKey.Key, dynamicPublicKey, role.InstallScript, true) err = b.installPublicKeyInTarget(role.AdminUser, username, ip, role.Port, hostKey.Key, dynamicPublicKey, role.InstallScript, true)
if err != nil { if err != nil {
return "", "", fmt.Errorf("error adding public key to authorized_keys file in target") return "", "", fmt.Errorf("failed to add public key to authorized_keys file in target: %v", err)
} }
return dynamicPublicKey, dynamicPrivateKey, nil return dynamicPublicKey, dynamicPrivateKey, nil
} }

View file

@ -32,7 +32,7 @@ func (b *backend) pathLookupWrite(req *logical.Request, d *framework.FieldData)
} }
ip := net.ParseIP(ipAddr) ip := net.ParseIP(ipAddr)
if ip == nil { if ip == nil {
return logical.ErrorResponse(fmt.Sprintf("Invalid IP '%s'", ip.String())), nil return logical.ErrorResponse(fmt.Sprintf("Invalid IP %q", ip.String())), nil
} }
// Get all the roles created in the backend. // Get all the roles created in the backend.

View file

@ -232,7 +232,7 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
} }
keyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", keyName)) keyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", keyName))
if err != nil || keyEntry == nil { if err != nil || keyEntry == nil {
return logical.ErrorResponse(fmt.Sprintf("invalid 'key': '%s'", keyName)), nil return logical.ErrorResponse(fmt.Sprintf("invalid 'key': %q", keyName)), nil
} }
installScript := d.Get("install_script").(string) installScript := d.Get("install_script").(string)

View file

@ -55,10 +55,10 @@ func (b *backend) secretDynamicKeyRevoke(req *logical.Request, d *framework.Fiel
// Fetch the host key using the key name // Fetch the host key using the key name
hostKey, err := b.getKey(req.Storage, intSec.HostKeyName) hostKey, err := b.getKey(req.Storage, intSec.HostKeyName)
if err != nil { if err != nil {
return nil, fmt.Errorf("key '%s' not found error:%s", intSec.HostKeyName, err) return nil, fmt.Errorf("key %q not found error: %v", intSec.HostKeyName, err)
} }
if hostKey == nil { if hostKey == nil {
return nil, fmt.Errorf("key '%s' not found", intSec.HostKeyName) return nil, fmt.Errorf("key %q not found", intSec.HostKeyName)
} }
// Remove the public key from authorized_keys file in target machine // Remove the public key from authorized_keys file in target machine

View file

@ -23,7 +23,7 @@ import (
func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, err error) { func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, err error) {
privateKey, err := rsa.GenerateKey(rand.Reader, keyBits) privateKey, err := rsa.GenerateKey(rand.Reader, keyBits)
if err != nil { if err != nil {
return "", "", fmt.Errorf("error generating RSA key-pair: %s", err) return "", "", fmt.Errorf("error generating RSA key-pair: %v", err)
} }
privateKeyRsa = string(pem.EncodeToMemory(&pem.Block{ privateKeyRsa = string(pem.EncodeToMemory(&pem.Block{
@ -33,7 +33,7 @@ func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, er
sshPublicKey, err := ssh.NewPublicKey(privateKey.Public()) sshPublicKey, err := ssh.NewPublicKey(privateKey.Public())
if err != nil { if err != nil {
return "", "", fmt.Errorf("error generating RSA key-pair: %s", err) return "", "", fmt.Errorf("error generating RSA key-pair: %v", err)
} }
publicKeyRsa = "ssh-rsa " + base64.StdEncoding.EncodeToString(sshPublicKey.Marshal()) publicKeyRsa = "ssh-rsa " + base64.StdEncoding.EncodeToString(sshPublicKey.Marshal())
return return
@ -61,7 +61,7 @@ func (b *backend) installPublicKeyInTarget(adminUser, username, ip string, port
err = comm.Upload(publicKeyFileName, bytes.NewBufferString(dynamicPublicKey), nil) err = comm.Upload(publicKeyFileName, bytes.NewBufferString(dynamicPublicKey), nil)
if err != nil { if err != nil {
return fmt.Errorf("error uploading public key: %s", err) return fmt.Errorf("error uploading public key: %v", err)
} }
// Transfer the script required to install or uninstall the key to the remote // Transfer the script required to install or uninstall the key to the remote
@ -70,14 +70,14 @@ func (b *backend) installPublicKeyInTarget(adminUser, username, ip string, port
scriptFileName := fmt.Sprintf("%s.sh", publicKeyFileName) scriptFileName := fmt.Sprintf("%s.sh", publicKeyFileName)
err = comm.Upload(scriptFileName, bytes.NewBufferString(installScript), nil) err = comm.Upload(scriptFileName, bytes.NewBufferString(installScript), nil)
if err != nil { if err != nil {
return fmt.Errorf("error uploading install script: %s", err) return fmt.Errorf("error uploading install script: %v", err)
} }
// Create a session to run remote command that triggers the script to install // Create a session to run remote command that triggers the script to install
// or uninstall the key. // or uninstall the key.
session, err := comm.NewSession() session, err := comm.NewSession()
if err != nil { if err != nil {
return fmt.Errorf("unable to create SSH Session using public keys: %s", err) return fmt.Errorf("unable to create SSH Session using public keys: %v", err)
} }
if session == nil { if session == nil {
return fmt.Errorf("invalid session object") return fmt.Errorf("invalid session object")
@ -116,15 +116,15 @@ func roleContainsIP(s logical.Storage, roleName string, ip string) (bool, error)
roleEntry, err := s.Get(fmt.Sprintf("roles/%s", roleName)) roleEntry, err := s.Get(fmt.Sprintf("roles/%s", roleName))
if err != nil { if err != nil {
return false, fmt.Errorf("error retrieving role '%s'", err) return false, fmt.Errorf("error retrieving role %v", err)
} }
if roleEntry == nil { if roleEntry == nil {
return false, fmt.Errorf("role '%s' not found", roleName) return false, fmt.Errorf("role %q not found", roleName)
} }
var role sshRole var role sshRole
if err := roleEntry.DecodeJSON(&role); err != nil { if err := roleEntry.DecodeJSON(&role); err != nil {
return false, fmt.Errorf("error decoding role '%s'", roleName) return false, fmt.Errorf("error decoding role %q", roleName)
} }
if matched, err := cidrListContainsIP(ip, role.CIDRList); err != nil { if matched, err := cidrListContainsIP(ip, role.CIDRList); err != nil {
@ -143,7 +143,7 @@ func cidrListContainsIP(ip, cidrList string) (bool, error) {
for _, item := range strings.Split(cidrList, ",") { for _, item := range strings.Split(cidrList, ",") {
_, cidrIPNet, err := net.ParseCIDR(item) _, cidrIPNet, err := net.ParseCIDR(item)
if err != nil { if err != nil {
return false, fmt.Errorf("invalid CIDR entry '%s'", item) return false, fmt.Errorf("invalid CIDR entry %q", item)
} }
if cidrIPNet.Contains(net.ParseIP(ip)) { if cidrIPNet.Contains(net.ParseIP(ip)) {
return true, nil return true, nil

View file

@ -1,6 +1,7 @@
package transit package transit
import ( import (
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -25,6 +26,7 @@ func Backend(conf *logical.BackendConfig) *backend {
b.pathRotate(), b.pathRotate(),
b.pathRewrap(), b.pathRewrap(),
b.pathKeys(), b.pathKeys(),
b.pathListKeys(),
b.pathEncrypt(), b.pathEncrypt(),
b.pathDecrypt(), b.pathDecrypt(),
b.pathDatakey(), b.pathDatakey(),
@ -38,12 +40,12 @@ func Backend(conf *logical.BackendConfig) *backend {
Secrets: []*framework.Secret{}, Secrets: []*framework.Secret{},
} }
b.lm = newLockManager(conf.System.CachingDisabled()) b.lm = keysutil.NewLockManager(conf.System.CachingDisabled())
return &b return &b
} }
type backend struct { type backend struct {
*framework.Backend *framework.Backend
lm *lockManager lm *keysutil.LockManager
} }

View file

@ -12,6 +12,7 @@ import (
"time" "time"
uuid "github.com/hashicorp/go-uuid" uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
logicaltest "github.com/hashicorp/vault/logical/testing" logicaltest "github.com/hashicorp/vault/logical/testing"
@ -27,7 +28,9 @@ func TestBackend_basic(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{ logicaltest.Test(t, logicaltest.TestCase{
Factory: Factory, Factory: Factory,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", false), testAccStepWritePolicy(t, "test", false),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false), testAccStepReadPolicy(t, "test", false, false),
testAccStepEncrypt(t, "test", testPlaintext, decryptData), testAccStepEncrypt(t, "test", testPlaintext, decryptData),
testAccStepDecrypt(t, "test", testPlaintext, decryptData), testAccStepDecrypt(t, "test", testPlaintext, decryptData),
@ -53,7 +56,9 @@ func TestBackend_upsert(t *testing.T) {
Factory: Factory, Factory: Factory,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testAccStepReadPolicy(t, "test", true, false), testAccStepReadPolicy(t, "test", true, false),
testAccStepListPolicy(t, "test", true),
testAccStepEncryptUpsert(t, "test", testPlaintext, decryptData), testAccStepEncryptUpsert(t, "test", testPlaintext, decryptData),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false), testAccStepReadPolicy(t, "test", false, false),
testAccStepDecrypt(t, "test", testPlaintext, decryptData), testAccStepDecrypt(t, "test", testPlaintext, decryptData),
}, },
@ -65,7 +70,9 @@ func TestBackend_datakey(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{ logicaltest.Test(t, logicaltest.TestCase{
Factory: Factory, Factory: Factory,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", false), testAccStepWritePolicy(t, "test", false),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false), testAccStepReadPolicy(t, "test", false, false),
testAccStepWriteDatakey(t, "test", false, 256, dataKeyInfo), testAccStepWriteDatakey(t, "test", false, 256, dataKeyInfo),
testAccStepDecryptDatakey(t, "test", dataKeyInfo), testAccStepDecryptDatakey(t, "test", dataKeyInfo),
@ -80,7 +87,9 @@ func TestBackend_rotation(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{ logicaltest.Test(t, logicaltest.TestCase{
Factory: Factory, Factory: Factory,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", false), testAccStepWritePolicy(t, "test", false),
testAccStepListPolicy(t, "test", false),
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 0, encryptHistory), testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 0, encryptHistory),
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 1, encryptHistory), testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 1, encryptHistory),
testAccStepRotate(t, "test"), // now v2 testAccStepRotate(t, "test"), // now v2
@ -128,6 +137,7 @@ func TestBackend_rotation(t *testing.T) {
testAccStepEnableDeletion(t, "test"), testAccStepEnableDeletion(t, "test"),
testAccStepDeletePolicy(t, "test"), testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", true, false), testAccStepReadPolicy(t, "test", true, false),
testAccStepListPolicy(t, "test", true),
}, },
}) })
} }
@ -137,7 +147,9 @@ func TestBackend_basic_derived(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{ logicaltest.Test(t, logicaltest.TestCase{
Factory: Factory, Factory: Factory,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", true), testAccStepWritePolicy(t, "test", true),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, true), testAccStepReadPolicy(t, "test", false, true),
testAccStepEncryptContext(t, "test", testPlaintext, "my-cool-context", decryptData), testAccStepEncryptContext(t, "test", testPlaintext, "my-cool-context", decryptData),
testAccStepDecrypt(t, "test", testPlaintext, decryptData), testAccStepDecrypt(t, "test", testPlaintext, decryptData),
@ -158,6 +170,42 @@ func testAccStepWritePolicy(t *testing.T, name string, derived bool) logicaltest
} }
} }
func testAccStepListPolicy(t *testing.T, name string, expectNone bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ListOperation,
Path: "keys",
Check: func(resp *logical.Response) error {
if resp == nil {
return fmt.Errorf("missing response")
}
if expectNone {
keysRaw, ok := resp.Data["keys"]
if ok || keysRaw != nil {
return fmt.Errorf("response data when expecting none")
}
return nil
}
if len(resp.Data) == 0 {
return fmt.Errorf("no data returned")
}
var d struct {
Keys []string `mapstructure:"keys"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if len(d.Keys) > 0 && d.Keys[0] != name {
return fmt.Errorf("bad name: %#v", d)
}
if len(d.Keys) != 1 {
return fmt.Errorf("only 1 key expected, %d returned", len(d.Keys))
}
return nil
},
}
}
func testAccStepAdjustPolicy(t *testing.T, name string, minVer int) logicaltest.TestStep { func testAccStepAdjustPolicy(t *testing.T, name string, minVer int) logicaltest.TestStep {
return logicaltest.TestStep{ return logicaltest.TestStep{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
@ -242,7 +290,7 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool)
if d.Name != name { if d.Name != name {
return fmt.Errorf("bad name: %#v", d) return fmt.Errorf("bad name: %#v", d)
} }
if d.Type != KeyType(keyType_AES256_GCM96).String() { if d.Type != keysutil.KeyType(keysutil.KeyType_AES256_GCM96).String() {
return fmt.Errorf("bad key type: %#v", d) return fmt.Errorf("bad key type: %#v", d)
} }
// Should NOT get a key back // Should NOT get a key back
@ -536,13 +584,13 @@ func testAccStepDecryptDatakey(t *testing.T, name string,
func TestKeyUpgrade(t *testing.T) { func TestKeyUpgrade(t *testing.T) {
key, _ := uuid.GenerateRandomBytes(32) key, _ := uuid.GenerateRandomBytes(32)
p := &policy{ p := &keysutil.Policy{
Name: "test", Name: "test",
Key: key, Key: key,
Type: keyType_AES256_GCM96, Type: keysutil.KeyType_AES256_GCM96,
} }
p.migrateKeyToKeysMap() p.MigrateKeyToKeysMap()
if p.Key != nil || if p.Key != nil ||
p.Keys == nil || p.Keys == nil ||
@ -557,18 +605,18 @@ func TestDerivedKeyUpgrade(t *testing.T) {
key, _ := uuid.GenerateRandomBytes(32) key, _ := uuid.GenerateRandomBytes(32)
context, _ := uuid.GenerateRandomBytes(32) context, _ := uuid.GenerateRandomBytes(32)
p := &policy{ p := &keysutil.Policy{
Name: "test", Name: "test",
Key: key, Key: key,
Type: keyType_AES256_GCM96, Type: keysutil.KeyType_AES256_GCM96,
Derived: true, Derived: true,
} }
p.migrateKeyToKeysMap() p.MigrateKeyToKeysMap()
p.upgrade(storage) // Need to run the upgrade code to make the migration stick p.Upgrade(storage) // Need to run the upgrade code to make the migration stick
if p.KDF != kdf_hmac_sha256_counter { if p.KDF != keysutil.Kdf_hmac_sha256_counter {
t.Fatalf("bad KDF value by default; counter val is %d, KDF val is %d, policy is %#v", kdf_hmac_sha256_counter, p.KDF, *p) t.Fatalf("bad KDF value by default; counter val is %d, KDF val is %d, policy is %#v", keysutil.Kdf_hmac_sha256_counter, p.KDF, *p)
} }
derBytesOld, err := p.DeriveKey(context, 1) derBytesOld, err := p.DeriveKey(context, 1)
@ -585,8 +633,8 @@ func TestDerivedKeyUpgrade(t *testing.T) {
t.Fatal("mismatch of same context alg") t.Fatal("mismatch of same context alg")
} }
p.KDF = kdf_hkdf_sha256 p.KDF = keysutil.Kdf_hkdf_sha256
if p.needsUpgrade() { if p.NeedsUpgrade() {
t.Fatal("expected no upgrade needed") t.Fatal("expected no upgrade needed")
} }
@ -645,15 +693,15 @@ func testConvergentEncryptionCommon(t *testing.T, ver int) {
t.Fatalf("bad: expected error response, got %#v", *resp) t.Fatalf("bad: expected error response, got %#v", *resp)
} }
p := &policy{ p := &keysutil.Policy{
Name: "testkey", Name: "testkey",
Type: keyType_AES256_GCM96, Type: keysutil.KeyType_AES256_GCM96,
Derived: true, Derived: true,
ConvergentEncryption: true, ConvergentEncryption: true,
ConvergentVersion: ver, ConvergentVersion: ver,
} }
err = p.rotate(storage) err = p.Rotate(storage)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -929,7 +977,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) {
resp, err := be.pathDecryptWrite(req, fd) resp, err := be.pathDecryptWrite(req, fd)
if err != nil { if err != nil {
// This could well happen since the min version is jumping around // This could well happen since the min version is jumping around
if resp.Data["error"].(string) == ErrTooOld { if resp.Data["error"].(string) == keysutil.ErrTooOld {
continue continue
} }
t.Fatalf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id) t.Fatalf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id)

View file

@ -6,6 +6,7 @@ import (
"sync" "sync"
"github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -116,7 +117,7 @@ func (b *backend) pathEncryptWrite(
} }
// Get the policy // Get the policy
var p *policy var p *keysutil.Policy
var lock *sync.RWMutex var lock *sync.RWMutex
var upserted bool var upserted bool
if req.Operation == logical.CreateOperation { if req.Operation == logical.CreateOperation {
@ -125,17 +126,17 @@ func (b *backend) pathEncryptWrite(
return logical.ErrorResponse("convergent encryption requires derivation to be enabled, so context is required"), nil return logical.ErrorResponse("convergent encryption requires derivation to be enabled, so context is required"), nil
} }
polReq := policyRequest{ polReq := keysutil.PolicyRequest{
storage: req.Storage, Storage: req.Storage,
name: name, Name: name,
derived: len(context) != 0, Derived: len(context) != 0,
convergent: convergent, Convergent: convergent,
} }
keyType := d.Get("type").(string) keyType := d.Get("type").(string)
switch keyType { switch keyType {
case "aes256-gcm96": case "aes256-gcm96":
polReq.keyType = keyType_AES256_GCM96 polReq.KeyType = keysutil.KeyType_AES256_GCM96
case "ecdsa-p256": case "ecdsa-p256":
return logical.ErrorResponse(fmt.Sprintf("key type %v not supported for this operation", keyType)), logical.ErrInvalidRequest return logical.ErrorResponse(fmt.Sprintf("key type %v not supported for this operation", keyType)), logical.ErrInvalidRequest
default: default:

View file

@ -124,7 +124,7 @@ func TestTransit_HMAC(t *testing.T) {
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
// Rotate // Rotate
err = p.rotate(storage) err = p.Rotate(storage)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -5,10 +5,24 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
func (b *backend) pathListKeys() *framework.Path {
return &framework.Path{
Pattern: "keys/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathKeysList,
},
HelpSynopsis: pathPolicyHelpSyn,
HelpDescription: pathPolicyHelpDesc,
}
}
func (b *backend) pathKeys() *framework.Path { func (b *backend) pathKeys() *framework.Path {
return &framework.Path{ return &framework.Path{
Pattern: "keys/" + framework.GenericNameRegex("name"), Pattern: "keys/" + framework.GenericNameRegex("name"),
@ -61,6 +75,16 @@ impact the ciphertext's security.`,
} }
} }
func (b *backend) pathKeysList(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("policy/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func (b *backend) pathPolicyWrite( func (b *backend) pathPolicyWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) { req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string) name := d.Get("name").(string)
@ -72,17 +96,17 @@ func (b *backend) pathPolicyWrite(
return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil
} }
polReq := policyRequest{ polReq := keysutil.PolicyRequest{
storage: req.Storage, Storage: req.Storage,
name: name, Name: name,
derived: derived, Derived: derived,
convergent: convergent, Convergent: convergent,
} }
switch keyType { switch keyType {
case "aes256-gcm96": case "aes256-gcm96":
polReq.keyType = keyType_AES256_GCM96 polReq.KeyType = keysutil.KeyType_AES256_GCM96
case "ecdsa-p256": case "ecdsa-p256":
polReq.keyType = keyType_ECDSA_P256 polReq.KeyType = keysutil.KeyType_ECDSA_P256
default: default:
return logical.ErrorResponse(fmt.Sprintf("unknown key type %v", keyType)), logical.ErrInvalidRequest return logical.ErrorResponse(fmt.Sprintf("unknown key type %v", keyType)), logical.ErrInvalidRequest
} }
@ -135,10 +159,10 @@ func (b *backend) pathPolicyRead(
if p.Derived { if p.Derived {
switch p.KDF { switch p.KDF {
case kdf_hmac_sha256_counter: case keysutil.Kdf_hmac_sha256_counter:
resp.Data["kdf"] = "hmac-sha256-counter" resp.Data["kdf"] = "hmac-sha256-counter"
resp.Data["kdf_mode"] = "hmac-sha256-counter" resp.Data["kdf_mode"] = "hmac-sha256-counter"
case kdf_hkdf_sha256: case keysutil.Kdf_hkdf_sha256:
resp.Data["kdf"] = "hkdf_sha256" resp.Data["kdf"] = "hkdf_sha256"
} }
resp.Data["convergent_encryption"] = p.ConvergentEncryption resp.Data["convergent_encryption"] = p.ConvergentEncryption
@ -148,14 +172,14 @@ func (b *backend) pathPolicyRead(
} }
switch p.Type { switch p.Type {
case keyType_AES256_GCM96: case keysutil.KeyType_AES256_GCM96:
retKeys := map[string]int64{} retKeys := map[string]int64{}
for k, v := range p.Keys { for k, v := range p.Keys {
retKeys[strconv.Itoa(k)] = v.CreationTime retKeys[strconv.Itoa(k)] = v.CreationTime
} }
resp.Data["keys"] = retKeys resp.Data["keys"] = retKeys
case keyType_ECDSA_P256: case keysutil.KeyType_ECDSA_P256:
type ecdsaKey struct { type ecdsaKey struct {
Name string `json:"name"` Name string `json:"name"`
PublicKey string `json:"public_key"` PublicKey string `json:"public_key"`

View file

@ -41,7 +41,7 @@ func (b *backend) pathRotateWrite(
} }
// Rotate the policy // Rotate the policy
err = p.rotate(req.Storage) err = p.Rotate(req.Storage)
return nil, err return nil, err
} }

View file

@ -177,11 +177,11 @@ func TestTransit_SignVerify(t *testing.T) {
signRequest(req, true, "") signRequest(req, true, "")
// Rotate and set min decryption version // Rotate and set min decryption version
err = p.rotate(storage) err = p.Rotate(storage)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = p.rotate(storage) err = p.Rotate(storage)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -820,6 +820,8 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
cfg.CheckManager.Check.ForceMetricActivation = telConfig.CirconusCheckForceMetricActivation cfg.CheckManager.Check.ForceMetricActivation = telConfig.CirconusCheckForceMetricActivation
cfg.CheckManager.Check.InstanceID = telConfig.CirconusCheckInstanceID cfg.CheckManager.Check.InstanceID = telConfig.CirconusCheckInstanceID
cfg.CheckManager.Check.SearchTag = telConfig.CirconusCheckSearchTag cfg.CheckManager.Check.SearchTag = telConfig.CirconusCheckSearchTag
cfg.CheckManager.Check.DisplayName = telConfig.CirconusCheckDisplayName
cfg.CheckManager.Check.Tags = telConfig.CirconusCheckTags
cfg.CheckManager.Broker.ID = telConfig.CirconusBrokerID cfg.CheckManager.Broker.ID = telConfig.CirconusBrokerID
cfg.CheckManager.Broker.SelectTag = telConfig.CirconusBrokerSelectTag cfg.CheckManager.Broker.SelectTag = telConfig.CirconusBrokerSelectTag

View file

@ -149,6 +149,13 @@ type Telemetry struct {
// narrow down the search results when neither a Submission URL or Check ID is provided. // narrow down the search results when neither a Submission URL or Check ID is provided.
// Default: service:app (e.g. service:consul) // Default: service:app (e.g. service:consul)
CirconusCheckSearchTag string `hcl:"circonus_check_search_tag"` CirconusCheckSearchTag string `hcl:"circonus_check_search_tag"`
// CirconusCheckTags is a comma separated list of tags to apply to the check. Note that
// the value of CirconusCheckSearchTag will always be added to the check.
// Default: none
CirconusCheckTags string `mapstructure:"circonus_check_tags"`
// CirconusCheckDisplayName is the name for the check which will be displayed in the Circonus UI.
// Default: value of CirconusCheckInstanceID
CirconusCheckDisplayName string `mapstructure:"circonus_check_display_name"`
// CirconusBrokerID is an explicit broker to use when creating a new check. The numeric portion // CirconusBrokerID is an explicit broker to use when creating a new check. The numeric portion
// of broker._cid. If metric management is enabled and neither a Submission URL nor Check ID // of broker._cid. If metric management is enabled and neither a Submission URL nor Check ID
// is provided, an attempt will be made to search for an existing check using Instance ID and // is provided, an attempt will be made to search for an existing check using Instance ID and
@ -597,6 +604,8 @@ func parseTelemetry(result *Config, list *ast.ObjectList) error {
"circonus_check_force_metric_activation", "circonus_check_force_metric_activation",
"circonus_check_instance_id", "circonus_check_instance_id",
"circonus_check_search_tag", "circonus_check_search_tag",
"circonus_check_display_name",
"circonus_check_tags",
"circonus_broker_id", "circonus_broker_id",
"circonus_broker_select_tag", "circonus_broker_select_tag",
"disable_hostname", "disable_hostname",

View file

@ -122,6 +122,8 @@ func TestLoadConfigFile_json(t *testing.T) {
CirconusCheckForceMetricActivation: "", CirconusCheckForceMetricActivation: "",
CirconusCheckInstanceID: "", CirconusCheckInstanceID: "",
CirconusCheckSearchTag: "", CirconusCheckSearchTag: "",
CirconusCheckDisplayName: "",
CirconusCheckTags: "",
CirconusBrokerID: "", CirconusBrokerID: "",
CirconusBrokerSelectTag: "", CirconusBrokerSelectTag: "",
}, },
@ -191,6 +193,8 @@ func TestLoadConfigFile_json2(t *testing.T) {
CirconusCheckForceMetricActivation: "true", CirconusCheckForceMetricActivation: "true",
CirconusCheckInstanceID: "node1:vault", CirconusCheckInstanceID: "node1:vault",
CirconusCheckSearchTag: "service:vault", CirconusCheckSearchTag: "service:vault",
CirconusCheckDisplayName: "node1:vault",
CirconusCheckTags: "cat1:tag1,cat2:tag2",
CirconusBrokerID: "0", CirconusBrokerID: "0",
CirconusBrokerSelectTag: "dc:sfo", CirconusBrokerSelectTag: "dc:sfo",
}, },

View file

@ -3,18 +3,26 @@ package server
import ( import (
"io" "io"
"net" "net"
"strings"
"time" "time"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
) )
func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) { func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) {
bind_proto := "tcp"
addr, ok := config["address"] addr, ok := config["address"]
if !ok { if !ok {
addr = "127.0.0.1:8200" addr = "127.0.0.1:8200"
} }
ln, err := net.Listen("tcp", addr) // If they've passed 0.0.0.0, we only want to bind on IPv4
// rather than golang's dual stack default
if strings.HasPrefix(addr, "0.0.0.0:") {
bind_proto = "tcp4"
}
ln, err := net.Listen(bind_proto, addr)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }

View file

@ -36,6 +36,8 @@
"circonus_check_force_metric_activation": "true", "circonus_check_force_metric_activation": "true",
"circonus_check_instance_id": "node1:vault", "circonus_check_instance_id": "node1:vault",
"circonus_check_search_tag": "service:vault", "circonus_check_search_tag": "service:vault",
"circonus_check_display_name": "node1:vault",
"circonus_check_tags": "cat1:tag1,cat2:tag2",
"circonus_broker_id": "0", "circonus_broker_id": "0",
"circonus_broker_select_tag": "dc:sfo" "circonus_broker_select_tag": "dc:sfo"
} }

View file

@ -34,7 +34,6 @@ func (c *SSHCommand) Run(args []string) int {
var role, mountPoint, format, userKnownHostsFile, strictHostKeyChecking string var role, mountPoint, format, userKnownHostsFile, strictHostKeyChecking string
var noExec bool var noExec bool
var sshCmdArgs []string var sshCmdArgs []string
var sshDynamicKeyFileName string
flags := c.Meta.FlagSet("ssh", meta.FlagSetDefault) flags := c.Meta.FlagSet("ssh", meta.FlagSetDefault)
flags.StringVar(&strictHostKeyChecking, "strict-host-key-checking", "", "") flags.StringVar(&strictHostKeyChecking, "strict-host-key-checking", "", "")
flags.StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "") flags.StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "")
@ -76,7 +75,7 @@ func (c *SSHCommand) Run(args []string) int {
client, err := c.Client() client, err := c.Client()
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf("Error initializing client: %s", err)) c.Ui.Error(fmt.Sprintf("Error initializing client: %v", err))
return 1 return 1
} }
@ -92,7 +91,7 @@ func (c *SSHCommand) Run(args []string) int {
if len(input) == 1 { if len(input) == 1 {
u, err := user.Current() u, err := user.Current()
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf("Error fetching username: %s", err)) c.Ui.Error(fmt.Sprintf("Error fetching username: %v", err))
return 1 return 1
} }
username = u.Username username = u.Username
@ -101,7 +100,7 @@ func (c *SSHCommand) Run(args []string) int {
username = input[0] username = input[0]
ipAddr = input[1] ipAddr = input[1]
} else { } else {
c.Ui.Error(fmt.Sprintf("Invalid parameter: %s", args[0])) c.Ui.Error(fmt.Sprintf("Invalid parameter: %q", args[0]))
return 1 return 1
} }
@ -109,7 +108,7 @@ func (c *SSHCommand) Run(args []string) int {
// Vault only deals with IP addresses. // Vault only deals with IP addresses.
ip, err := net.ResolveIPAddr("ip", ipAddr) ip, err := net.ResolveIPAddr("ip", ipAddr)
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf("Error resolving IP Address: %s", err)) c.Ui.Error(fmt.Sprintf("Error resolving IP Address: %v", err))
return 1 return 1
} }
@ -120,14 +119,14 @@ func (c *SSHCommand) Run(args []string) int {
if role == "" { if role == "" {
role, err = c.defaultRole(mountPoint, ip.String()) role, err = c.defaultRole(mountPoint, ip.String())
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf("Error choosing role: %s", err)) c.Ui.Error(fmt.Sprintf("Error choosing role: %v", err))
return 1 return 1
} }
// Print the default role chosen so that user knows the role name // Print the default role chosen so that user knows the role name
// if something doesn't work. If the role chosen is not allowed to // if something doesn't work. If the role chosen is not allowed to
// be used by the user (ACL enforcement), then user should see an // be used by the user (ACL enforcement), then user should see an
// error message accordingly. // error message accordingly.
c.Ui.Output(fmt.Sprintf("Vault SSH: Role: %s", role)) c.Ui.Output(fmt.Sprintf("Vault SSH: Role: %q", role))
} }
data := map[string]interface{}{ data := map[string]interface{}{
@ -137,7 +136,7 @@ func (c *SSHCommand) Run(args []string) int {
keySecret, err := client.SSHWithMountPoint(mountPoint).Credential(role, data) keySecret, err := client.SSHWithMountPoint(mountPoint).Credential(role, data)
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf("Error getting key for SSH session:%s", err)) c.Ui.Error(fmt.Sprintf("Error getting key for SSH session: %v", err))
return 1 return 1
} }
@ -152,7 +151,7 @@ func (c *SSHCommand) Run(args []string) int {
} }
var resp SSHCredentialResp var resp SSHCredentialResp
if err := mapstructure.Decode(keySecret.Data, &resp); err != nil { if err := mapstructure.Decode(keySecret.Data, &resp); err != nil {
c.Ui.Error(fmt.Sprintf("Error parsing the credential response:%s", err)) c.Ui.Error(fmt.Sprintf("Error parsing the credential response: %v", err))
return 1 return 1
} }
@ -161,9 +160,21 @@ func (c *SSHCommand) Run(args []string) int {
c.Ui.Error(fmt.Sprintf("Invalid key")) c.Ui.Error(fmt.Sprintf("Invalid key"))
return 1 return 1
} }
sshDynamicKeyFileName = fmt.Sprintf("vault_ssh_%s_%s", username, ip.String()) sshDynamicKeyFile, err := ioutil.TempFile("", fmt.Sprintf("vault_ssh_%s_%s_", username, ip.String()))
err = ioutil.WriteFile(sshDynamicKeyFileName, []byte(resp.Key), 0600) if err != nil {
sshCmdArgs = append(sshCmdArgs, []string{"-i", sshDynamicKeyFileName}...) c.Ui.Error(fmt.Sprintf("Error creating temporary file: %v", err))
return 1
}
// Ensure that we delete the temporary file
defer os.Remove(sshDynamicKeyFile.Name())
if err = ioutil.WriteFile(sshDynamicKeyFile.Name(),
[]byte(resp.Key), 0600); err != nil {
c.Ui.Error(fmt.Sprintf("Error storing the dynamic key into the temporary file: %v", err))
return 1
}
sshCmdArgs = append(sshCmdArgs, []string{"-i", sshDynamicKeyFile.Name()}...)
} else if resp.KeyType == ssh.KeyTypeOTP { } else if resp.KeyType == ssh.KeyTypeOTP {
// Check if the application 'sshpass' is installed in the client machine. // Check if the application 'sshpass' is installed in the client machine.
@ -182,7 +193,7 @@ func (c *SSHCommand) Run(args []string) int {
sshCmd.Stdout = os.Stdout sshCmd.Stdout = os.Stdout
err = sshCmd.Run() err = sshCmd.Run()
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to establish SSH connection:%s", err)) c.Ui.Error(fmt.Sprintf("Failed to establish SSH connection: %q", err))
} }
return 0 return 0
} }
@ -204,15 +215,7 @@ func (c *SSHCommand) Run(args []string) int {
// to establish an independent session like this. // to establish an independent session like this.
err = sshCmd.Run() err = sshCmd.Run()
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf("Error while running ssh command:%s", err)) c.Ui.Error(fmt.Sprintf("Error while running ssh command: %q", err))
}
// Delete the temporary key file generated by the command.
if resp.KeyType == ssh.KeyTypeDynamic {
// Ignoring the error from the below call since it is not a security
// issue if the deletion of file is not successful. User is authorized
// to have this secret.
os.Remove(sshDynamicKeyFileName)
} }
// If the session established was longer than the lease expiry, the secret // If the session established was longer than the lease expiry, the secret
@ -222,7 +225,7 @@ func (c *SSHCommand) Run(args []string) int {
// is run, a fresh credential is generated anyways. // is run, a fresh credential is generated anyways.
err = client.Sys().Revoke(keySecret.LeaseID) err = client.Sys().Revoke(keySecret.LeaseID)
if err != nil { if err != nil {
c.Ui.Error(fmt.Sprintf("Error revoking the key: %s", err)) c.Ui.Error(fmt.Sprintf("Error revoking the key: %q", err))
} }
return 0 return 0
@ -241,15 +244,15 @@ func (c *SSHCommand) defaultRole(mountPoint, ip string) (string, error) {
} }
secret, err := client.Logical().Write(mountPoint+"/lookup", data) secret, err := client.Logical().Write(mountPoint+"/lookup", data)
if err != nil { if err != nil {
return "", fmt.Errorf("Error finding roles for IP %s: %s", ip, err) return "", fmt.Errorf("Error finding roles for IP %q: %q", ip, err)
} }
if secret == nil { if secret == nil {
return "", fmt.Errorf("Error finding roles for IP %s: %s", ip, err) return "", fmt.Errorf("Error finding roles for IP %q: %q", ip, err)
} }
if secret.Data["roles"] == nil { if secret.Data["roles"] == nil {
return "", fmt.Errorf("No matching roles found for IP %s", ip) return "", fmt.Errorf("No matching roles found for IP %q", ip)
} }
if len(secret.Data["roles"].([]interface{})) == 1 { if len(secret.Data["roles"].([]interface{})) == 1 {
@ -260,7 +263,7 @@ func (c *SSHCommand) defaultRole(mountPoint, ip string) (string, error) {
roleNames += item.(string) + ", " roleNames += item.(string) + ", "
} }
roleNames = strings.TrimRight(roleNames, ", ") roleNames = strings.TrimRight(roleNames, ", ")
return "", fmt.Errorf("Roles:[%s]"+` return "", fmt.Errorf("Roles:%q. "+`
Multiple roles are registered for this IP. Multiple roles are registered for this IP.
Select a role using '-role' option. Select a role using '-role' option.
Note that all roles may not be permitted, based on ACLs.`, roleNames) Note that all roles may not be permitted, based on ACLs.`, roleNames)

View file

@ -49,6 +49,13 @@ func (m *Request) String() string { return proto.CompactTextString(m)
func (*Request) ProtoMessage() {} func (*Request) ProtoMessage() {}
func (*Request) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } func (*Request) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *Request) GetMethod() string {
if m != nil {
return m.Method
}
return ""
}
func (m *Request) GetUrl() *URL { func (m *Request) GetUrl() *URL {
if m != nil { if m != nil {
return m.Url return m.Url
@ -63,6 +70,34 @@ func (m *Request) GetHeaderEntries() map[string]*HeaderEntry {
return nil return nil
} }
func (m *Request) GetBody() []byte {
if m != nil {
return m.Body
}
return nil
}
func (m *Request) GetHost() string {
if m != nil {
return m.Host
}
return ""
}
func (m *Request) GetRemoteAddr() string {
if m != nil {
return m.RemoteAddr
}
return ""
}
func (m *Request) GetPeerCertificates() [][]byte {
if m != nil {
return m.PeerCertificates
}
return nil
}
type URL struct { type URL struct {
Scheme string `protobuf:"bytes,1,opt,name=scheme" json:"scheme,omitempty"` Scheme string `protobuf:"bytes,1,opt,name=scheme" json:"scheme,omitempty"`
Opaque string `protobuf:"bytes,2,opt,name=opaque" json:"opaque,omitempty"` Opaque string `protobuf:"bytes,2,opt,name=opaque" json:"opaque,omitempty"`
@ -83,6 +118,55 @@ func (m *URL) String() string { return proto.CompactTextString(m) }
func (*URL) ProtoMessage() {} func (*URL) ProtoMessage() {}
func (*URL) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } func (*URL) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *URL) GetScheme() string {
if m != nil {
return m.Scheme
}
return ""
}
func (m *URL) GetOpaque() string {
if m != nil {
return m.Opaque
}
return ""
}
func (m *URL) GetHost() string {
if m != nil {
return m.Host
}
return ""
}
func (m *URL) GetPath() string {
if m != nil {
return m.Path
}
return ""
}
func (m *URL) GetRawPath() string {
if m != nil {
return m.RawPath
}
return ""
}
func (m *URL) GetRawQuery() string {
if m != nil {
return m.RawQuery
}
return ""
}
func (m *URL) GetFragment() string {
if m != nil {
return m.Fragment
}
return ""
}
type HeaderEntry struct { type HeaderEntry struct {
Values []string `protobuf:"bytes,1,rep,name=values" json:"values,omitempty"` Values []string `protobuf:"bytes,1,rep,name=values" json:"values,omitempty"`
} }
@ -92,6 +176,13 @@ func (m *HeaderEntry) String() string { return proto.CompactTextStrin
func (*HeaderEntry) ProtoMessage() {} func (*HeaderEntry) ProtoMessage() {}
func (*HeaderEntry) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } func (*HeaderEntry) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *HeaderEntry) GetValues() []string {
if m != nil {
return m.Values
}
return nil
}
type Response struct { type Response struct {
// Not used right now but reserving in case it turns out that streaming // Not used right now but reserving in case it turns out that streaming
// makes things more economical on the gRPC side // makes things more economical on the gRPC side
@ -108,6 +199,20 @@ func (m *Response) String() string { return proto.CompactTextString(m
func (*Response) ProtoMessage() {} func (*Response) ProtoMessage() {}
func (*Response) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } func (*Response) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *Response) GetStatusCode() uint32 {
if m != nil {
return m.StatusCode
}
return 0
}
func (m *Response) GetBody() []byte {
if m != nil {
return m.Body
}
return nil
}
func (m *Response) GetHeaderEntries() map[string]*HeaderEntry { func (m *Response) GetHeaderEntries() map[string]*HeaderEntry {
if m != nil { if m != nil {
return m.HeaderEntries return m.HeaderEntries

View file

@ -1,4 +1,4 @@
package transit package keysutil
import ( import (
"errors" "errors"
@ -18,29 +18,29 @@ var (
errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation") errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation")
) )
// policyRequest holds values used when requesting a policy. Most values are // PolicyRequest holds values used when requesting a policy. Most values are
// only used during an upsert. // only used during an upsert.
type policyRequest struct { type PolicyRequest struct {
// The storage to use // The storage to use
storage logical.Storage Storage logical.Storage
// The name of the policy // The name of the policy
name string Name string
// The key type // The key type
keyType KeyType KeyType KeyType
// Whether it should be derived // Whether it should be derived
derived bool Derived bool
// Whether to enable convergent encryption // Whether to enable convergent encryption
convergent bool Convergent bool
// Whether to upsert // Whether to upsert
upsert bool Upsert bool
} }
type lockManager struct { type LockManager struct {
// A lock for each named key // A lock for each named key
locks map[string]*sync.RWMutex locks map[string]*sync.RWMutex
@ -48,27 +48,27 @@ type lockManager struct {
locksMutex sync.RWMutex locksMutex sync.RWMutex
// If caching is enabled, the map of name to in-memory policy cache // If caching is enabled, the map of name to in-memory policy cache
cache map[string]*policy cache map[string]*Policy
// Used for global locking, and as the cache map mutex // Used for global locking, and as the cache map mutex
cacheMutex sync.RWMutex cacheMutex sync.RWMutex
} }
func newLockManager(cacheDisabled bool) *lockManager { func NewLockManager(cacheDisabled bool) *LockManager {
lm := &lockManager{ lm := &LockManager{
locks: map[string]*sync.RWMutex{}, locks: map[string]*sync.RWMutex{},
} }
if !cacheDisabled { if !cacheDisabled {
lm.cache = map[string]*policy{} lm.cache = map[string]*Policy{}
} }
return lm return lm
} }
func (lm *lockManager) CacheActive() bool { func (lm *LockManager) CacheActive() bool {
return lm.cache != nil return lm.cache != nil
} }
func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex { func (lm *LockManager) policyLock(name string, lockType bool) *sync.RWMutex {
lm.locksMutex.RLock() lm.locksMutex.RLock()
lock := lm.locks[name] lock := lm.locks[name]
if lock != nil { if lock != nil {
@ -115,7 +115,7 @@ func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex {
return lock return lock
} }
func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) { func (lm *LockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
if lockType == exclusive { if lockType == exclusive {
lock.Unlock() lock.Unlock()
} else { } else {
@ -126,10 +126,10 @@ func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
// Get the policy with a read lock. If we get an error saying an exclusive lock // Get the policy with a read lock. If we get an error saying an exclusive lock
// is needed (for instance, for an upgrade/migration), give up the read lock, // is needed (for instance, for an upgrade/migration), give up the read lock,
// call again with an exclusive lock, then swap back out for a read lock. // call again with an exclusive lock, then swap back out for a read lock.
func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*policy, *sync.RWMutex, error) { func (lm *LockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(policyRequest{ p, lock, _, err := lm.getPolicyCommon(PolicyRequest{
storage: storage, Storage: storage,
name: name, Name: name,
}, shared) }, shared)
if err == nil || if err == nil ||
(err != nil && err != errNeedExclusiveLock) { (err != nil && err != errNeedExclusiveLock) {
@ -137,9 +137,9 @@ func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*p
} }
// Try again while asking for an exlusive lock // Try again while asking for an exlusive lock
p, lock, _, err = lm.getPolicyCommon(policyRequest{ p, lock, _, err = lm.getPolicyCommon(PolicyRequest{
storage: storage, Storage: storage,
name: name, Name: name,
}, exclusive) }, exclusive)
if err != nil || p == nil || lock == nil { if err != nil || p == nil || lock == nil {
return p, lock, err return p, lock, err
@ -147,18 +147,18 @@ func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*p
lock.Unlock() lock.Unlock()
p, lock, _, err = lm.getPolicyCommon(policyRequest{ p, lock, _, err = lm.getPolicyCommon(PolicyRequest{
storage: storage, Storage: storage,
name: name, Name: name,
}, shared) }, shared)
return p, lock, err return p, lock, err
} }
// Get the policy with an exclusive lock // Get the policy with an exclusive lock
func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*policy, *sync.RWMutex, error) { func (lm *LockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(policyRequest{ p, lock, _, err := lm.getPolicyCommon(PolicyRequest{
storage: storage, Storage: storage,
name: name, Name: name,
}, exclusive) }, exclusive)
return p, lock, err return p, lock, err
} }
@ -166,8 +166,8 @@ func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string)
// Get the policy with a read lock; if it returns that an exclusive lock is // Get the policy with a read lock; if it returns that an exclusive lock is
// needed, retry. If successful, call one more time to get a read lock and // needed, retry. If successful, call one more time to get a read lock and
// return the value. // return the value.
func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMutex, bool, error) { func (lm *LockManager) GetPolicyUpsert(req PolicyRequest) (*Policy, *sync.RWMutex, bool, error) {
req.upsert = true req.Upsert = true
p, lock, _, err := lm.getPolicyCommon(req, shared) p, lock, _, err := lm.getPolicyCommon(req, shared)
if err == nil || if err == nil ||
@ -182,7 +182,7 @@ func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMute
} }
lock.Unlock() lock.Unlock()
req.upsert = false req.Upsert = false
// Now get a shared lock for the return, but preserve the value of upserted // Now get a shared lock for the return, but preserve the value of upserted
p, lock, _, err = lm.getPolicyCommon(req, shared) p, lock, _, err = lm.getPolicyCommon(req, shared)
@ -191,16 +191,16 @@ func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMute
// When the function returns, a lock will be held on the policy if err == nil. // When the function returns, a lock will be held on the policy if err == nil.
// It is the caller's responsibility to unlock. // It is the caller's responsibility to unlock.
func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*policy, *sync.RWMutex, bool, error) { func (lm *LockManager) getPolicyCommon(req PolicyRequest, lockType bool) (*Policy, *sync.RWMutex, bool, error) {
lock := lm.policyLock(req.name, lockType) lock := lm.policyLock(req.Name, lockType)
var p *policy var p *Policy
var err error var err error
// Check if it's in our cache. If so, return right away. // Check if it's in our cache. If so, return right away.
if lm.CacheActive() { if lm.CacheActive() {
lm.cacheMutex.RLock() lm.cacheMutex.RLock()
p = lm.cache[req.name] p = lm.cache[req.Name]
if p != nil { if p != nil {
lm.cacheMutex.RUnlock() lm.cacheMutex.RUnlock()
return p, lock, false, nil return p, lock, false, nil
@ -209,7 +209,7 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
} }
// Load it from storage // Load it from storage
p, err = lm.getStoredPolicy(req.storage, req.name) p, err = lm.getStoredPolicy(req.Storage, req.Name)
if err != nil { if err != nil {
lm.UnlockPolicy(lock, lockType) lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err return nil, nil, false, err
@ -218,7 +218,7 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
if p == nil { if p == nil {
// This is the only place we upsert a new policy, so if upsert is not // This is the only place we upsert a new policy, so if upsert is not
// specified, or the lock type is wrong, unlock before returning // specified, or the lock type is wrong, unlock before returning
if !req.upsert { if !req.Upsert {
lm.UnlockPolicy(lock, lockType) lm.UnlockPolicy(lock, lockType)
return nil, nil, false, nil return nil, nil, false, nil
} }
@ -228,33 +228,33 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
return nil, nil, false, errNeedExclusiveLock return nil, nil, false, errNeedExclusiveLock
} }
switch req.keyType { switch req.KeyType {
case keyType_AES256_GCM96: case KeyType_AES256_GCM96:
if req.convergent && !req.derived { if req.Convergent && !req.Derived {
return nil, nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled") return nil, nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled")
} }
case keyType_ECDSA_P256: case KeyType_ECDSA_P256:
if req.derived || req.convergent { if req.Derived || req.Convergent {
return nil, nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %s", keyType_ECDSA_P256) return nil, nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %s", KeyType_ECDSA_P256)
} }
default: default:
return nil, nil, false, fmt.Errorf("unsupported key type %v", req.keyType) return nil, nil, false, fmt.Errorf("unsupported key type %v", req.KeyType)
} }
p = &policy{ p = &Policy{
Name: req.name, Name: req.Name,
Type: req.keyType, Type: req.KeyType,
Derived: req.derived, Derived: req.Derived,
} }
if req.derived { if req.Derived {
p.KDF = kdf_hkdf_sha256 p.KDF = Kdf_hkdf_sha256
p.ConvergentEncryption = req.convergent p.ConvergentEncryption = req.Convergent
p.ConvergentVersion = 2 p.ConvergentVersion = 2
} }
err = p.rotate(req.storage) err = p.Rotate(req.Storage)
if err != nil { if err != nil {
lm.UnlockPolicy(lock, lockType) lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err return nil, nil, false, err
@ -267,12 +267,12 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
defer lm.cacheMutex.Unlock() defer lm.cacheMutex.Unlock()
// Make sure a policy didn't appear. If so, it will only be set if // Make sure a policy didn't appear. If so, it will only be set if
// there was no error, so assume it's good and return that // there was no error, so assume it's good and return that
exp := lm.cache[req.name] exp := lm.cache[req.Name]
if exp != nil { if exp != nil {
return exp, lock, false, nil return exp, lock, false, nil
} }
if err == nil { if err == nil {
lm.cache[req.name] = p lm.cache[req.Name] = p
} }
} }
@ -280,13 +280,13 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
return p, lock, true, nil return p, lock, true, nil
} }
if p.needsUpgrade() { if p.NeedsUpgrade() {
if lockType == shared { if lockType == shared {
lm.UnlockPolicy(lock, lockType) lm.UnlockPolicy(lock, lockType)
return nil, nil, false, errNeedExclusiveLock return nil, nil, false, errNeedExclusiveLock
} }
err = p.upgrade(req.storage) err = p.Upgrade(req.Storage)
if err != nil { if err != nil {
lm.UnlockPolicy(lock, lockType) lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err return nil, nil, false, err
@ -300,25 +300,25 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
defer lm.cacheMutex.Unlock() defer lm.cacheMutex.Unlock()
// Make sure a policy didn't appear. If so, it will only be set if // Make sure a policy didn't appear. If so, it will only be set if
// there was no error, so assume it's good and return that // there was no error, so assume it's good and return that
exp := lm.cache[req.name] exp := lm.cache[req.Name]
if exp != nil { if exp != nil {
return exp, lock, false, nil return exp, lock, false, nil
} }
if err == nil { if err == nil {
lm.cache[req.name] = p lm.cache[req.Name] = p
} }
} }
return p, lock, false, nil return p, lock, false, nil
} }
func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error { func (lm *LockManager) DeletePolicy(storage logical.Storage, name string) error {
lm.cacheMutex.Lock() lm.cacheMutex.Lock()
lock := lm.policyLock(name, exclusive) lock := lm.policyLock(name, exclusive)
defer lock.Unlock() defer lock.Unlock()
defer lm.cacheMutex.Unlock() defer lm.cacheMutex.Unlock()
var p *policy var p *Policy
var err error var err error
if lm.CacheActive() { if lm.CacheActive() {
@ -355,7 +355,7 @@ func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error
return nil return nil
} }
func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*policy, error) { func (lm *LockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) {
// Check if the policy already exists // Check if the policy already exists
raw, err := storage.Get("policy/" + name) raw, err := storage.Get("policy/" + name)
if err != nil { if err != nil {
@ -366,7 +366,7 @@ func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*p
} }
// Decode the policy // Decode the policy
policy := &policy{ policy := &Policy{
Keys: keyEntryMap{}, Keys: keyEntryMap{},
} }
err = jsonutil.DecodeJSON(raw.Value, policy) err = jsonutil.DecodeJSON(raw.Value, policy)

View file

@ -1,4 +1,4 @@
package transit package keysutil
import ( import (
"bytes" "bytes"
@ -33,14 +33,14 @@ import (
// Careful with iota; don't put anything before it in this const block because // Careful with iota; don't put anything before it in this const block because
// we need the default of zero to be the old-style KDF // we need the default of zero to be the old-style KDF
const ( const (
kdf_hmac_sha256_counter = iota // built-in helper Kdf_hmac_sha256_counter = iota // built-in helper
kdf_hkdf_sha256 // golang.org/x/crypto/hkdf Kdf_hkdf_sha256 // golang.org/x/crypto/hkdf
) )
// Or this one...we need the default of zero to be the original AES256-GCM96 // Or this one...we need the default of zero to be the original AES256-GCM96
const ( const (
keyType_AES256_GCM96 = iota KeyType_AES256_GCM96 = iota
keyType_ECDSA_P256 KeyType_ECDSA_P256
) )
const ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)" const ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)"
@ -53,7 +53,7 @@ type KeyType int
func (kt KeyType) EncryptionSupported() bool { func (kt KeyType) EncryptionSupported() bool {
switch kt { switch kt {
case keyType_AES256_GCM96: case KeyType_AES256_GCM96:
return true return true
} }
return false return false
@ -61,7 +61,7 @@ func (kt KeyType) EncryptionSupported() bool {
func (kt KeyType) DecryptionSupported() bool { func (kt KeyType) DecryptionSupported() bool {
switch kt { switch kt {
case keyType_AES256_GCM96: case KeyType_AES256_GCM96:
return true return true
} }
return false return false
@ -69,7 +69,7 @@ func (kt KeyType) DecryptionSupported() bool {
func (kt KeyType) SigningSupported() bool { func (kt KeyType) SigningSupported() bool {
switch kt { switch kt {
case keyType_ECDSA_P256: case KeyType_ECDSA_P256:
return true return true
} }
return false return false
@ -77,7 +77,7 @@ func (kt KeyType) SigningSupported() bool {
func (kt KeyType) DerivationSupported() bool { func (kt KeyType) DerivationSupported() bool {
switch kt { switch kt {
case keyType_AES256_GCM96: case KeyType_AES256_GCM96:
return true return true
} }
return false return false
@ -85,17 +85,17 @@ func (kt KeyType) DerivationSupported() bool {
func (kt KeyType) String() string { func (kt KeyType) String() string {
switch kt { switch kt {
case keyType_AES256_GCM96: case KeyType_AES256_GCM96:
return "aes256-gcm96" return "aes256-gcm96"
case keyType_ECDSA_P256: case KeyType_ECDSA_P256:
return "ecdsa-p256" return "ecdsa-p256"
} }
return "[unknown]" return "[unknown]"
} }
// keyEntry stores the key and metadata // KeyEntry stores the key and metadata
type keyEntry struct { type KeyEntry struct {
AESKey []byte `json:"key"` AESKey []byte `json:"key"`
HMACKey []byte `json:"hmac_key"` HMACKey []byte `json:"hmac_key"`
CreationTime int64 `json:"creation_time"` CreationTime int64 `json:"creation_time"`
@ -106,11 +106,11 @@ type keyEntry struct {
} }
// keyEntryMap is used to allow JSON marshal/unmarshal // keyEntryMap is used to allow JSON marshal/unmarshal
type keyEntryMap map[int]keyEntry type keyEntryMap map[int]KeyEntry
// MarshalJSON implements JSON marshaling // MarshalJSON implements JSON marshaling
func (kem keyEntryMap) MarshalJSON() ([]byte, error) { func (kem keyEntryMap) MarshalJSON() ([]byte, error) {
intermediate := map[string]keyEntry{} intermediate := map[string]KeyEntry{}
for k, v := range kem { for k, v := range kem {
intermediate[strconv.Itoa(k)] = v intermediate[strconv.Itoa(k)] = v
} }
@ -119,7 +119,7 @@ func (kem keyEntryMap) MarshalJSON() ([]byte, error) {
// MarshalJSON implements JSON unmarshaling // MarshalJSON implements JSON unmarshaling
func (kem keyEntryMap) UnmarshalJSON(data []byte) error { func (kem keyEntryMap) UnmarshalJSON(data []byte) error {
intermediate := map[string]keyEntry{} intermediate := map[string]KeyEntry{}
if err := jsonutil.DecodeJSON(data, &intermediate); err != nil { if err := jsonutil.DecodeJSON(data, &intermediate); err != nil {
return err return err
} }
@ -135,7 +135,7 @@ func (kem keyEntryMap) UnmarshalJSON(data []byte) error {
} }
// Policy is the struct used to store metadata // Policy is the struct used to store metadata
type policy struct { type Policy struct {
Name string `json:"name"` Name string `json:"name"`
Key []byte `json:"key,omitempty"` //DEPRECATED Key []byte `json:"key,omitempty"` //DEPRECATED
Keys keyEntryMap `json:"keys"` Keys keyEntryMap `json:"keys"`
@ -171,10 +171,10 @@ type policy struct {
// ArchivedKeys stores old keys. This is used to keep the key loading time sane // ArchivedKeys stores old keys. This is used to keep the key loading time sane
// when there are huge numbers of rotations. // when there are huge numbers of rotations.
type archivedKeys struct { type archivedKeys struct {
Keys []keyEntry `json:"keys"` Keys []KeyEntry `json:"keys"`
} }
func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) { func (p *Policy) LoadArchive(storage logical.Storage) (*archivedKeys, error) {
archive := &archivedKeys{} archive := &archivedKeys{}
raw, err := storage.Get("archive/" + p.Name) raw, err := storage.Get("archive/" + p.Name)
@ -182,7 +182,7 @@ func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) {
return nil, err return nil, err
} }
if raw == nil { if raw == nil {
archive.Keys = make([]keyEntry, 0) archive.Keys = make([]KeyEntry, 0)
return archive, nil return archive, nil
} }
@ -193,7 +193,7 @@ func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) {
return archive, nil return archive, nil
} }
func (p *policy) storeArchive(archive *archivedKeys, storage logical.Storage) error { func (p *Policy) storeArchive(archive *archivedKeys, storage logical.Storage) error {
// Encode the policy // Encode the policy
buf, err := json.Marshal(archive) buf, err := json.Marshal(archive)
if err != nil { if err != nil {
@ -215,7 +215,7 @@ func (p *policy) storeArchive(archive *archivedKeys, storage logical.Storage) er
// handleArchiving manages the movement of keys to and from the policy archive. // handleArchiving manages the movement of keys to and from the policy archive.
// This should *ONLY* be called from Persist() since it assumes that the policy // This should *ONLY* be called from Persist() since it assumes that the policy
// will be persisted afterwards. // will be persisted afterwards.
func (p *policy) handleArchiving(storage logical.Storage) error { func (p *Policy) handleArchiving(storage logical.Storage) error {
// We need to move keys that are no longer accessible to archivedKeys, and keys // We need to move keys that are no longer accessible to archivedKeys, and keys
// that now need to be accessible back here. // that now need to be accessible back here.
// //
@ -241,7 +241,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
p.MinDecryptionVersion, p.LatestVersion) p.MinDecryptionVersion, p.LatestVersion)
} }
archive, err := p.loadArchive(storage) archive, err := p.LoadArchive(storage)
if err != nil { if err != nil {
return err return err
} }
@ -263,7 +263,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
// key version // key version
if len(archive.Keys) < p.LatestVersion+1 { if len(archive.Keys) < p.LatestVersion+1 {
// Increase the size of the archive slice // Increase the size of the archive slice
newKeys := make([]keyEntry, p.LatestVersion+1) newKeys := make([]KeyEntry, p.LatestVersion+1)
copy(newKeys, archive.Keys) copy(newKeys, archive.Keys)
archive.Keys = newKeys archive.Keys = newKeys
} }
@ -289,7 +289,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
return nil return nil
} }
func (p *policy) Persist(storage logical.Storage) error { func (p *Policy) Persist(storage logical.Storage) error {
err := p.handleArchiving(storage) err := p.handleArchiving(storage)
if err != nil { if err != nil {
return err return err
@ -313,11 +313,11 @@ func (p *policy) Persist(storage logical.Storage) error {
return nil return nil
} }
func (p *policy) Serialize() ([]byte, error) { func (p *Policy) Serialize() ([]byte, error) {
return json.Marshal(p) return json.Marshal(p)
} }
func (p *policy) needsUpgrade() bool { func (p *Policy) NeedsUpgrade() bool {
// Ensure we've moved from Key -> Keys // Ensure we've moved from Key -> Keys
if p.Key != nil && len(p.Key) > 0 { if p.Key != nil && len(p.Key) > 0 {
return true return true
@ -352,11 +352,11 @@ func (p *policy) needsUpgrade() bool {
return false return false
} }
func (p *policy) upgrade(storage logical.Storage) error { func (p *Policy) Upgrade(storage logical.Storage) error {
persistNeeded := false persistNeeded := false
// Ensure we've moved from Key -> Keys // Ensure we've moved from Key -> Keys
if p.Key != nil && len(p.Key) > 0 { if p.Key != nil && len(p.Key) > 0 {
p.migrateKeyToKeysMap() p.MigrateKeyToKeysMap()
persistNeeded = true persistNeeded = true
} }
@ -409,7 +409,7 @@ func (p *policy) upgrade(storage logical.Storage) error {
// on the policy. If derivation is disabled the raw key is used and no context // on the policy. If derivation is disabled the raw key is used and no context
// is required, otherwise the KDF mode is used with the context to derive the // is required, otherwise the KDF mode is used with the context to derive the
// proper key. // proper key.
func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) { func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
if !p.Type.DerivationSupported() { if !p.Type.DerivationSupported() {
return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)} return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)}
} }
@ -433,11 +433,11 @@ func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) {
} }
switch p.KDF { switch p.KDF {
case kdf_hmac_sha256_counter: case Kdf_hmac_sha256_counter:
prf := kdf.HMACSHA256PRF prf := kdf.HMACSHA256PRF
prfLen := kdf.HMACSHA256PRFLen prfLen := kdf.HMACSHA256PRFLen
return kdf.CounterMode(prf, prfLen, p.Keys[ver].AESKey, context, 256) return kdf.CounterMode(prf, prfLen, p.Keys[ver].AESKey, context, 256)
case kdf_hkdf_sha256: case Kdf_hkdf_sha256:
reader := hkdf.New(sha256.New, p.Keys[ver].AESKey, nil, context) reader := hkdf.New(sha256.New, p.Keys[ver].AESKey, nil, context)
derBytes := bytes.NewBuffer(nil) derBytes := bytes.NewBuffer(nil)
derBytes.Grow(32) derBytes.Grow(32)
@ -458,14 +458,14 @@ func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) {
} }
} }
func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) { func (p *Policy) Encrypt(context, nonce []byte, value string) (string, error) {
if !p.Type.EncryptionSupported() { if !p.Type.EncryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)} return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)}
} }
// Guard against a potentially invalid key type // Guard against a potentially invalid key type
switch p.Type { switch p.Type {
case keyType_AES256_GCM96: case KeyType_AES256_GCM96:
default: default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
} }
@ -484,7 +484,7 @@ func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) {
// Guard against a potentially invalid key type // Guard against a potentially invalid key type
switch p.Type { switch p.Type {
case keyType_AES256_GCM96: case KeyType_AES256_GCM96:
default: default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
} }
@ -539,7 +539,7 @@ func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) {
return encoded, nil return encoded, nil
} }
func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) { func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
if !p.Type.DecryptionSupported() { if !p.Type.DecryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)} return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)}
} }
@ -585,7 +585,7 @@ func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) {
// Guard against a potentially invalid key type // Guard against a potentially invalid key type
switch p.Type { switch p.Type {
case keyType_AES256_GCM96: case KeyType_AES256_GCM96:
default: default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)} return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
} }
@ -626,7 +626,7 @@ func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) {
return base64.StdEncoding.EncodeToString(plain), nil return base64.StdEncoding.EncodeToString(plain), nil
} }
func (p *policy) HMACKey(version int) ([]byte, error) { func (p *Policy) HMACKey(version int) ([]byte, error) {
if version < p.MinDecryptionVersion { if version < p.MinDecryptionVersion {
return nil, fmt.Errorf("key version disallowed by policy (minimum is %d)", p.MinDecryptionVersion) return nil, fmt.Errorf("key version disallowed by policy (minimum is %d)", p.MinDecryptionVersion)
} }
@ -642,14 +642,14 @@ func (p *policy) HMACKey(version int) ([]byte, error) {
return p.Keys[version].HMACKey, nil return p.Keys[version].HMACKey, nil
} }
func (p *policy) Sign(hashedInput []byte) (string, error) { func (p *Policy) Sign(hashedInput []byte) (string, error) {
if !p.Type.SigningSupported() { if !p.Type.SigningSupported() {
return "", fmt.Errorf("message signing not supported for key type %v", p.Type) return "", fmt.Errorf("message signing not supported for key type %v", p.Type)
} }
var sig []byte var sig []byte
switch p.Type { switch p.Type {
case keyType_ECDSA_P256: case KeyType_ECDSA_P256:
keyParams := p.Keys[p.LatestVersion] keyParams := p.Keys[p.LatestVersion]
key := &ecdsa.PrivateKey{ key := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{ PublicKey: ecdsa.PublicKey{
@ -685,7 +685,7 @@ func (p *policy) Sign(hashedInput []byte) (string, error) {
return encoded, nil return encoded, nil
} }
func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) { func (p *Policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
if !p.Type.SigningSupported() { if !p.Type.SigningSupported() {
return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)} return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)}
} }
@ -714,7 +714,7 @@ func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
} }
switch p.Type { switch p.Type {
case keyType_ECDSA_P256: case KeyType_ECDSA_P256:
asn1Sig, err := base64.StdEncoding.DecodeString(splitVerSig[1]) asn1Sig, err := base64.StdEncoding.DecodeString(splitVerSig[1])
if err != nil { if err != nil {
return false, errutil.UserError{Err: "invalid base64 signature value"} return false, errutil.UserError{Err: "invalid base64 signature value"}
@ -744,7 +744,7 @@ func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
return false, errutil.InternalError{Err: "no valid key type found"} return false, errutil.InternalError{Err: "no valid key type found"}
} }
func (p *policy) rotate(storage logical.Storage) error { func (p *Policy) Rotate(storage logical.Storage) error {
if p.Keys == nil { if p.Keys == nil {
// This is an initial key rotation when generating a new policy. We // This is an initial key rotation when generating a new policy. We
// don't need to call migrate here because if we've called getPolicy to // don't need to call migrate here because if we've called getPolicy to
@ -753,7 +753,7 @@ func (p *policy) rotate(storage logical.Storage) error {
} }
p.LatestVersion += 1 p.LatestVersion += 1
entry := keyEntry{ entry := KeyEntry{
CreationTime: time.Now().Unix(), CreationTime: time.Now().Unix(),
} }
@ -764,7 +764,7 @@ func (p *policy) rotate(storage logical.Storage) error {
entry.HMACKey = hmacKey entry.HMACKey = hmacKey
switch p.Type { switch p.Type {
case keyType_AES256_GCM96: case KeyType_AES256_GCM96:
// Generate a 256bit key // Generate a 256bit key
newKey, err := uuid.GenerateRandomBytes(32) newKey, err := uuid.GenerateRandomBytes(32)
if err != nil { if err != nil {
@ -772,7 +772,7 @@ func (p *policy) rotate(storage logical.Storage) error {
} }
entry.AESKey = newKey entry.AESKey = newKey
case keyType_ECDSA_P256: case KeyType_ECDSA_P256:
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
return err return err
@ -807,9 +807,9 @@ func (p *policy) rotate(storage logical.Storage) error {
return p.Persist(storage) return p.Persist(storage)
} }
func (p *policy) migrateKeyToKeysMap() { func (p *Policy) MigrateKeyToKeysMap() {
p.Keys = keyEntryMap{ p.Keys = keyEntryMap{
1: keyEntry{ 1: KeyEntry{
AESKey: p.Key, AESKey: p.Key,
CreationTime: time.Now().Unix(), CreationTime: time.Now().Unix(),
}, },

View file

@ -1,4 +1,4 @@
package transit package keysutil
import ( import (
"reflect" "reflect"
@ -8,24 +8,24 @@ import (
) )
var ( var (
keysArchive []keyEntry keysArchive []KeyEntry
) )
func resetKeysArchive() { func resetKeysArchive() {
keysArchive = []keyEntry{keyEntry{}} keysArchive = []KeyEntry{KeyEntry{}}
} }
func Test_KeyUpgrade(t *testing.T) { func Test_KeyUpgrade(t *testing.T) {
testKeyUpgradeCommon(t, newLockManager(false)) testKeyUpgradeCommon(t, NewLockManager(false))
testKeyUpgradeCommon(t, newLockManager(true)) testKeyUpgradeCommon(t, NewLockManager(true))
} }
func testKeyUpgradeCommon(t *testing.T, lm *lockManager) { func testKeyUpgradeCommon(t *testing.T, lm *LockManager) {
storage := &logical.InmemStorage{} storage := &logical.InmemStorage{}
p, lock, upserted, err := lm.GetPolicyUpsert(policyRequest{ p, lock, upserted, err := lm.GetPolicyUpsert(PolicyRequest{
storage: storage, Storage: storage,
keyType: keyType_AES256_GCM96, KeyType: KeyType_AES256_GCM96,
name: "test", Name: "test",
}) })
if lock != nil { if lock != nil {
defer lock.RUnlock() defer lock.RUnlock()
@ -45,7 +45,7 @@ func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
p.Key = p.Keys[1].AESKey p.Key = p.Keys[1].AESKey
p.Keys = nil p.Keys = nil
p.migrateKeyToKeysMap() p.MigrateKeyToKeysMap()
if p.Key != nil { if p.Key != nil {
t.Fatal("policy.Key is not nil") t.Fatal("policy.Key is not nil")
} }
@ -58,11 +58,11 @@ func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
} }
func Test_ArchivingUpgrade(t *testing.T) { func Test_ArchivingUpgrade(t *testing.T) {
testArchivingUpgradeCommon(t, newLockManager(false)) testArchivingUpgradeCommon(t, NewLockManager(false))
testArchivingUpgradeCommon(t, newLockManager(true)) testArchivingUpgradeCommon(t, NewLockManager(true))
} }
func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) { func testArchivingUpgradeCommon(t *testing.T, lm *LockManager) {
resetKeysArchive() resetKeysArchive()
// First, we generate a policy and rotate it a number of times. Each time // First, we generate a policy and rotate it a number of times. Each time
@ -71,10 +71,10 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
// zero and latest, respectively // zero and latest, respectively
storage := &logical.InmemStorage{} storage := &logical.InmemStorage{}
p, lock, _, err := lm.GetPolicyUpsert(policyRequest{ p, lock, _, err := lm.GetPolicyUpsert(PolicyRequest{
storage: storage, Storage: storage,
keyType: keyType_AES256_GCM96, KeyType: KeyType_AES256_GCM96,
name: "test", Name: "test",
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -89,7 +89,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
checkKeys(t, p, storage, "initial", 1, 1, 1) checkKeys(t, p, storage, "initial", 1, 1, 1)
for i := 2; i <= 10; i++ { for i := 2; i <= 10; i++ {
err = p.rotate(storage) err = p.Rotate(storage)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -191,11 +191,11 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
} }
func Test_Archiving(t *testing.T) { func Test_Archiving(t *testing.T) {
testArchivingCommon(t, newLockManager(false)) testArchivingCommon(t, NewLockManager(false))
testArchivingCommon(t, newLockManager(true)) testArchivingCommon(t, NewLockManager(true))
} }
func testArchivingCommon(t *testing.T, lm *lockManager) { func testArchivingCommon(t *testing.T, lm *LockManager) {
resetKeysArchive() resetKeysArchive()
// First, we generate a policy and rotate it a number of times. Each time // we'll ensure that we have the expected number of keys in the archive and // First, we generate a policy and rotate it a number of times. Each time // we'll ensure that we have the expected number of keys in the archive and
@ -203,10 +203,10 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
// zero and latest, respectively // zero and latest, respectively
storage := &logical.InmemStorage{} storage := &logical.InmemStorage{}
p, lock, _, err := lm.GetPolicyUpsert(policyRequest{ p, lock, _, err := lm.GetPolicyUpsert(PolicyRequest{
storage: storage, Storage: storage,
keyType: keyType_AES256_GCM96, KeyType: KeyType_AES256_GCM96,
name: "test", Name: "test",
}) })
if lock != nil { if lock != nil {
defer lock.RUnlock() defer lock.RUnlock()
@ -223,7 +223,7 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
checkKeys(t, p, storage, "initial", 1, 1, 1) checkKeys(t, p, storage, "initial", 1, 1, 1)
for i := 2; i <= 10; i++ { for i := 2; i <= 10; i++ {
err = p.rotate(storage) err = p.Rotate(storage)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -271,7 +271,7 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
} }
func checkKeys(t *testing.T, func checkKeys(t *testing.T,
p *policy, p *Policy,
storage logical.Storage, storage logical.Storage,
action string, action string,
archiveVer, latestVer, keysSize int) { archiveVer, latestVer, keysSize int) {
@ -282,7 +282,7 @@ func checkKeys(t *testing.T,
"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive)) "but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
} }
archive, err := p.loadArchive(storage) archive, err := p.LoadArchive(storage)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -21,6 +21,7 @@ import (
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
credCert "github.com/hashicorp/vault/builtin/credential/cert" credCert "github.com/hashicorp/vault/builtin/credential/cert"
"github.com/hashicorp/vault/builtin/logical/transit" "github.com/hashicorp/vault/builtin/logical/transit"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
) )
@ -381,7 +382,7 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, rpc, parallel bool, num uin
secret, err := doResp(resp) secret, err := doResp(resp)
if err != nil { if err != nil {
// This could well happen since the min version is jumping around // This could well happen since the min version is jumping around
if strings.Contains(err.Error(), transit.ErrTooOld) { if strings.Contains(err.Error(), keysutil.ErrTooOld) {
mySuccessfulOps++ mySuccessfulOps++
continue continue
} }

View file

@ -26,6 +26,11 @@ const (
// NoRequestForwardingHeaderName is the name of the header telling Vault // NoRequestForwardingHeaderName is the name of the header telling Vault
// not to use request forwarding // not to use request forwarding
NoRequestForwardingHeaderName = "X-Vault-No-Request-Forwarding" NoRequestForwardingHeaderName = "X-Vault-No-Request-Forwarding"
// MaxRequestSize is the maximum accepted request size. This is to prevent
// a denial of service attack where no Content-Length is provided and the server
// is fed ever more data until it exhausts memory.
MaxRequestSize = 32 * 1024 * 1024
) )
// Handler returns an http.Handler for the API. This can be used on // Handler returns an http.Handler for the API. This can be used on
@ -109,7 +114,10 @@ func stripPrefix(prefix, path string) (string, bool) {
} }
func parseRequest(r *http.Request, out interface{}) error { func parseRequest(r *http.Request, out interface{}) error {
err := jsonutil.DecodeJSONFromReader(r.Body, out) // Limit the maximum number of bytes to MaxRequestSize to protect
// against an indefinite amount of data being read.
limit := &io.LimitedReader{R: r.Body, N: MaxRequestSize}
err := jsonutil.DecodeJSONFromReader(limit, out)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return fmt.Errorf("Failed to parse JSON input: %s", err) return fmt.Errorf("Failed to parse JSON input: %s", err)
} }
@ -245,10 +253,19 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) {
} }
// requestAuth adds the token to the logical.Request if it exists. // requestAuth adds the token to the logical.Request if it exists.
func requestAuth(r *http.Request, req *logical.Request) *logical.Request { func requestAuth(core *vault.Core, r *http.Request, req *logical.Request) *logical.Request {
// Attach the header value if we have it // Attach the header value if we have it
if v := r.Header.Get(AuthHeaderName); v != "" { if v := r.Header.Get(AuthHeaderName); v != "" {
req.ClientToken = v req.ClientToken = v
// Also attach the accessor if we have it. This doesn't fail if it
// doesn't exist because the request may be to an unauthenticated
// endpoint/login endpoint where a bad current token doesn't matter, or
// a token from a Vault version pre-accessors.
te, err := core.LookupToken(v)
if err == nil && te != nil {
req.ClientTokenAccessor = te.Accessor
}
} }
return req return req

View file

@ -27,11 +27,13 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, req *http.Request) {
return return
} }
resp, err := core.HandleRequest(requestAuth(req, &logical.Request{ lreq := requestAuth(core, req, &logical.Request{
Operation: logical.HelpOperation, Operation: logical.HelpOperation,
Path: path, Path: path,
Connection: getConnection(req), Connection: getConnection(req),
})) })
resp, err := core.HandleRequest(lreq)
if err != nil { if err != nil {
respondError(w, http.StatusInternalServerError, err) respondError(w, http.StatusInternalServerError, err)
return return

View file

@ -16,7 +16,7 @@ import (
type PrepareRequestFunc func(*vault.Core, *logical.Request) error type PrepareRequestFunc func(*vault.Core, *logical.Request) error
func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Request, int, error) { func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, int, error) {
// Determine the path... // Determine the path...
if !strings.HasPrefix(r.URL.Path, "/v1/") { if !strings.HasPrefix(r.URL.Path, "/v1/") {
return nil, http.StatusNotFound, nil return nil, http.StatusNotFound, nil
@ -26,6 +26,11 @@ func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Reque
return nil, http.StatusNotFound, nil return nil, http.StatusNotFound, nil
} }
// Verify the content length does not exceed the maximum size
if r.ContentLength >= MaxRequestSize {
return nil, http.StatusRequestEntityTooLarge, nil
}
// Determine the operation // Determine the operation
var op logical.Operation var op logical.Operation
switch r.Method { switch r.Method {
@ -72,13 +77,14 @@ func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Reque
return nil, http.StatusBadRequest, errwrap.Wrapf("failed to generate identifier for the request: {{err}}", err) return nil, http.StatusBadRequest, errwrap.Wrapf("failed to generate identifier for the request: {{err}}", err)
} }
req := requestAuth(r, &logical.Request{ req := requestAuth(core, r, &logical.Request{
ID: request_id, ID: request_id,
Operation: op, Operation: op,
Path: path, Path: path,
Data: data, Data: data,
Connection: getConnection(r), Connection: getConnection(r),
}) })
req, err = requestWrapTTL(r, req) req, err = requestWrapTTL(r, req)
if err != nil { if err != nil {
return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err) return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
@ -89,7 +95,7 @@ func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Reque
func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback PrepareRequestFunc) http.Handler { func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback PrepareRequestFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, statusCode, err := buildLogicalRequest(w, r) req, statusCode, err := buildLogicalRequest(core, w, r)
if err != nil || statusCode != 0 { if err != nil || statusCode != 0 {
respondError(w, statusCode, err) respondError(w, statusCode, err)
return return

View file

@ -231,3 +231,16 @@ func TestLogical_RawHTTP(t *testing.T) {
t.Fatalf("Bad: %s", body.Bytes()) t.Fatalf("Bad: %s", body.Bytes())
} }
} }
func TestLogical_RequestSizeLimit(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := TestServer(t, core)
defer ln.Close()
TestServerAuth(t, addr, token)
// Write a very large object, should fail
resp := testHttpPut(t, token, addr+"/v1/secret/foo", map[string]interface{}{
"data": make([]byte, MaxRequestSize),
})
testResponseStatus(t, resp, 413)
}

View file

@ -15,7 +15,7 @@ import (
func handleSysSeal(core *vault.Core) http.Handler { func handleSysSeal(core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, statusCode, err := buildLogicalRequest(w, r) req, statusCode, err := buildLogicalRequest(core, w, r)
if err != nil || statusCode != 0 { if err != nil || statusCode != 0 {
respondError(w, statusCode, err) respondError(w, statusCode, err)
return return
@ -30,8 +30,13 @@ func handleSysSeal(core *vault.Core) http.Handler {
// Seal with the token above // Seal with the token above
if err := core.SealWithRequest(req); err != nil { if err := core.SealWithRequest(req); err != nil {
respondError(w, http.StatusInternalServerError, err) if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
return respondError(w, http.StatusForbidden, err)
return
} else {
respondError(w, http.StatusInternalServerError, err)
return
}
} }
respondOk(w, nil) respondOk(w, nil)
@ -40,7 +45,7 @@ func handleSysSeal(core *vault.Core) http.Handler {
func handleSysStepDown(core *vault.Core) http.Handler { func handleSysStepDown(core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, statusCode, err := buildLogicalRequest(w, r) req, statusCode, err := buildLogicalRequest(core, w, r)
if err != nil || statusCode != 0 { if err != nil || statusCode != 0 {
respondError(w, statusCode, err) respondError(w, statusCode, err)
return return

View file

@ -285,7 +285,7 @@ func TestSysSeal_Permissions(t *testing.T) {
// We expect this to fail since it needs update and sudo // We expect this to fail since it needs update and sudo
httpResp := testHttpPut(t, "child", addr+"/v1/sys/seal", nil) httpResp := testHttpPut(t, "child", addr+"/v1/sys/seal", nil)
testResponseStatus(t, httpResp, 500) testResponseStatus(t, httpResp, 403)
// Now modify to add update capability // Now modify to add update capability
req = &logical.Request{ req = &logical.Request{
@ -306,7 +306,7 @@ func TestSysSeal_Permissions(t *testing.T) {
// We expect this to fail since it needs sudo // We expect this to fail since it needs sudo
httpResp = testHttpPut(t, "child", addr+"/v1/sys/seal", nil) httpResp = testHttpPut(t, "child", addr+"/v1/sys/seal", nil)
testResponseStatus(t, httpResp, 500) testResponseStatus(t, httpResp, 403)
// Now modify to just sudo capability // Now modify to just sudo capability
req = &logical.Request{ req = &logical.Request{
@ -327,7 +327,7 @@ func TestSysSeal_Permissions(t *testing.T) {
// We expect this to fail since it needs update // We expect this to fail since it needs update
httpResp = testHttpPut(t, "child", addr+"/v1/sys/seal", nil) httpResp = testHttpPut(t, "child", addr+"/v1/sys/seal", nil)
testResponseStatus(t, httpResp, 500) testResponseStatus(t, httpResp, 403)
// Now modify to add all needed capabilities // Now modify to add all needed capabilities
req = &logical.Request{ req = &logical.Request{

View file

@ -47,6 +47,10 @@ type Request struct {
// hashed. // hashed.
ClientToken string `json:"client_token" structs:"client_token" mapstructure:"client_token"` ClientToken string `json:"client_token" structs:"client_token" mapstructure:"client_token"`
// ClientTokenAccessor is provided to the core so that the it can get
// logged as part of request audit logging.
ClientTokenAccessor string `json:"client_token_accessor" structs:"client_token_accessor" mapstructure:"client_token_accessor"`
// DisplayName is provided to the logical backend to help associate // DisplayName is provided to the logical backend to help associate
// dynamic secrets with the source entity. This is not a sensitive // dynamic secrets with the source entity. This is not a sensitive
// name, but is useful for operators. // name, but is useful for operators.

View file

@ -9,7 +9,7 @@ import (
const ( const (
// DefaultCacheSize is used if no cache size is specified for NewCache // DefaultCacheSize is used if no cache size is specified for NewCache
DefaultCacheSize = 1024 * 1024 DefaultCacheSize = 32 * 1024
) )
// Cache is used to wrap an underlying physical backend // Cache is used to wrap an underlying physical backend
@ -45,7 +45,9 @@ func (c *Cache) Purge() {
func (c *Cache) Put(entry *Entry) error { func (c *Cache) Put(entry *Entry) error {
err := c.backend.Put(entry) err := c.backend.Put(entry)
c.lru.Add(entry.Key, entry) if err == nil {
c.lru.Add(entry.Key, entry)
}
return err return err
} }
@ -78,7 +80,9 @@ func (c *Cache) Get(key string) (*Entry, error) {
func (c *Cache) Delete(key string) error { func (c *Cache) Delete(key string) error {
err := c.backend.Delete(key) err := c.backend.Delete(key)
c.lru.Remove(key) if err == nil {
c.lru.Remove(key)
}
return err return err
} }

View file

@ -173,6 +173,9 @@ func (m *MySQLBackend) List(prefix string) ([]string, error) {
// Add the % wildcard to the prefix to do the prefix search // Add the % wildcard to the prefix to do the prefix search
likePrefix := prefix + "%" likePrefix := prefix + "%"
rows, err := m.statements["list"].Query(likePrefix) rows, err := m.statements["list"].Query(likePrefix)
if err != nil {
return nil, fmt.Errorf("failed to execute statement: %v", err)
}
var keys []string var keys []string
for rows.Next() { for rows.Next() {

View file

@ -10,7 +10,7 @@ RUN apt-get update -y && apt-get install --no-install-recommends -y -q \
git mercurial bzr \ git mercurial bzr \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
ENV GOVERSION 1.7.1 ENV GOVERSION 1.7.3
RUN mkdir /goroot && mkdir /gopath RUN mkdir /goroot && mkdir /gopath
RUN curl https://storage.googleapis.com/golang/go${GOVERSION}.linux-amd64.tar.gz \ RUN curl https://storage.googleapis.com/golang/go${GOVERSION}.linux-amd64.tar.gz \
| tar xvzf - -C /goroot --strip-components=1 | tar xvzf - -C /goroot --strip-components=1

View file

@ -29,13 +29,11 @@ func makePolynomial(intercept, degree uint8) (polynomial, error) {
// Ensure the intercept is set // Ensure the intercept is set
p.coefficients[0] = intercept p.coefficients[0] = intercept
// Assign random co-efficients to the polynomial, ensuring // Assign random co-efficients to the polynomial
// the highest order co-efficient is non-zero if _, err := rand.Read(p.coefficients[1:]); err != nil {
for p.coefficients[degree] == 0 { return p, err
if _, err := rand.Read(p.coefficients[1:]); err != nil {
return p, err
}
} }
return p, nil return p, nil
} }

View file

@ -256,8 +256,10 @@ func (c *Core) teardownAudits() error {
c.auditLock.Lock() c.auditLock.Lock()
defer c.auditLock.Unlock() defer c.auditLock.Unlock()
for _, entry := range c.audit.Entries { if c.audit != nil {
c.removeAuditReloadFunc(entry) for _, entry := range c.audit.Entries {
c.removeAuditReloadFunc(entry)
}
} }
c.audit = nil c.audit = nil

View file

@ -350,6 +350,17 @@ func (c *Core) teardownCredentials() error {
c.authLock.Lock() c.authLock.Lock()
defer c.authLock.Unlock() defer c.authLock.Unlock()
if c.auth != nil {
authTable := c.auth.shallowClone()
for _, e := range authTable.Entries {
prefix := e.Path
b, ok := c.router.root.Get(prefix)
if ok {
b.(*routeEntry).backend.Cleanup()
}
}
}
c.auth = nil c.auth = nil
c.tokenStore = nil c.tokenStore = nil
return nil return nil

View file

@ -316,6 +316,10 @@ func (c *Core) stopClusterListener() {
return return
} }
if !c.clusterListenersRunning {
c.logger.Info("core/stopClusterListener: listeners not running")
return
}
c.logger.Info("core/stopClusterListener: stopping listeners") c.logger.Info("core/stopClusterListener: stopping listeners")
// Tell the goroutine managing the listeners to perform the shutdown // Tell the goroutine managing the listeners to perform the shutdown
@ -327,6 +331,8 @@ func (c *Core) stopClusterListener() {
// bind errors. This ensures proper ordering. // bind errors. This ensures proper ordering.
c.logger.Trace("core/stopClusterListener: waiting for success notification") c.logger.Trace("core/stopClusterListener: waiting for success notification")
<-c.clusterListenerShutdownSuccessCh <-c.clusterListenerShutdownSuccessCh
c.clusterListenersRunning = false
c.logger.Info("core/stopClusterListener: success") c.logger.Info("core/stopClusterListener: success")
} }
@ -417,21 +423,3 @@ func WrapHandlerForClustering(handler http.Handler, logger log.Logger) func() (h
return handler, mux return handler, mux
} }
} }
// WrapListenersForClustering takes in Vault's cluster addresses and returns a
// setup function that creates the new listeners
func WrapListenersForClustering(addrs []string, logger log.Logger) func() ([]net.Listener, error) {
return func() ([]net.Listener, error) {
ret := make([]net.Listener, 0, len(addrs))
// Loop over the existing listeners and start listeners on appropriate ports
for _, addr := range addrs {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
ret = append(ret, ln)
}
return ret, nil
}
}

View file

@ -16,6 +16,10 @@ import (
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
) )
var (
clusterTestPausePeriod = 2 * time.Second
)
func TestClusterFetching(t *testing.T) { func TestClusterFetching(t *testing.T) {
c, _, _ := TestCoreUnsealed(t) c, _, _ := TestCoreUnsealed(t)
@ -109,16 +113,16 @@ func TestCluster_ListenForRequests(t *testing.T) {
t.Fatal("%s not a TCP port", tcpAddr.String()) t.Fatal("%s not a TCP port", tcpAddr.String())
} }
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port+1), tlsConfig) conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port+10), tlsConfig)
if err != nil { if err != nil {
if expectFail { if expectFail {
t.Logf("testing %s:%d unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+1) t.Logf("testing %s:%d unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+10)
continue continue
} }
t.Fatalf("error: %v\nlisteners are\n%#v\n%#v\n", err, cores[0].Listeners[0], cores[0].Listeners[1]) t.Fatalf("error: %v\nlisteners are\n%#v\n%#v\n", err, cores[0].Listeners[0], cores[0].Listeners[1])
} }
if expectFail { if expectFail {
t.Fatalf("testing %s:%d not unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+1) t.Fatalf("testing %s:%d not unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+10)
} }
err = conn.Handshake() err = conn.Handshake()
if err != nil { if err != nil {
@ -131,11 +135,11 @@ func TestCluster_ListenForRequests(t *testing.T) {
case connState.NegotiatedProtocol != "h2" || !connState.NegotiatedProtocolIsMutual: case connState.NegotiatedProtocol != "h2" || !connState.NegotiatedProtocolIsMutual:
t.Fatal("bad protocol negotiation") t.Fatal("bad protocol negotiation")
} }
t.Logf("testing %s:%d successful", tcpAddr.IP.String(), tcpAddr.Port+1) t.Logf("testing %s:%d successful", tcpAddr.IP.String(), tcpAddr.Port+10)
} }
} }
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
checkListenersFunc(false) checkListenersFunc(false)
err := cores[0].StepDown(&logical.Request{ err := cores[0].StepDown(&logical.Request{
@ -149,7 +153,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
// StepDown doesn't wait during actual preSeal so give time for listeners // StepDown doesn't wait during actual preSeal so give time for listeners
// to close // to close
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
checkListenersFunc(true) checkListenersFunc(true)
// After this period it should be active again // After this period it should be active again
@ -160,7 +164,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
// After sealing it should be inactive again // After sealing it should be inactive again
checkListenersFunc(true) checkListenersFunc(true)
} }
@ -230,13 +234,13 @@ func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
_ = cores[2].StepDown(&logical.Request{ _ = cores[2].StepDown(&logical.Request{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
Path: "sys/step-down", Path: "sys/step-down",
ClientToken: root, ClientToken: root,
}) })
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
TestWaitActive(t, cores[1].Core) TestWaitActive(t, cores[1].Core)
testCluster_ForwardRequests(t, cores[0], "core2") testCluster_ForwardRequests(t, cores[0], "core2")
testCluster_ForwardRequests(t, cores[2], "core2") testCluster_ForwardRequests(t, cores[2], "core2")
@ -250,13 +254,13 @@ func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
_ = cores[0].StepDown(&logical.Request{ _ = cores[0].StepDown(&logical.Request{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
Path: "sys/step-down", Path: "sys/step-down",
ClientToken: root, ClientToken: root,
}) })
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
TestWaitActive(t, cores[2].Core) TestWaitActive(t, cores[2].Core)
testCluster_ForwardRequests(t, cores[0], "core3") testCluster_ForwardRequests(t, cores[0], "core3")
testCluster_ForwardRequests(t, cores[1], "core3") testCluster_ForwardRequests(t, cores[1], "core3")
@ -270,13 +274,13 @@ func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
_ = cores[1].StepDown(&logical.Request{ _ = cores[1].StepDown(&logical.Request{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
Path: "sys/step-down", Path: "sys/step-down",
ClientToken: root, ClientToken: root,
}) })
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
TestWaitActive(t, cores[0].Core) TestWaitActive(t, cores[0].Core)
testCluster_ForwardRequests(t, cores[1], "core1") testCluster_ForwardRequests(t, cores[1], "core1")
testCluster_ForwardRequests(t, cores[2], "core1") testCluster_ForwardRequests(t, cores[2], "core1")
@ -290,13 +294,13 @@ func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
_ = cores[2].StepDown(&logical.Request{ _ = cores[2].StepDown(&logical.Request{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
Path: "sys/step-down", Path: "sys/step-down",
ClientToken: root, ClientToken: root,
}) })
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
TestWaitActive(t, cores[1].Core) TestWaitActive(t, cores[1].Core)
testCluster_ForwardRequests(t, cores[0], "core2") testCluster_ForwardRequests(t, cores[0], "core2")
testCluster_ForwardRequests(t, cores[2], "core2") testCluster_ForwardRequests(t, cores[2], "core2")
@ -310,13 +314,13 @@ func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
_ = cores[0].StepDown(&logical.Request{ _ = cores[0].StepDown(&logical.Request{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
Path: "sys/step-down", Path: "sys/step-down",
ClientToken: root, ClientToken: root,
}) })
time.Sleep(2 * time.Second) time.Sleep(clusterTestPausePeriod)
TestWaitActive(t, cores[2].Core) TestWaitActive(t, cores[2].Core)
testCluster_ForwardRequests(t, cores[0], "core3") testCluster_ForwardRequests(t, cores[0], "core3")
testCluster_ForwardRequests(t, cores[1], "core3") testCluster_ForwardRequests(t, cores[1], "core3")

View file

@ -13,12 +13,12 @@ import (
"sync" "sync"
"time" "time"
"github.com/armon/go-metrics"
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/armon/go-metrics"
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
@ -269,6 +269,9 @@ type Core struct {
clusterListenerAddrs []*net.TCPAddr clusterListenerAddrs []*net.TCPAddr
// The setup function that gives us the handler to use // The setup function that gives us the handler to use
clusterHandlerSetupFunc func() (http.Handler, http.Handler) clusterHandlerSetupFunc func() (http.Handler, http.Handler)
// Tracks whether cluster listeners are running, e.g. it's safe to send a
// shutdown down the channel
clusterListenersRunning bool
// Shutdown channel for the cluster listeners // Shutdown channel for the cluster listeners
clusterListenerShutdownCh chan struct{} clusterListenerShutdownCh chan struct{}
// Shutdown success channel. We need this to be done serially to ensure // Shutdown success channel. We need this to be done serially to ensure
@ -492,6 +495,23 @@ func (c *Core) Shutdown() error {
return c.sealInternal() return c.sealInternal()
} }
// LookupToken returns the properties of the token from the token store. This
// is particularly useful to fetch the accessor of the client token and get it
// populated in the logical request along with the client token. The accessor
// of the client token can get audit logged.
func (c *Core) LookupToken(token string) (*TokenEntry, error) {
if token == "" {
return nil, fmt.Errorf("missing client token")
}
// Many tests don't have a token store running
if c.tokenStore == nil {
return nil, nil
}
return c.tokenStore.Lookup(token)
}
func (c *Core) fetchACLandTokenEntry(req *logical.Request) (*ACL, *TokenEntry, error) { func (c *Core) fetchACLandTokenEntry(req *logical.Request) (*ACL, *TokenEntry, error) {
defer metrics.MeasureSince([]string{"core", "fetch_acl_and_token"}, time.Now()) defer metrics.MeasureSince([]string{"core", "fetch_acl_and_token"}, time.Now())

View file

@ -1679,11 +1679,13 @@ func (b *SystemBackend) handleWrappingRewrap(
// Return response in "response"; wrapping code will detect the rewrap and // Return response in "response"; wrapping code will detect the rewrap and
// slot in instead of nesting // slot in instead of nesting
req.WrapTTL = time.Duration(creationTTL)
return &logical.Response{ return &logical.Response{
Data: map[string]interface{}{ Data: map[string]interface{}{
"response": response, "response": response,
}, },
WrapInfo: &logical.WrapInfo{
TTL: time.Duration(creationTTL),
},
}, nil }, nil
} }

View file

@ -68,7 +68,9 @@ func (c *Core) startForwarding() error {
go func() { go func() {
defer shutdownWg.Done() defer shutdownWg.Done()
c.logger.Info("core/startClusterListener: starting listener") if c.logger.IsInfo() {
c.logger.Info("core/startClusterListener: starting listener", "listener_address", laddr)
}
// Create a TCP listener. We do this separately and specifically // Create a TCP listener. We do this separately and specifically
// with TCP so that we can set deadlines. // with TCP so that we can set deadlines.
@ -143,6 +145,10 @@ func (c *Core) startForwarding() error {
// This is in its own goroutine so that we don't block the main thread, and // This is in its own goroutine so that we don't block the main thread, and
// thus we use atomic and channels to coordinate // thus we use atomic and channels to coordinate
// However, because you can't query the status of a channel, we set a bool
// here while we have the state lock to know whether to actually send a
// shutdown (e.g. whether the channel will block). See issue #2083.
c.clusterListenersRunning = true
go func() { go func() {
// If we get told to shut down... // If we get told to shut down...
<-c.clusterListenerShutdownCh <-c.clusterListenerShutdownCh

View file

@ -39,7 +39,7 @@ var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file // This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against. // is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion3 const _ = grpc.SupportPackageIsVersion4
// Client API for RequestForwarding service // Client API for RequestForwarding service
@ -102,7 +102,7 @@ var _RequestForwarding_serviceDesc = grpc.ServiceDesc{
}, },
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: fileDescriptor0, Metadata: "request_forwarding_service.proto",
} }
func init() { proto.RegisterFile("request_forwarding_service.proto", fileDescriptor0) } func init() { proto.RegisterFile("request_forwarding_service.proto", fileDescriptor0) }

View file

@ -186,12 +186,29 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
// Route the request // Route the request
resp, err := c.router.Route(req) resp, err := c.router.Route(req)
if resp != nil { if resp != nil {
// We don't allow backends to specify this, so ensure it's not set // If wrapping is used, use the shortest between the request and response
resp.WrapInfo = nil var wrapTTL time.Duration
if req.WrapTTL != 0 { // Ensure no wrap info information is set other than, possibly, the TTL
if resp.WrapInfo != nil {
if resp.WrapInfo.TTL > 0 {
wrapTTL = resp.WrapInfo.TTL
}
resp.WrapInfo = nil
}
if req.WrapTTL > 0 {
switch {
case wrapTTL == 0:
wrapTTL = req.WrapTTL
case req.WrapTTL < wrapTTL:
wrapTTL = req.WrapTTL
}
}
if wrapTTL > 0 {
resp.WrapInfo = &logical.WrapInfo{ resp.WrapInfo = &logical.WrapInfo{
TTL: req.WrapTTL, TTL: wrapTTL,
} }
} }
} }
@ -306,14 +323,32 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
// Route the request // Route the request
resp, err := c.router.Route(req) resp, err := c.router.Route(req)
if resp != nil { if resp != nil {
// We don't allow backends to specify this, so ensure it's not set // If wrapping is used, use the shortest between the request and response
resp.WrapInfo = nil var wrapTTL time.Duration
if req.WrapTTL != 0 { // Ensure no wrap info information is set other than, possibly, the TTL
resp.WrapInfo = &logical.WrapInfo{ if resp.WrapInfo != nil {
TTL: req.WrapTTL, if resp.WrapInfo.TTL > 0 {
wrapTTL = resp.WrapInfo.TTL
}
resp.WrapInfo = nil
}
if req.WrapTTL > 0 {
switch {
case wrapTTL == 0:
wrapTTL = req.WrapTTL
case req.WrapTTL < wrapTTL:
wrapTTL = req.WrapTTL
} }
} }
if wrapTTL > 0 {
resp.WrapInfo = &logical.WrapInfo{
TTL: wrapTTL,
}
}
} }
// A login request should never return a secret! // A login request should never return a secret!

View file

@ -263,12 +263,7 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica
req.ID = originalReqID req.ID = originalReqID
req.Storage = nil req.Storage = nil
req.ClientToken = clientToken req.ClientToken = clientToken
req.WrapTTL = originalWrapTTL
// Only the rewrap endpoint is allowed to declare a wrap TTL on a
// request that did not come from the client
if req.Path != "sys/wrapping/rewrap" {
req.WrapTTL = originalWrapTTL
}
}() }()
// Invoke the backend // Invoke the backend

View file

@ -584,6 +584,11 @@ func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unseal
} }
// Create three cores with the same physical and different redirect/cluster addrs // Create three cores with the same physical and different redirect/cluster addrs
// N.B.: On OSX, instead of random ports, it assigns new ports to new
// listeners sequentially. Aside from being a bad idea in a security sense,
// it also broke tests that assumed it was OK to just use the port above
// the redirect addr. This has now been changed to 10 ports above, but if
// we ever do more than three nodes in a cluster it may need to be bumped.
coreConfig := &CoreConfig{ coreConfig := &CoreConfig{
Physical: physical.NewInmem(logger), Physical: physical.NewInmem(logger),
HAPhysical: physical.NewInmemHA(logger), HAPhysical: physical.NewInmemHA(logger),
@ -591,7 +596,7 @@ func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unseal
CredentialBackends: make(map[string]logical.Factory), CredentialBackends: make(map[string]logical.Factory),
AuditBackends: make(map[string]audit.Factory), AuditBackends: make(map[string]audit.Factory),
RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port), RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port),
ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port+1), ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port+10),
DisableMlock: true, DisableMlock: true,
} }
@ -629,7 +634,7 @@ func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unseal
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port) coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port)
if coreConfig.ClusterAddr != "" { if coreConfig.ClusterAddr != "" {
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port+1) coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port+10)
} }
c2, err := NewCore(coreConfig) c2, err := NewCore(coreConfig)
if err != nil { if err != nil {
@ -638,7 +643,7 @@ func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unseal
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port) coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port)
if coreConfig.ClusterAddr != "" { if coreConfig.ClusterAddr != "" {
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port+1) coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port+10)
} }
c3, err := NewCore(coreConfig) c3, err := NewCore(coreConfig)
if err != nil { if err != nil {
@ -653,7 +658,7 @@ func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unseal
for i, ln := range lns { for i, ln := range lns {
ret[i] = &net.TCPAddr{ ret[i] = &net.TCPAddr{
IP: ln.Address.IP, IP: ln.Address.IP,
Port: ln.Address.Port + 1, Port: ln.Address.Port + 10,
} }
} }
return ret return ret

View file

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@ -256,6 +257,23 @@ const (
blobCopyStatusFailed = "failed" blobCopyStatusFailed = "failed"
) )
// lease constants.
const (
leaseHeaderPrefix = "x-ms-lease-"
leaseID = "x-ms-lease-id"
leaseAction = "x-ms-lease-action"
leaseBreakPeriod = "x-ms-lease-break-period"
leaseDuration = "x-ms-lease-duration"
leaseProposedID = "x-ms-proposed-lease-id"
leaseTime = "x-ms-lease-time"
acquireLease = "acquire"
renewLease = "renew"
changeLease = "change"
releaseLease = "release"
breakLease = "break"
)
// BlockListType is used to filter out types of blocks in a Get Blocks List call // BlockListType is used to filter out types of blocks in a Get Blocks List call
// for a block blob. // for a block blob.
// //
@ -284,6 +302,65 @@ const (
ContainerAccessTypeContainer ContainerAccessType = "container" ContainerAccessTypeContainer ContainerAccessType = "container"
) )
// ContainerAccessOptions are used when setting ACLs of containers (after creation)
type ContainerAccessOptions struct {
ContainerAccess ContainerAccessType
Timeout int
LeaseID string
}
// AccessPolicyDetails are used for SETTING policies
type AccessPolicyDetails struct {
ID string
StartTime time.Time
ExpiryTime time.Time
CanRead bool
CanWrite bool
CanDelete bool
}
// ContainerPermissions is used when setting permissions and Access Policies for containers.
type ContainerPermissions struct {
AccessOptions ContainerAccessOptions
AccessPolicy AccessPolicyDetails
}
// AccessPolicyDetailsXML has specifics about an access policy
// annotated with XML details.
type AccessPolicyDetailsXML struct {
StartTime time.Time `xml:"Start"`
ExpiryTime time.Time `xml:"Expiry"`
Permission string `xml:"Permission"`
}
// SignedIdentifier is a wrapper for a specific policy
type SignedIdentifier struct {
ID string `xml:"Id"`
AccessPolicy AccessPolicyDetailsXML `xml:"AccessPolicy"`
}
// SignedIdentifiers part of the response from GetPermissions call.
type SignedIdentifiers struct {
SignedIdentifiers []SignedIdentifier `xml:"SignedIdentifier"`
}
// AccessPolicy is the response type from the GetPermissions call.
type AccessPolicy struct {
SignedIdentifiersList SignedIdentifiers `xml:"SignedIdentifiers"`
}
// ContainerAccessResponse is returned for the GetContainerPermissions function.
// This contains both the permission and access policy for the container.
type ContainerAccessResponse struct {
ContainerAccess ContainerAccessType
AccessPolicy SignedIdentifiers
}
// ContainerAccessHeader references header used when setting/getting container ACL
const (
ContainerAccessHeader string = "x-ms-blob-public-access"
)
// Maximum sizes (per REST API) for various concepts // Maximum sizes (per REST API) for various concepts
const ( const (
MaxBlobBlockSize = 4 * 1024 * 1024 MaxBlobBlockSize = 4 * 1024 * 1024
@ -399,7 +476,7 @@ func (b BlobStorageClient) createContainer(name string, access ContainerAccessTy
headers := b.client.getStandardHeaders() headers := b.client.getStandardHeaders()
if access != "" { if access != "" {
headers["x-ms-blob-public-access"] = string(access) headers[ContainerAccessHeader] = string(access)
} }
return b.client.exec(verb, uri, headers, nil) return b.client.exec(verb, uri, headers, nil)
} }
@ -421,6 +498,101 @@ func (b BlobStorageClient) ContainerExists(name string) (bool, error) {
return false, err return false, err
} }
// SetContainerPermissions sets up container permissions as per https://msdn.microsoft.com/en-us/library/azure/dd179391.aspx
func (b BlobStorageClient) SetContainerPermissions(container string, containerPermissions ContainerPermissions) (err error) {
params := url.Values{
"restype": {"container"},
"comp": {"acl"},
}
if containerPermissions.AccessOptions.Timeout > 0 {
params.Add("timeout", strconv.Itoa(containerPermissions.AccessOptions.Timeout))
}
uri := b.client.getEndpoint(blobServiceName, pathForContainer(container), params)
headers := b.client.getStandardHeaders()
if containerPermissions.AccessOptions.ContainerAccess != "" {
headers[ContainerAccessHeader] = string(containerPermissions.AccessOptions.ContainerAccess)
}
if containerPermissions.AccessOptions.LeaseID != "" {
headers[leaseID] = containerPermissions.AccessOptions.LeaseID
}
// generate the XML for the SharedAccessSignature if required.
accessPolicyXML, err := generateAccessPolicy(containerPermissions.AccessPolicy)
if err != nil {
return err
}
var resp *storageResponse
if accessPolicyXML != "" {
headers["Content-Length"] = strconv.Itoa(len(accessPolicyXML))
resp, err = b.client.exec("PUT", uri, headers, strings.NewReader(accessPolicyXML))
} else {
resp, err = b.client.exec("PUT", uri, headers, nil)
}
if err != nil {
return err
}
if resp != nil {
defer func() {
err = resp.body.Close()
}()
if resp.statusCode != http.StatusOK {
return errors.New("Unable to set permissions")
}
}
return nil
}
// GetContainerPermissions gets the container permissions as per https://msdn.microsoft.com/en-us/library/azure/dd179469.aspx
// If timeout is 0 then it will not be passed to Azure
// leaseID will only be passed to Azure if populated
// Returns permissionResponse which is combined permissions and AccessPolicy
func (b BlobStorageClient) GetContainerPermissions(container string, timeout int, leaseID string) (permissionResponse *ContainerAccessResponse, err error) {
params := url.Values{"restype": {"container"},
"comp": {"acl"}}
if timeout > 0 {
params.Add("timeout", strconv.Itoa(timeout))
}
uri := b.client.getEndpoint(blobServiceName, pathForContainer(container), params)
headers := b.client.getStandardHeaders()
if leaseID != "" {
headers[leaseID] = leaseID
}
resp, err := b.client.exec("GET", uri, headers, nil)
if err != nil {
return nil, err
}
// containerAccess. Blob, Container, empty
containerAccess := resp.headers.Get(http.CanonicalHeaderKey(ContainerAccessHeader))
defer func() {
err = resp.body.Close()
}()
var out AccessPolicy
err = xmlUnmarshal(resp.body, &out.SignedIdentifiersList)
if err != nil {
return nil, err
}
permissionResponse = &ContainerAccessResponse{}
permissionResponse.AccessPolicy = out.SignedIdentifiersList
permissionResponse.ContainerAccess = ContainerAccessType(containerAccess)
return permissionResponse, nil
}
// DeleteContainer deletes the container with given name on the storage // DeleteContainer deletes the container with given name on the storage
// account. If the container does not exist returns error. // account. If the container does not exist returns error.
// //
@ -560,6 +732,132 @@ func (b BlobStorageClient) getBlobRange(container, name, bytesRange string, extr
return resp, err return resp, err
} }
// leasePut is common PUT code for the various aquire/release/break etc functions.
func (b BlobStorageClient) leaseCommonPut(container string, name string, headers map[string]string, expectedStatus int) (http.Header, error) {
params := url.Values{"comp": {"lease"}}
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), params)
resp, err := b.client.exec("PUT", uri, headers, nil)
if err != nil {
return nil, err
}
defer resp.body.Close()
if err := checkRespCode(resp.statusCode, []int{expectedStatus}); err != nil {
return nil, err
}
return resp.headers, nil
}
// AcquireLease creates a lease for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
// returns leaseID acquired
func (b BlobStorageClient) AcquireLease(container string, name string, leaseTimeInSeconds int, proposedLeaseID string) (returnedLeaseID string, err error) {
headers := b.client.getStandardHeaders()
headers[leaseAction] = acquireLease
headers[leaseProposedID] = proposedLeaseID
headers[leaseDuration] = strconv.Itoa(leaseTimeInSeconds)
respHeaders, err := b.leaseCommonPut(container, name, headers, http.StatusCreated)
if err != nil {
return "", err
}
returnedLeaseID = respHeaders.Get(http.CanonicalHeaderKey(leaseID))
if returnedLeaseID != "" {
return returnedLeaseID, nil
}
return "", errors.New("LeaseID not returned")
}
// BreakLease breaks the lease for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
// Returns the timeout remaining in the lease in seconds
func (b BlobStorageClient) BreakLease(container string, name string) (breakTimeout int, err error) {
headers := b.client.getStandardHeaders()
headers[leaseAction] = breakLease
return b.breakLeaseCommon(container, name, headers)
}
// BreakLeaseWithBreakPeriod breaks the lease for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
// breakPeriodInSeconds is used to determine how long until new lease can be created.
// Returns the timeout remaining in the lease in seconds
func (b BlobStorageClient) BreakLeaseWithBreakPeriod(container string, name string, breakPeriodInSeconds int) (breakTimeout int, err error) {
headers := b.client.getStandardHeaders()
headers[leaseAction] = breakLease
headers[leaseBreakPeriod] = strconv.Itoa(breakPeriodInSeconds)
return b.breakLeaseCommon(container, name, headers)
}
// breakLeaseCommon is common code for both version of BreakLease (with and without break period)
func (b BlobStorageClient) breakLeaseCommon(container string, name string, headers map[string]string) (breakTimeout int, err error) {
respHeaders, err := b.leaseCommonPut(container, name, headers, http.StatusAccepted)
if err != nil {
return 0, err
}
breakTimeoutStr := respHeaders.Get(http.CanonicalHeaderKey(leaseTime))
if breakTimeoutStr != "" {
breakTimeout, err = strconv.Atoi(breakTimeoutStr)
if err != nil {
return 0, err
}
}
return breakTimeout, nil
}
// ChangeLease changes a lease ID for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
// Returns the new LeaseID acquired
func (b BlobStorageClient) ChangeLease(container string, name string, currentLeaseID string, proposedLeaseID string) (newLeaseID string, err error) {
headers := b.client.getStandardHeaders()
headers[leaseAction] = changeLease
headers[leaseID] = currentLeaseID
headers[leaseProposedID] = proposedLeaseID
respHeaders, err := b.leaseCommonPut(container, name, headers, http.StatusOK)
if err != nil {
return "", err
}
newLeaseID = respHeaders.Get(http.CanonicalHeaderKey(leaseID))
if newLeaseID != "" {
return newLeaseID, nil
}
return "", errors.New("LeaseID not returned")
}
// ReleaseLease releases the lease for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
func (b BlobStorageClient) ReleaseLease(container string, name string, currentLeaseID string) error {
headers := b.client.getStandardHeaders()
headers[leaseAction] = releaseLease
headers[leaseID] = currentLeaseID
_, err := b.leaseCommonPut(container, name, headers, http.StatusOK)
if err != nil {
return err
}
return nil
}
// RenewLease renews the lease for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
func (b BlobStorageClient) RenewLease(container string, name string, currentLeaseID string) error {
headers := b.client.getStandardHeaders()
headers[leaseAction] = renewLease
headers[leaseID] = currentLeaseID
_, err := b.leaseCommonPut(container, name, headers, http.StatusOK)
if err != nil {
return err
}
return nil
}
// GetBlobProperties provides various information about the specified // GetBlobProperties provides various information about the specified
// blob. See https://msdn.microsoft.com/en-us/library/azure/dd179394.aspx // blob. See https://msdn.microsoft.com/en-us/library/azure/dd179394.aspx
func (b BlobStorageClient) GetBlobProperties(container, name string) (*BlobProperties, error) { func (b BlobStorageClient) GetBlobProperties(container, name string) (*BlobProperties, error) {
@ -961,15 +1259,20 @@ func (b BlobStorageClient) AppendBlock(container, name string, chunk []byte, ext
// //
// See https://msdn.microsoft.com/en-us/library/azure/dd894037.aspx // See https://msdn.microsoft.com/en-us/library/azure/dd894037.aspx
func (b BlobStorageClient) CopyBlob(container, name, sourceBlob string) error { func (b BlobStorageClient) CopyBlob(container, name, sourceBlob string) error {
copyID, err := b.startBlobCopy(container, name, sourceBlob) copyID, err := b.StartBlobCopy(container, name, sourceBlob)
if err != nil { if err != nil {
return err return err
} }
return b.waitForBlobCopy(container, name, copyID) return b.WaitForBlobCopy(container, name, copyID)
} }
func (b BlobStorageClient) startBlobCopy(container, name, sourceBlob string) (string, error) { // StartBlobCopy starts a blob copy operation.
// sourceBlob parameter must be a canonical URL to the blob (can be
// obtained using GetBlobURL method.)
//
// See https://msdn.microsoft.com/en-us/library/azure/dd894037.aspx
func (b BlobStorageClient) StartBlobCopy(container, name, sourceBlob string) (string, error) {
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{}) uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{})
headers := b.client.getStandardHeaders() headers := b.client.getStandardHeaders()
@ -992,7 +1295,39 @@ func (b BlobStorageClient) startBlobCopy(container, name, sourceBlob string) (st
return copyID, nil return copyID, nil
} }
func (b BlobStorageClient) waitForBlobCopy(container, name, copyID string) error { // AbortBlobCopy aborts a BlobCopy which has already been triggered by the StartBlobCopy function.
// copyID is generated from StartBlobCopy function.
// currentLeaseID is required IF the destination blob has an active lease on it.
// As defined in https://msdn.microsoft.com/en-us/library/azure/jj159098.aspx
func (b BlobStorageClient) AbortBlobCopy(container, name, copyID, currentLeaseID string, timeout int) error {
params := url.Values{"comp": {"copy"}, "copyid": {copyID}}
if timeout > 0 {
params.Add("timeout", strconv.Itoa(timeout))
}
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), params)
headers := b.client.getStandardHeaders()
headers["x-ms-copy-action"] = "abort"
if currentLeaseID != "" {
headers[leaseID] = currentLeaseID
}
resp, err := b.client.exec("PUT", uri, headers, nil)
if err != nil {
return err
}
defer resp.body.Close()
if err := checkRespCode(resp.statusCode, []int{http.StatusNoContent}); err != nil {
return err
}
return nil
}
// WaitForBlobCopy loops until a BlobCopy operation is completed (or fails with error)
func (b BlobStorageClient) WaitForBlobCopy(container, name, copyID string) error {
for { for {
props, err := b.GetBlobProperties(container, name) props, err := b.GetBlobProperties(container, name)
if err != nil { if err != nil {
@ -1036,10 +1371,12 @@ func (b BlobStorageClient) DeleteBlob(container, name string, extraHeaders map[s
// See https://msdn.microsoft.com/en-us/library/azure/dd179413.aspx // See https://msdn.microsoft.com/en-us/library/azure/dd179413.aspx
func (b BlobStorageClient) DeleteBlobIfExists(container, name string, extraHeaders map[string]string) (bool, error) { func (b BlobStorageClient) DeleteBlobIfExists(container, name string, extraHeaders map[string]string) (bool, error) {
resp, err := b.deleteBlob(container, name, extraHeaders) resp, err := b.deleteBlob(container, name, extraHeaders)
if resp != nil && (resp.statusCode == http.StatusAccepted || resp.statusCode == http.StatusNotFound) { if resp != nil {
return resp.statusCode == http.StatusAccepted, nil defer resp.body.Close()
if resp.statusCode == http.StatusAccepted || resp.statusCode == http.StatusNotFound {
return resp.statusCode == http.StatusAccepted, nil
}
} }
defer resp.body.Close()
return false, err return false, err
} }
@ -1065,17 +1402,18 @@ func pathForBlob(container, name string) string {
return fmt.Sprintf("/%s/%s", container, name) return fmt.Sprintf("/%s/%s", container, name)
} }
// GetBlobSASURI creates an URL to the specified blob which contains the Shared // GetBlobSASURIWithSignedIPAndProtocol creates an URL to the specified blob which contains the Shared
// Access Signature with specified permissions and expiration time. // Access Signature with specified permissions and expiration time. Also includes signedIPRange and allowed procotols.
// If old API version is used but no signedIP is passed (ie empty string) then this should still work.
// We only populate the signedIP when it non-empty.
// //
// See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx // See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx
func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Time, permissions string) (string, error) { func (b BlobStorageClient) GetBlobSASURIWithSignedIPAndProtocol(container, name string, expiry time.Time, permissions string, signedIPRange string, HTTPSOnly bool) (string, error) {
var ( var (
signedPermissions = permissions signedPermissions = permissions
blobURL = b.GetBlobURL(container, name) blobURL = b.GetBlobURL(container, name)
) )
canonicalizedResource, err := b.client.buildCanonicalizedResource(blobURL) canonicalizedResource, err := b.client.buildCanonicalizedResource(blobURL)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -1087,7 +1425,6 @@ func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Tim
// We need to replace + with %2b first to avoid being treated as a space (which is correct for query strings, but not the path component). // We need to replace + with %2b first to avoid being treated as a space (which is correct for query strings, but not the path component).
canonicalizedResource = strings.Replace(canonicalizedResource, "+", "%2b", -1) canonicalizedResource = strings.Replace(canonicalizedResource, "+", "%2b", -1)
canonicalizedResource, err = url.QueryUnescape(canonicalizedResource) canonicalizedResource, err = url.QueryUnescape(canonicalizedResource)
if err != nil { if err != nil {
return "", err return "", err
@ -1096,7 +1433,11 @@ func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Tim
signedExpiry := expiry.UTC().Format(time.RFC3339) signedExpiry := expiry.UTC().Format(time.RFC3339)
signedResource := "b" signedResource := "b"
stringToSign, err := blobSASStringToSign(b.client.apiVersion, canonicalizedResource, signedExpiry, signedPermissions) protocols := "https,http"
if HTTPSOnly {
protocols = "https"
}
stringToSign, err := blobSASStringToSign(b.client.apiVersion, canonicalizedResource, signedExpiry, signedPermissions, signedIPRange, protocols)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -1110,6 +1451,13 @@ func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Tim
"sig": {sig}, "sig": {sig},
} }
if b.client.apiVersion >= "2015-04-05" {
sasParams.Add("spr", protocols)
if signedIPRange != "" {
sasParams.Add("sip", signedIPRange)
}
}
sasURL, err := url.Parse(blobURL) sasURL, err := url.Parse(blobURL)
if err != nil { if err != nil {
return "", err return "", err
@ -1118,16 +1466,89 @@ func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Tim
return sasURL.String(), nil return sasURL.String(), nil
} }
func blobSASStringToSign(signedVersion, canonicalizedResource, signedExpiry, signedPermissions string) (string, error) { // GetBlobSASURI creates an URL to the specified blob which contains the Shared
// Access Signature with specified permissions and expiration time.
//
// See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx
func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Time, permissions string) (string, error) {
url, err := b.GetBlobSASURIWithSignedIPAndProtocol(container, name, expiry, permissions, "", false)
return url, err
}
func blobSASStringToSign(signedVersion, canonicalizedResource, signedExpiry, signedPermissions string, signedIP string, protocols string) (string, error) {
var signedStart, signedIdentifier, rscc, rscd, rsce, rscl, rsct string var signedStart, signedIdentifier, rscc, rscd, rsce, rscl, rsct string
if signedVersion >= "2015-02-21" { if signedVersion >= "2015-02-21" {
canonicalizedResource = "/blob" + canonicalizedResource canonicalizedResource = "/blob" + canonicalizedResource
} }
// https://msdn.microsoft.com/en-us/library/azure/dn140255.aspx#Anchor_12
if signedVersion >= "2015-04-05" {
return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s", signedPermissions, signedStart, signedExpiry, canonicalizedResource, signedIdentifier, signedIP, protocols, signedVersion, rscc, rscd, rsce, rscl, rsct), nil
}
// reference: http://msdn.microsoft.com/en-us/library/azure/dn140255.aspx // reference: http://msdn.microsoft.com/en-us/library/azure/dn140255.aspx
if signedVersion >= "2013-08-15" { if signedVersion >= "2013-08-15" {
return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s", signedPermissions, signedStart, signedExpiry, canonicalizedResource, signedIdentifier, signedVersion, rscc, rscd, rsce, rscl, rsct), nil return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s", signedPermissions, signedStart, signedExpiry, canonicalizedResource, signedIdentifier, signedVersion, rscc, rscd, rsce, rscl, rsct), nil
} }
return "", errors.New("storage: not implemented SAS for versions earlier than 2013-08-15") return "", errors.New("storage: not implemented SAS for versions earlier than 2013-08-15")
} }
func generatePermissions(accessPolicy AccessPolicyDetails) (permissions string) {
// generate the permissions string (rwd).
// still want the end user API to have bool flags.
permissions = ""
if accessPolicy.CanRead {
permissions += "r"
}
if accessPolicy.CanWrite {
permissions += "w"
}
if accessPolicy.CanDelete {
permissions += "d"
}
return permissions
}
// convertAccessPolicyToXMLStructs converts between AccessPolicyDetails which is a struct better for API usage to the
// AccessPolicy struct which will get converted to XML.
func convertAccessPolicyToXMLStructs(accessPolicy AccessPolicyDetails) SignedIdentifiers {
return SignedIdentifiers{
SignedIdentifiers: []SignedIdentifier{
{
ID: accessPolicy.ID,
AccessPolicy: AccessPolicyDetailsXML{
StartTime: accessPolicy.StartTime.UTC().Round(time.Second),
ExpiryTime: accessPolicy.ExpiryTime.UTC().Round(time.Second),
Permission: generatePermissions(accessPolicy),
},
},
},
}
}
// generateAccessPolicy generates the XML access policy used as the payload for SetContainerPermissions.
func generateAccessPolicy(accessPolicy AccessPolicyDetails) (accessPolicyXML string, err error) {
if accessPolicy.ID != "" {
signedIdentifiers := convertAccessPolicyToXMLStructs(accessPolicy)
body, _, err := xmlMarshal(signedIdentifiers)
if err != nil {
return "", err
}
xmlByteArray, err := ioutil.ReadAll(body)
if err != nil {
return "", err
}
accessPolicyXML = string(xmlByteArray)
return accessPolicyXML, nil
}
return "", nil
}

View file

@ -305,7 +305,7 @@ func (c Client) buildCanonicalizedResourceTable(uri string) (string, error) {
cr := "/" + c.getCanonicalizedAccountName() cr := "/" + c.getCanonicalizedAccountName()
if len(u.Path) > 0 { if len(u.Path) > 0 {
cr += u.Path cr += u.EscapedPath()
} }
return cr, nil return cr, nil

View file

@ -82,6 +82,24 @@ func (p PeekMessagesParameters) getParameters() url.Values {
return out return out
} }
// UpdateMessageParameters is the set of options can be specified for Update Messsage
// operation. A zero struct does not use any preferences for the request.
type UpdateMessageParameters struct {
PopReceipt string
VisibilityTimeout int
}
func (p UpdateMessageParameters) getParameters() url.Values {
out := url.Values{}
if p.PopReceipt != "" {
out.Set("popreceipt", p.PopReceipt)
}
if p.VisibilityTimeout != 0 {
out.Set("visibilitytimeout", strconv.Itoa(p.VisibilityTimeout))
}
return out
}
// GetMessagesResponse represents a response returned from Get Messages // GetMessagesResponse represents a response returned from Get Messages
// operation. // operation.
type GetMessagesResponse struct { type GetMessagesResponse struct {
@ -304,3 +322,23 @@ func (c QueueServiceClient) DeleteMessage(queue, messageID, popReceipt string) e
defer resp.body.Close() defer resp.body.Close()
return checkRespCode(resp.statusCode, []int{http.StatusNoContent}) return checkRespCode(resp.statusCode, []int{http.StatusNoContent})
} }
// UpdateMessage operation deletes the specified message.
//
// See https://msdn.microsoft.com/en-us/library/azure/hh452234.aspx
func (c QueueServiceClient) UpdateMessage(queue string, messageID string, message string, params UpdateMessageParameters) error {
uri := c.client.getEndpoint(queueServiceName, pathForMessage(queue, messageID), params.getParameters())
req := putMessageRequest{MessageText: message}
body, nn, err := xmlMarshal(req)
if err != nil {
return err
}
headers := c.client.getStandardHeaders()
headers["Content-Length"] = fmt.Sprintf("%d", nn)
resp, err := c.client.exec("PUT", uri, headers, body)
if err != nil {
return err
}
defer resp.body.Close()
return checkRespCode(resp.statusCode, []int{http.StatusNoContent})
}

View file

@ -31,8 +31,7 @@ import (
"strings" "strings"
) )
/*--------------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------------
*/
var ( var (
// ErrOutOfBounds - Index out of bounds. // ErrOutOfBounds - Index out of bounds.
@ -63,40 +62,30 @@ var (
ErrInvalidBuffer = errors.New("input buffer contained invalid JSON") ErrInvalidBuffer = errors.New("input buffer contained invalid JSON")
) )
/*--------------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------------
*/
/* // Container - an internal structure that holds a reference to the core interface map of the parsed
Container - an internal structure that holds a reference to the core interface map of the parsed // json. Use this container to move context.
json. Use this container to move context.
*/
type Container struct { type Container struct {
object interface{} object interface{}
} }
/* // Data - Return the contained data as an interface{}.
Data - Return the contained data as an interface{}.
*/
func (g *Container) Data() interface{} { func (g *Container) Data() interface{} {
return g.object return g.object
} }
/*--------------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------------
*/
/* // Path - Search for a value using dot notation.
Path - Search for a value using dot notation.
*/
func (g *Container) Path(path string) *Container { func (g *Container) Path(path string) *Container {
return g.Search(strings.Split(path, ".")...) return g.Search(strings.Split(path, ".")...)
} }
/* // Search - Attempt to find and return an object within the JSON structure by specifying the
Search - Attempt to find and return an object within the JSON structure by specifying the hierarchy // hierarchy of field names to locate the target. If the search encounters an array and has not
of field names to locate the target. If the search encounters an array and has not reached the end // reached the end target then it will iterate each object of the array for the target and return
target then it will iterate each object of the array for the target and return all of the results in // all of the results in a JSON array.
a JSON array.
*/
func (g *Container) Search(hierarchy ...string) *Container { func (g *Container) Search(hierarchy ...string) *Container {
var object interface{} var object interface{}
@ -124,31 +113,22 @@ func (g *Container) Search(hierarchy ...string) *Container {
return &Container{object} return &Container{object}
} }
/* // S - Shorthand method, does the same thing as Search.
S - Shorthand method, does the same thing as Search.
*/
func (g *Container) S(hierarchy ...string) *Container { func (g *Container) S(hierarchy ...string) *Container {
return g.Search(hierarchy...) return g.Search(hierarchy...)
} }
/* // Exists - Checks whether a path exists.
Exists - Checks whether a path exists.
*/
func (g *Container) Exists(hierarchy ...string) bool { func (g *Container) Exists(hierarchy ...string) bool {
return g.Search(hierarchy...).Data() != nil return g.Search(hierarchy...).Data() != nil
} }
/* // ExistsP - Checks whether a dot notation path exists.
ExistsP - Checks whether a dot notation path exists.
*/
func (g *Container) ExistsP(path string) bool { func (g *Container) ExistsP(path string) bool {
return g.Exists(strings.Split(path, ".")...) return g.Exists(strings.Split(path, ".")...)
} }
/* // Index - Attempt to find and return an object within a JSON array by index.
Index - Attempt to find and return an object with a JSON array by specifying the index of the
target.
*/
func (g *Container) Index(index int) *Container { func (g *Container) Index(index int) *Container {
if array, ok := g.Data().([]interface{}); ok { if array, ok := g.Data().([]interface{}); ok {
if index >= len(array) { if index >= len(array) {
@ -159,11 +139,9 @@ func (g *Container) Index(index int) *Container {
return &Container{nil} return &Container{nil}
} }
/* // Children - Return a slice of all the children of the array. This also works for objects, however,
Children - Return a slice of all the children of the array. This also works for objects, however, // the children returned for an object will NOT be in order and you lose the names of the returned
the children returned for an object will NOT be in order and you lose the names of the returned // objects this way.
objects this way.
*/
func (g *Container) Children() ([]*Container, error) { func (g *Container) Children() ([]*Container, error) {
if array, ok := g.Data().([]interface{}); ok { if array, ok := g.Data().([]interface{}); ok {
children := make([]*Container, len(array)) children := make([]*Container, len(array))
@ -182,9 +160,7 @@ func (g *Container) Children() ([]*Container, error) {
return nil, ErrNotObjOrArray return nil, ErrNotObjOrArray
} }
/* // ChildrenMap - Return a map of all the children of an object.
ChildrenMap - Return a map of all the children of an object.
*/
func (g *Container) ChildrenMap() (map[string]*Container, error) { func (g *Container) ChildrenMap() (map[string]*Container, error) {
if mmap, ok := g.Data().(map[string]interface{}); ok { if mmap, ok := g.Data().(map[string]interface{}); ok {
children := map[string]*Container{} children := map[string]*Container{}
@ -196,14 +172,11 @@ func (g *Container) ChildrenMap() (map[string]*Container, error) {
return nil, ErrNotObj return nil, ErrNotObj
} }
/*--------------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------------
*/
/* // Set - Set the value of a field at a JSON path, any parts of the path that do not exist will be
Set - Set the value of a field at a JSON path, any parts of the path that do not exist will be // constructed, and if a collision occurs with a non object type whilst iterating the path an error
constructed, and if a collision occurs with a non object type whilst iterating the path an error is // is returned.
returned.
*/
func (g *Container) Set(value interface{}, path ...string) (*Container, error) { func (g *Container) Set(value interface{}, path ...string) (*Container, error) {
if len(path) == 0 { if len(path) == 0 {
g.object = value g.object = value
@ -229,16 +202,12 @@ func (g *Container) Set(value interface{}, path ...string) (*Container, error) {
return &Container{object}, nil return &Container{object}, nil
} }
/* // SetP - Does the same as Set, but using a dot notation JSON path.
SetP - Does the same as Set, but using a dot notation JSON path.
*/
func (g *Container) SetP(value interface{}, path string) (*Container, error) { func (g *Container) SetP(value interface{}, path string) (*Container, error) {
return g.Set(value, strings.Split(path, ".")...) return g.Set(value, strings.Split(path, ".")...)
} }
/* // SetIndex - Set a value of an array element based on the index.
SetIndex - Set a value of an array element based on the index.
*/
func (g *Container) SetIndex(value interface{}, index int) (*Container, error) { func (g *Container) SetIndex(value interface{}, index int) (*Container, error) {
if array, ok := g.Data().([]interface{}); ok { if array, ok := g.Data().([]interface{}); ok {
if index >= len(array) { if index >= len(array) {
@ -250,80 +219,60 @@ func (g *Container) SetIndex(value interface{}, index int) (*Container, error) {
return &Container{nil}, ErrNotArray return &Container{nil}, ErrNotArray
} }
/* // Object - Create a new JSON object at a path. Returns an error if the path contains a collision
Object - Create a new JSON object at a path. Returns an error if the path contains a collision with // with a non object type.
a non object type.
*/
func (g *Container) Object(path ...string) (*Container, error) { func (g *Container) Object(path ...string) (*Container, error) {
return g.Set(map[string]interface{}{}, path...) return g.Set(map[string]interface{}{}, path...)
} }
/* // ObjectP - Does the same as Object, but using a dot notation JSON path.
ObjectP - Does the same as Object, but using a dot notation JSON path.
*/
func (g *Container) ObjectP(path string) (*Container, error) { func (g *Container) ObjectP(path string) (*Container, error) {
return g.Object(strings.Split(path, ".")...) return g.Object(strings.Split(path, ".")...)
} }
/* // ObjectI - Create a new JSON object at an array index. Returns an error if the object is not an
ObjectI - Create a new JSON object at an array index. Returns an error if the object is not an array // array or the index is out of bounds.
or the index is out of bounds.
*/
func (g *Container) ObjectI(index int) (*Container, error) { func (g *Container) ObjectI(index int) (*Container, error) {
return g.SetIndex(map[string]interface{}{}, index) return g.SetIndex(map[string]interface{}{}, index)
} }
/* // Array - Create a new JSON array at a path. Returns an error if the path contains a collision with
Array - Create a new JSON array at a path. Returns an error if the path contains a collision with // a non object type.
a non object type.
*/
func (g *Container) Array(path ...string) (*Container, error) { func (g *Container) Array(path ...string) (*Container, error) {
return g.Set([]interface{}{}, path...) return g.Set([]interface{}{}, path...)
} }
/* // ArrayP - Does the same as Array, but using a dot notation JSON path.
ArrayP - Does the same as Array, but using a dot notation JSON path.
*/
func (g *Container) ArrayP(path string) (*Container, error) { func (g *Container) ArrayP(path string) (*Container, error) {
return g.Array(strings.Split(path, ".")...) return g.Array(strings.Split(path, ".")...)
} }
/* // ArrayI - Create a new JSON array at an array index. Returns an error if the object is not an
ArrayI - Create a new JSON array at an array index. Returns an error if the object is not an array // array or the index is out of bounds.
or the index is out of bounds.
*/
func (g *Container) ArrayI(index int) (*Container, error) { func (g *Container) ArrayI(index int) (*Container, error) {
return g.SetIndex([]interface{}{}, index) return g.SetIndex([]interface{}{}, index)
} }
/* // ArrayOfSize - Create a new JSON array of a particular size at a path. Returns an error if the
ArrayOfSize - Create a new JSON array of a particular size at a path. Returns an error if the path // path contains a collision with a non object type.
contains a collision with a non object type.
*/
func (g *Container) ArrayOfSize(size int, path ...string) (*Container, error) { func (g *Container) ArrayOfSize(size int, path ...string) (*Container, error) {
a := make([]interface{}, size) a := make([]interface{}, size)
return g.Set(a, path...) return g.Set(a, path...)
} }
/* // ArrayOfSizeP - Does the same as ArrayOfSize, but using a dot notation JSON path.
ArrayOfSizeP - Does the same as ArrayOfSize, but using a dot notation JSON path.
*/
func (g *Container) ArrayOfSizeP(size int, path string) (*Container, error) { func (g *Container) ArrayOfSizeP(size int, path string) (*Container, error) {
return g.ArrayOfSize(size, strings.Split(path, ".")...) return g.ArrayOfSize(size, strings.Split(path, ".")...)
} }
/* // ArrayOfSizeI - Create a new JSON array of a particular size at an array index. Returns an error
ArrayOfSizeI - Create a new JSON array of a particular size at an array index. Returns an error if // if the object is not an array or the index is out of bounds.
the object is not an array or the index is out of bounds.
*/
func (g *Container) ArrayOfSizeI(size, index int) (*Container, error) { func (g *Container) ArrayOfSizeI(size, index int) (*Container, error) {
a := make([]interface{}, size) a := make([]interface{}, size)
return g.SetIndex(a, index) return g.SetIndex(a, index)
} }
/* // Delete - Delete an element at a JSON path, an error is returned if the element does not exist.
Delete - Delete an element at a JSON path, an error is returned if the element does not exist.
*/
func (g *Container) Delete(path ...string) error { func (g *Container) Delete(path ...string) error {
var object interface{} var object interface{}
@ -346,24 +295,19 @@ func (g *Container) Delete(path ...string) error {
return nil return nil
} }
/* // DeleteP - Does the same as Delete, but using a dot notation JSON path.
DeleteP - Does the same as Delete, but using a dot notation JSON path.
*/
func (g *Container) DeleteP(path string) error { func (g *Container) DeleteP(path string) error {
return g.Delete(strings.Split(path, ".")...) return g.Delete(strings.Split(path, ".")...)
} }
/*--------------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------------
*/
/* /*
Array modification/search - Keeping these options simple right now, no need for anything more Array modification/search - Keeping these options simple right now, no need for anything more
complicated since you can just cast to []interface{}, modify and then reassign with Set. complicated since you can just cast to []interface{}, modify and then reassign with Set.
*/ */
/* // ArrayAppend - Append a value onto a JSON array.
ArrayAppend - Append a value onto a JSON array.
*/
func (g *Container) ArrayAppend(value interface{}, path ...string) error { func (g *Container) ArrayAppend(value interface{}, path ...string) error {
array, ok := g.Search(path...).Data().([]interface{}) array, ok := g.Search(path...).Data().([]interface{})
if !ok { if !ok {
@ -374,16 +318,12 @@ func (g *Container) ArrayAppend(value interface{}, path ...string) error {
return err return err
} }
/* // ArrayAppendP - Append a value onto a JSON array using a dot notation JSON path.
ArrayAppendP - Append a value onto a JSON array using a dot notation JSON path.
*/
func (g *Container) ArrayAppendP(value interface{}, path string) error { func (g *Container) ArrayAppendP(value interface{}, path string) error {
return g.ArrayAppend(value, strings.Split(path, ".")...) return g.ArrayAppend(value, strings.Split(path, ".")...)
} }
/* // ArrayRemove - Remove an element from a JSON array.
ArrayRemove - Remove an element from a JSON array.
*/
func (g *Container) ArrayRemove(index int, path ...string) error { func (g *Container) ArrayRemove(index int, path ...string) error {
if index < 0 { if index < 0 {
return ErrOutOfBounds return ErrOutOfBounds
@ -401,16 +341,12 @@ func (g *Container) ArrayRemove(index int, path ...string) error {
return err return err
} }
/* // ArrayRemoveP - Remove an element from a JSON array using a dot notation JSON path.
ArrayRemoveP - Remove an element from a JSON array using a dot notation JSON path.
*/
func (g *Container) ArrayRemoveP(index int, path string) error { func (g *Container) ArrayRemoveP(index int, path string) error {
return g.ArrayRemove(index, strings.Split(path, ".")...) return g.ArrayRemove(index, strings.Split(path, ".")...)
} }
/* // ArrayElement - Access an element from a JSON array.
ArrayElement - Access an element from a JSON array.
*/
func (g *Container) ArrayElement(index int, path ...string) (*Container, error) { func (g *Container) ArrayElement(index int, path ...string) (*Container, error) {
if index < 0 { if index < 0 {
return &Container{nil}, ErrOutOfBounds return &Container{nil}, ErrOutOfBounds
@ -425,16 +361,12 @@ func (g *Container) ArrayElement(index int, path ...string) (*Container, error)
return &Container{nil}, ErrOutOfBounds return &Container{nil}, ErrOutOfBounds
} }
/* // ArrayElementP - Access an element from a JSON array using a dot notation JSON path.
ArrayElementP - Access an element from a JSON array using a dot notation JSON path.
*/
func (g *Container) ArrayElementP(index int, path string) (*Container, error) { func (g *Container) ArrayElementP(index int, path string) (*Container, error) {
return g.ArrayElement(index, strings.Split(path, ".")...) return g.ArrayElement(index, strings.Split(path, ".")...)
} }
/* // ArrayCount - Count the number of elements in a JSON array.
ArrayCount - Count the number of elements in a JSON array.
*/
func (g *Container) ArrayCount(path ...string) (int, error) { func (g *Container) ArrayCount(path ...string) (int, error) {
if array, ok := g.Search(path...).Data().([]interface{}); ok { if array, ok := g.Search(path...).Data().([]interface{}); ok {
return len(array), nil return len(array), nil
@ -442,19 +374,14 @@ func (g *Container) ArrayCount(path ...string) (int, error) {
return 0, ErrNotArray return 0, ErrNotArray
} }
/* // ArrayCountP - Count the number of elements in a JSON array using a dot notation JSON path.
ArrayCountP - Count the number of elements in a JSON array using a dot notation JSON path.
*/
func (g *Container) ArrayCountP(path string) (int, error) { func (g *Container) ArrayCountP(path string) (int, error) {
return g.ArrayCount(strings.Split(path, ".")...) return g.ArrayCount(strings.Split(path, ".")...)
} }
/*--------------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------------
*/
/* // Bytes - Converts the contained object back to a JSON []byte blob.
Bytes - Converts the contained object back to a JSON []byte blob.
*/
func (g *Container) Bytes() []byte { func (g *Container) Bytes() []byte {
if g.object != nil { if g.object != nil {
if bytes, err := json.Marshal(g.object); err == nil { if bytes, err := json.Marshal(g.object); err == nil {
@ -464,9 +391,7 @@ func (g *Container) Bytes() []byte {
return []byte("{}") return []byte("{}")
} }
/* // BytesIndent - Converts the contained object to a JSON []byte blob formatted with prefix, indent.
BytesIndent - Converts the contained object back to a JSON []byte blob formatted with prefix and indent.
*/
func (g *Container) BytesIndent(prefix string, indent string) []byte { func (g *Container) BytesIndent(prefix string, indent string) []byte {
if g.object != nil { if g.object != nil {
if bytes, err := json.MarshalIndent(g.object, prefix, indent); err == nil { if bytes, err := json.MarshalIndent(g.object, prefix, indent); err == nil {
@ -476,37 +401,27 @@ func (g *Container) BytesIndent(prefix string, indent string) []byte {
return []byte("{}") return []byte("{}")
} }
/* // String - Converts the contained object to a JSON formatted string.
String - Converts the contained object back to a JSON formatted string.
*/
func (g *Container) String() string { func (g *Container) String() string {
return string(g.Bytes()) return string(g.Bytes())
} }
/* // StringIndent - Converts the contained object back to a JSON formatted string with prefix, indent.
StringIndent - Converts the contained object back to a JSON formatted string with prefix and indent.
*/
func (g *Container) StringIndent(prefix string, indent string) string { func (g *Container) StringIndent(prefix string, indent string) string {
return string(g.BytesIndent(prefix, indent)) return string(g.BytesIndent(prefix, indent))
} }
/* // New - Create a new gabs JSON object.
New - Create a new gabs JSON object.
*/
func New() *Container { func New() *Container {
return &Container{map[string]interface{}{}} return &Container{map[string]interface{}{}}
} }
/* // Consume - Gobble up an already converted JSON object, or a fresh map[string]interface{} object.
Consume - Gobble up an already converted JSON object, or a fresh map[string]interface{} object.
*/
func Consume(root interface{}) (*Container, error) { func Consume(root interface{}) (*Container, error) {
return &Container{root}, nil return &Container{root}, nil
} }
/* // ParseJSON - Convert a string into a representation of the parsed JSON.
ParseJSON - Convert a string into a representation of the parsed JSON.
*/
func ParseJSON(sample []byte) (*Container, error) { func ParseJSON(sample []byte) (*Container, error) {
var gabs Container var gabs Container
@ -517,9 +432,7 @@ func ParseJSON(sample []byte) (*Container, error) {
return &gabs, nil return &gabs, nil
} }
/* // ParseJSONDecoder - Convert a json.Decoder into a representation of the parsed JSON.
ParseJSONDecoder - Convert a json.Decoder into a representation of the parsed JSON.
*/
func ParseJSONDecoder(decoder *json.Decoder) (*Container, error) { func ParseJSONDecoder(decoder *json.Decoder) (*Container, error) {
var gabs Container var gabs Container
@ -530,9 +443,7 @@ func ParseJSONDecoder(decoder *json.Decoder) (*Container, error) {
return &gabs, nil return &gabs, nil
} }
/* // ParseJSONFile - Read a file and convert into a representation of the parsed JSON.
ParseJSONFile - Read a file and convert into a representation of the parsed JSON.
*/
func ParseJSONFile(path string) (*Container, error) { func ParseJSONFile(path string) (*Container, error) {
if len(path) > 0 { if len(path) > 0 {
cBytes, err := ioutil.ReadFile(path) cBytes, err := ioutil.ReadFile(path)
@ -550,9 +461,7 @@ func ParseJSONFile(path string) (*Container, error) {
return nil, ErrInvalidPath return nil, ErrInvalidPath
} }
/* // ParseJSONBuffer - Read the contents of a buffer into a representation of the parsed JSON.
ParseJSONBuffer - Read the contents of a buffer into a representation of the parsed JSON.
*/
func ParseJSONBuffer(buffer io.Reader) (*Container, error) { func ParseJSONBuffer(buffer io.Reader) (*Container, error) {
var gabs Container var gabs Container
jsonDecoder := json.NewDecoder(buffer) jsonDecoder := json.NewDecoder(buffer)
@ -563,83 +472,4 @@ func ParseJSONBuffer(buffer io.Reader) (*Container, error) {
return &gabs, nil return &gabs, nil
} }
/*--------------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------------
*/
// DEPRECATED METHODS
/*
Push - DEPRECATED: Push a value onto a JSON array.
*/
func (g *Container) Push(target string, value interface{}) error {
if mmap, ok := g.Data().(map[string]interface{}); ok {
arrayTarget := mmap[target]
if array, ok := arrayTarget.([]interface{}); ok {
mmap[target] = append(array, value)
} else {
return ErrNotArray
}
} else {
return ErrNotObj
}
return nil
}
/*
RemoveElement - DEPRECATED: Remove a value from a JSON array.
*/
func (g *Container) RemoveElement(target string, index int) error {
if index < 0 {
return ErrOutOfBounds
}
if mmap, ok := g.Data().(map[string]interface{}); ok {
arrayTarget := mmap[target]
if array, ok := arrayTarget.([]interface{}); ok {
if index < len(array) {
mmap[target] = append(array[:index], array[index+1:]...)
} else {
return ErrOutOfBounds
}
} else {
return ErrNotArray
}
} else {
return ErrNotObj
}
return nil
}
/*
GetElement - DEPRECATED: Get the desired element from a JSON array
*/
func (g *Container) GetElement(target string, index int) *Container {
if index < 0 {
return &Container{nil}
}
if mmap, ok := g.Data().(map[string]interface{}); ok {
arrayTarget := mmap[target]
if array, ok := arrayTarget.([]interface{}); ok {
if index < len(array) {
return &Container{array[index]}
}
}
}
return &Container{nil}
}
/*
CountElements - DEPRECATED: Count the elements of a JSON array, returns -1 if the target is not an
array
*/
func (g *Container) CountElements(target string) int {
if mmap, ok := g.Data().(map[string]interface{}); ok {
arrayTarget := mmap[target]
if array, ok := arrayTarget.([]interface{}); ok {
return len(array)
}
}
return -1
}
/*---------------------------------------------------------------------------------------------------
*/

View file

@ -28,39 +28,42 @@ Examples
Here is an example of using the package: Here is an example of using the package:
func SlowMethod() { ```go
// Profiling the runtime of a method func SlowMethod() {
defer metrics.MeasureSince([]string{"SlowMethod"}, time.Now()) // Profiling the runtime of a method
} defer metrics.MeasureSince([]string{"SlowMethod"}, time.Now())
}
// Configure a statsite sink as the global metrics sink // Configure a statsite sink as the global metrics sink
sink, _ := metrics.NewStatsiteSink("statsite:8125") sink, _ := metrics.NewStatsiteSink("statsite:8125")
metrics.NewGlobal(metrics.DefaultConfig("service-name"), sink) metrics.NewGlobal(metrics.DefaultConfig("service-name"), sink)
// Emit a Key/Value pair
metrics.EmitKey([]string{"questions", "meaning of life"}, 42)
// Emit a Key/Value pair
metrics.EmitKey([]string{"questions", "meaning of life"}, 42)
```
Here is an example of setting up an signal handler: Here is an example of setting up an signal handler:
// Setup the inmem sink and signal handler ```go
inm := metrics.NewInmemSink(10*time.Second, time.Minute) // Setup the inmem sink and signal handler
sig := metrics.DefaultInmemSignal(inm) inm := metrics.NewInmemSink(10*time.Second, time.Minute)
metrics.NewGlobal(metrics.DefaultConfig("service-name"), inm) sig := metrics.DefaultInmemSignal(inm)
metrics.NewGlobal(metrics.DefaultConfig("service-name"), inm)
// Run some code // Run some code
inm.SetGauge([]string{"foo"}, 42) inm.SetGauge([]string{"foo"}, 42)
inm.EmitKey([]string{"bar"}, 30) inm.EmitKey([]string{"bar"}, 30)
inm.IncrCounter([]string{"baz"}, 42) inm.IncrCounter([]string{"baz"}, 42)
inm.IncrCounter([]string{"baz"}, 1) inm.IncrCounter([]string{"baz"}, 1)
inm.IncrCounter([]string{"baz"}, 80) inm.IncrCounter([]string{"baz"}, 80)
inm.AddSample([]string{"method", "wow"}, 42) inm.AddSample([]string{"method", "wow"}, 42)
inm.AddSample([]string{"method", "wow"}, 100) inm.AddSample([]string{"method", "wow"}, 100)
inm.AddSample([]string{"method", "wow"}, 22) inm.AddSample([]string{"method", "wow"}, 22)
.... ....
```
When a signal comes in, output like the following will be dumped to stderr: When a signal comes in, output like the following will be dumped to stderr:

View file

@ -30,7 +30,15 @@ const (
Latitude string = "^[-+]?([1-8]?\\d(\\.\\d+)?|90(\\.0+)?)$" Latitude string = "^[-+]?([1-8]?\\d(\\.\\d+)?|90(\\.0+)?)$"
Longitude string = "^[-+]?(180(\\.0+)?|((1[0-7]\\d)|([1-9]?\\d))(\\.\\d+)?)$" Longitude string = "^[-+]?(180(\\.0+)?|((1[0-7]\\d)|([1-9]?\\d))(\\.\\d+)?)$"
DNSName string = `^([a-zA-Z0-9]{1}[a-zA-Z0-9_-]{1,62}){1}(\.[a-zA-Z0-9]{1}[a-zA-Z0-9_-]{1,62})*$` DNSName string = `^([a-zA-Z0-9]{1}[a-zA-Z0-9_-]{1,62}){1}(\.[a-zA-Z0-9]{1}[a-zA-Z0-9_-]{1,62})*$`
URL string = `^((ftp|https?):\/\/)?(\S+(:\S*)?@)?((([1-9]\d?|1\d\d|2[01]\d|22[0-3])(\.(1?\d{1,2}|2[0-4]\d|25[0-5])){2}(?:\.([0-9]\d?|1\d\d|2[0-4]\d|25[0-4]))|(([a-zA-Z0-9]([a-zA-Z0-9-]+)?[a-zA-Z0-9]([-\.][a-zA-Z0-9]+)*)|((www\.)?))?(([a-zA-Z\x{00a1}-\x{ffff}0-9]+-?-?)*[a-zA-Z\x{00a1}-\x{ffff}0-9]+)(?:\.([a-zA-Z\x{00a1}-\x{ffff}]{1,}))?))(:(\d{1,5}))?((\/|\?|#)[^\s]*)?$` IP string = `(([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,7}:|([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|:((:[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))`
URLSchema string = `((ftp|tcp|udp|wss?|https?):\/\/)`
URLUsername string = `(\S+(:\S*)?@)`
Hostname string = ``
URLPath string = `((\/|\?|#)[^\s]*)`
URLPort string = `(:(\d{1,5}))`
URLIP string = `([1-9]\d?|1\d\d|2[01]\d|22[0-3])(\.(1?\d{1,2}|2[0-4]\d|25[0-5])){2}(?:\.([0-9]\d?|1\d\d|2[0-4]\d|25[0-4]))`
URLSubdomain string = `((www\.)|([a-zA-Z0-9]([-\.][a-zA-Z0-9]+)*))`
URL string = `^` + URLSchema + `?` + URLUsername + `?` + `((` + URLIP + `|(\[` + IP + `\])|(([a-zA-Z0-9]([a-zA-Z0-9-]+)?[a-zA-Z0-9]([-\.][a-zA-Z0-9]+)*)|(` + URLSubdomain + `?))?(([a-zA-Z\x{00a1}-\x{ffff}0-9]+-?-?)*[a-zA-Z\x{00a1}-\x{ffff}0-9]+)(?:\.([a-zA-Z\x{00a1}-\x{ffff}]{1,}))?))` + URLPort + `?` + URLPath + `?$`
SSN string = `^\d{3}[- ]?\d{2}[- ]?\d{4}$` SSN string = `^\d{3}[- ]?\d{2}[- ]?\d{4}$`
WinPath string = `^[a-zA-Z]:\\(?:[^\\/:*?"<>|\r\n]+\\)*[^\\/:*?"<>|\r\n]*$` WinPath string = `^[a-zA-Z]:\\(?:[^\\/:*?"<>|\r\n]+\\)*[^\\/:*?"<>|\r\n]*$`
UnixPath string = `^((?:\/[a-zA-Z0-9\.\:]+(?:_[a-zA-Z0-9\:\.]+)*(?:\-[\:a-zA-Z0-9\.]+)*)+\/?)$` UnixPath string = `^((?:\/[a-zA-Z0-9\.\:]+(?:_[a-zA-Z0-9\:\.]+)*(?:\-[\:a-zA-Z0-9\.]+)*)+\/?)$`

View file

@ -496,6 +496,12 @@ func IsIPv6(str string) bool {
return ip != nil && strings.Contains(str, ":") return ip != nil && strings.Contains(str, ":")
} }
// IsCIDR check if the string is an valid CIDR notiation (IPV4 & IPV6)
func IsCIDR(str string) bool {
_, _, err := net.ParseCIDR(str)
return err == nil
}
// IsMAC check if a string is valid MAC address. // IsMAC check if a string is valid MAC address.
// Possible MAC formats: // Possible MAC formats:
// 01:23:45:67:89:ab // 01:23:45:67:89:ab

View file

@ -2,7 +2,6 @@ package client
import ( import (
"fmt" "fmt"
"io/ioutil"
"net/http/httputil" "net/http/httputil"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@ -104,8 +103,7 @@ func logRequest(r *request.Request) {
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's // Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP // Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader. // client reader.
r.Body.Seek(r.BodyStart, 0) r.ResetBody()
r.HTTPRequest.Body = ioutil.NopCloser(r.Body)
} }
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody))) r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody)))

View file

@ -137,9 +137,6 @@ type Config struct {
// accelerate enabled. If the bucket is not enabled for accelerate an error // accelerate enabled. If the bucket is not enabled for accelerate an error
// will be returned. The bucket name must be DNS compatible to also work // will be returned. The bucket name must be DNS compatible to also work
// with accelerate. // with accelerate.
//
// Not compatible with UseDualStack requests will fail if both flags are
// specified.
S3UseAccelerate *bool S3UseAccelerate *bool
// Set this to `true` to disable the EC2Metadata client from overriding the // Set this to `true` to disable the EC2Metadata client from overriding the
@ -185,6 +182,19 @@ type Config struct {
// the delay of a request see the aws/client.DefaultRetryer and // the delay of a request see the aws/client.DefaultRetryer and
// aws/request.Retryer. // aws/request.Retryer.
SleepDelay func(time.Duration) SleepDelay func(time.Duration)
// DisableRestProtocolURICleaning will not clean the URL path when making rest protocol requests.
// Will default to false. This would only be used for empty directory names in s3 requests.
//
// Example:
// sess, err := session.NewSession(&aws.Config{DisableRestProtocolURICleaning: aws.Bool(true))
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput {
// Bucket: aws.String("bucketname"),
// Key: aws.String("//foo//bar//moo"),
// })
DisableRestProtocolURICleaning *bool
} }
// NewConfig returns a new Config pointer that can be chained with builder // NewConfig returns a new Config pointer that can be chained with builder
@ -406,6 +416,10 @@ func mergeInConfig(dst *Config, other *Config) {
if other.SleepDelay != nil { if other.SleepDelay != nil {
dst.SleepDelay = other.SleepDelay dst.SleepDelay = other.SleepDelay
} }
if other.DisableRestProtocolURICleaning != nil {
dst.DisableRestProtocolURICleaning = other.DisableRestProtocolURICleaning
}
} }
// Copy will return a shallow copy of the Config object. If any additional // Copy will return a shallow copy of the Config object. If any additional

View file

@ -10,9 +10,11 @@ import (
"regexp" "regexp"
"runtime" "runtime"
"strconv" "strconv"
"time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
) )
@ -67,6 +69,34 @@ var SDKVersionUserAgentHandler = request.NamedHandler{
var reStatusCode = regexp.MustCompile(`^(\d{3})`) var reStatusCode = regexp.MustCompile(`^(\d{3})`)
// ValidateReqSigHandler is a request handler to ensure that the request's
// signature doesn't expire before it is sent. This can happen when a request
// is built and signed signficantly before it is sent. Or signficant delays
// occur whne retrying requests that would cause the signature to expire.
var ValidateReqSigHandler = request.NamedHandler{
Name: "core.ValidateReqSigHandler",
Fn: func(r *request.Request) {
// Unsigned requests are not signed
if r.Config.Credentials == credentials.AnonymousCredentials {
return
}
signedTime := r.Time
if !r.LastSignedAt.IsZero() {
signedTime = r.LastSignedAt
}
// 10 minutes to allow for some clock skew/delays in transmission.
// Would be improved with aws/aws-sdk-go#423
if signedTime.Add(10 * time.Minute).After(time.Now()) {
return
}
fmt.Println("request expired, resigning")
r.Sign()
},
}
// SendHandler is a request handler to send service request using HTTP client. // SendHandler is a request handler to send service request using HTTP client.
var SendHandler = request.NamedHandler{Name: "core.SendHandler", Fn: func(r *request.Request) { var SendHandler = request.NamedHandler{Name: "core.SendHandler", Fn: func(r *request.Request) {
var err error var err error

View file

@ -34,7 +34,7 @@ var (
// //
// Example of ChainProvider to be used with an EnvProvider and EC2RoleProvider. // Example of ChainProvider to be used with an EnvProvider and EC2RoleProvider.
// In this example EnvProvider will first check if any credentials are available // In this example EnvProvider will first check if any credentials are available
// vai the environment variables. If there are none ChainProvider will check // via the environment variables. If there are none ChainProvider will check
// the next Provider in the list, EC2RoleProvider in this case. If EC2RoleProvider // the next Provider in the list, EC2RoleProvider in this case. If EC2RoleProvider
// does not return any credentials ChainProvider will return the error // does not return any credentials ChainProvider will return the error
// ErrNoValidProvidersFoundInChain // ErrNoValidProvidersFoundInChain

View file

@ -111,7 +111,7 @@ func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
}, nil }, nil
} }
// A ec2RoleCredRespBody provides the shape for unmarshalling credential // A ec2RoleCredRespBody provides the shape for unmarshaling credential
// request responses. // request responses.
type ec2RoleCredRespBody struct { type ec2RoleCredRespBody struct {
// Success State // Success State

View file

@ -72,6 +72,7 @@ func Handlers() request.Handlers {
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler) handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
handlers.Build.AfterEachFn = request.HandlerListStopOnError handlers.Build.AfterEachFn = request.HandlerListStopOnError
handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler) handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
handlers.Send.PushBackNamed(corehandlers.ValidateReqSigHandler)
handlers.Send.PushBackNamed(corehandlers.SendHandler) handlers.Send.PushBackNamed(corehandlers.SendHandler)
handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler) handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
handlers.ValidateResponse.PushBackNamed(corehandlers.ValidateResponseHandler) handlers.ValidateResponse.PushBackNamed(corehandlers.ValidateResponseHandler)

View file

@ -3,6 +3,7 @@ package ec2metadata
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"path" "path"
"strings" "strings"
"time" "time"
@ -27,6 +28,27 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) {
return output.Content, req.Send() return output.Content, req.Send()
} }
// GetUserData returns the userdata that was configured for the service. If
// there is no user-data setup for the EC2 instance a "NotFoundError" error
// code will be returned.
func (c *EC2Metadata) GetUserData() (string, error) {
op := &request.Operation{
Name: "GetUserData",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "user-data"),
}
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
req.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
if r.HTTPResponse.StatusCode == http.StatusNotFound {
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
}
})
return output.Content, req.Send()
}
// GetDynamicData uses the path provided to request information from the EC2 // GetDynamicData uses the path provided to request information from the EC2
// instance metadata service for dynamic data. The content will be returned // instance metadata service for dynamic data. The content will be returned
// as a string, or error if the request failed. // as a string, or error if the request failed.
@ -111,7 +133,7 @@ func (c *EC2Metadata) Available() bool {
return true return true
} }
// An EC2IAMInfo provides the shape for unmarshalling // An EC2IAMInfo provides the shape for unmarshaling
// an IAM info from the metadata API // an IAM info from the metadata API
type EC2IAMInfo struct { type EC2IAMInfo struct {
Code string Code string
@ -120,7 +142,7 @@ type EC2IAMInfo struct {
InstanceProfileID string InstanceProfileID string
} }
// An EC2InstanceIdentityDocument provides the shape for unmarshalling // An EC2InstanceIdentityDocument provides the shape for unmarshaling
// an instance identity document // an instance identity document
type EC2InstanceIdentityDocument struct { type EC2InstanceIdentityDocument struct {
DevpayProductCodes []string `json:"devpayProductCodes"` DevpayProductCodes []string `json:"devpayProductCodes"`

View file

@ -9,7 +9,7 @@ import (
// with retrying requests // with retrying requests
type offsetReader struct { type offsetReader struct {
buf io.ReadSeeker buf io.ReadSeeker
lock sync.RWMutex lock sync.Mutex
closed bool closed bool
} }
@ -21,7 +21,8 @@ func newOffsetReader(buf io.ReadSeeker, offset int64) *offsetReader {
return reader return reader
} }
// Close is a thread-safe close. Uses the write lock. // Close will close the instance of the offset reader's access to
// the underlying io.ReadSeeker.
func (o *offsetReader) Close() error { func (o *offsetReader) Close() error {
o.lock.Lock() o.lock.Lock()
defer o.lock.Unlock() defer o.lock.Unlock()
@ -29,10 +30,10 @@ func (o *offsetReader) Close() error {
return nil return nil
} }
// Read is a thread-safe read using a read lock. // Read is a thread-safe read of the underlying io.ReadSeeker
func (o *offsetReader) Read(p []byte) (int, error) { func (o *offsetReader) Read(p []byte) (int, error) {
o.lock.RLock() o.lock.Lock()
defer o.lock.RUnlock() defer o.lock.Unlock()
if o.closed { if o.closed {
return 0, io.EOF return 0, io.EOF
@ -41,6 +42,14 @@ func (o *offsetReader) Read(p []byte) (int, error) {
return o.buf.Read(p) return o.buf.Read(p)
} }
// Seek is a thread-safe seeking operation.
func (o *offsetReader) Seek(offset int64, whence int) (int64, error) {
o.lock.Lock()
defer o.lock.Unlock()
return o.buf.Seek(offset, whence)
}
// CloseAndCopy will return a new offsetReader with a copy of the old buffer // CloseAndCopy will return a new offsetReader with a copy of the old buffer
// and close the old buffer. // and close the old buffer.
func (o *offsetReader) CloseAndCopy(offset int64) *offsetReader { func (o *offsetReader) CloseAndCopy(offset int64) *offsetReader {

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
@ -42,6 +41,12 @@ type Request struct {
LastSignedAt time.Time LastSignedAt time.Time
built bool built bool
// Need to persist an intermideant body betweend the input Body and HTTP
// request body because the HTTP Client's transport can maintain a reference
// to the HTTP request's body after the client has returned. This value is
// safe to use concurrently and rewraps the input Body for each HTTP request.
safeBody *offsetReader
} }
// An Operation is the service API operation to be made. // An Operation is the service API operation to be made.
@ -135,8 +140,8 @@ func (r *Request) SetStringBody(s string) {
// SetReaderBody will set the request's body reader. // SetReaderBody will set the request's body reader.
func (r *Request) SetReaderBody(reader io.ReadSeeker) { func (r *Request) SetReaderBody(reader io.ReadSeeker) {
r.HTTPRequest.Body = newOffsetReader(reader, 0)
r.Body = reader r.Body = reader
r.ResetBody()
} }
// Presign returns the request's signed URL. Error will be returned // Presign returns the request's signed URL. Error will be returned
@ -220,6 +225,24 @@ func (r *Request) Sign() error {
return r.Error return r.Error
} }
// ResetBody rewinds the request body backto its starting position, and
// set's the HTTP Request body reference. When the body is read prior
// to being sent in the HTTP request it will need to be rewound.
func (r *Request) ResetBody() {
if r.safeBody != nil {
r.safeBody.Close()
}
r.safeBody = newOffsetReader(r.Body, r.BodyStart)
r.HTTPRequest.Body = r.safeBody
}
// GetBody will return an io.ReadSeeker of the Request's underlying
// input body with a concurrency safe wrapper.
func (r *Request) GetBody() io.ReadSeeker {
return r.safeBody
}
// Send will send the request returning error if errors are encountered. // Send will send the request returning error if errors are encountered.
// //
// Send will sign the request prior to sending. All Send Handlers will // Send will sign the request prior to sending. All Send Handlers will
@ -231,6 +254,8 @@ func (r *Request) Sign() error {
// //
// readLoop() and getConn(req *Request, cm connectMethod) // readLoop() and getConn(req *Request, cm connectMethod)
// https://github.com/golang/go/blob/master/src/net/http/transport.go // https://github.com/golang/go/blob/master/src/net/http/transport.go
//
// Send will not close the request.Request's body.
func (r *Request) Send() error { func (r *Request) Send() error {
for { for {
if aws.BoolValue(r.Retryable) { if aws.BoolValue(r.Retryable) {
@ -239,21 +264,15 @@ func (r *Request) Send() error {
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount)) r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
} }
var body io.ReadCloser // The previous http.Request will have a reference to the r.Body
if reader, ok := r.HTTPRequest.Body.(*offsetReader); ok { // and the HTTP Client's Transport may still be reading from
body = reader.CloseAndCopy(r.BodyStart) // the request's body even though the Client's Do returned.
} else { r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil)
if r.Config.Logger != nil { r.ResetBody()
r.Config.Logger.Log("Request body type has been overwritten. May cause race conditions")
}
r.Body.Seek(r.BodyStart, 0)
body = ioutil.NopCloser(r.Body)
}
r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, body) // Closing response body to ensure that no response body is leaked
// between retry attempts.
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil { if r.HTTPResponse != nil && r.HTTPResponse.Body != nil {
// Closing response body. Since we are setting a new request to send off, this
// response will get squashed and leaked.
r.HTTPResponse.Body.Close() r.HTTPResponse.Body.Close()
} }
} }
@ -281,7 +300,6 @@ func (r *Request) Send() error {
debugLogReqError(r, "Send Request", true, err) debugLogReqError(r, "Send Request", true, err)
continue continue
} }
r.Handlers.UnmarshalMeta.Run(r) r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r) r.Handlers.ValidateResponse.Run(r)
if r.Error != nil { if r.Error != nil {

View file

@ -66,7 +66,7 @@ through code instead of being driven by environment variables only.
Use NewSessionWithOptions when you want to provide the config profile, or Use NewSessionWithOptions when you want to provide the config profile, or
override the shared config state (AWS_SDK_LOAD_CONFIG). override the shared config state (AWS_SDK_LOAD_CONFIG).
// Equivalent to session.New // Equivalent to session.NewSession()
sess, err := session.NewSessionWithOptions(session.Options{}) sess, err := session.NewSessionWithOptions(session.Options{})
// Specify profile to load for the session's config // Specify profile to load for the session's config

View file

@ -2,7 +2,7 @@ package session
import ( import (
"fmt" "fmt"
"os" "io/ioutil"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
@ -105,12 +105,13 @@ func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) {
files := make([]sharedConfigFile, 0, len(filenames)) files := make([]sharedConfigFile, 0, len(filenames))
for _, filename := range filenames { for _, filename := range filenames {
if _, err := os.Stat(filename); os.IsNotExist(err) { b, err := ioutil.ReadFile(filename)
// Trim files from the list that don't exist. if err != nil {
// Skip files which can't be opened and read for whatever reason
continue continue
} }
f, err := ini.Load(filename) f, err := ini.Load(b)
if err != nil { if err != nil {
return nil, SharedConfigLoadError{Filename: filename} return nil, SharedConfigLoadError{Filename: filename}
} }

View file

@ -0,0 +1,24 @@
// +build go1.5
package v4
import (
"net/url"
"strings"
)
func getURIPath(u *url.URL) string {
var uri string
if len(u.Opaque) > 0 {
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
} else {
uri = u.EscapedPath()
}
if len(uri) == 0 {
uri = "/"
}
return uri
}

View file

@ -0,0 +1,24 @@
// +build !go1.5
package v4
import (
"net/url"
"strings"
)
func getURIPath(u *url.URL) string {
var uri string
if len(u.Opaque) > 0 {
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
} else {
uri = u.Path
}
if len(uri) == 0 {
uri = "/"
}
return uri
}

View file

@ -2,6 +2,48 @@
// //
// Provides request signing for request that need to be signed with // Provides request signing for request that need to be signed with
// AWS V4 Signatures. // AWS V4 Signatures.
//
// Standalone Signer
//
// Generally using the signer outside of the SDK should not require any additional
// logic when using Go v1.5 or higher. The signer does this by taking advantage
// of the URL.EscapedPath method. If your request URI requires additional escaping
// you many need to use the URL.Opaque to define what the raw URI should be sent
// to the service as.
//
// The signer will first check the URL.Opaque field, and use its value if set.
// The signer does require the URL.Opaque field to be set in the form of:
//
// "//<hostname>/<path>"
//
// // e.g.
// "//example.com/some/path"
//
// The leading "//" and hostname are required or the URL.Opaque escaping will
// not work correctly.
//
// If URL.Opaque is not set the signer will fallback to the URL.EscapedPath()
// method and using the returned value. If you're using Go v1.4 you must set
// URL.Opaque if the URI path needs escaping. If URL.Opaque is not set with
// Go v1.5 the signer will fallback to URL.Path.
//
// AWS v4 signature validation requires that the canonical string's URI path
// element must be the URI escaped form of the HTTP request's path.
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
//
// The Go HTTP client will perform escaping automatically on the request. Some
// of these escaping may cause signature validation errors because the HTTP
// request differs from the URI path or query that the signature was generated.
// https://golang.org/pkg/net/url/#URL.EscapedPath
//
// Because of this, it is recommended that when using the signer outside of the
// SDK that explicitly escaping the request prior to being signed is preferable,
// and will help prevent signature validation errors. This can be done by setting
// the URL.Opaque or URL.RawPath. The SDK will use URL.Opaque first and then
// call URL.EscapedPath() if Opaque is not set.
//
// Test `TestStandaloneSign` provides a complete example of using the signer
// outside of the SDK and pre-escaping the URI path.
package v4 package v4
import ( import (
@ -120,6 +162,15 @@ type Signer struct {
// request's query string. // request's query string.
DisableHeaderHoisting bool DisableHeaderHoisting bool
// Disables the automatic escaping of the URI path of the request for the
// siganture's canonical string's path. For services that do not need additional
// escaping then use this to disable the signer escaping the path.
//
// S3 is an example of a service that does not need additional escaping.
//
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
DisableURIPathEscaping bool
// currentTimeFn returns the time value which represents the current time. // currentTimeFn returns the time value which represents the current time.
// This value should only be used for testing. If it is nil the default // This value should only be used for testing. If it is nil the default
// time.Now will be used. // time.Now will be used.
@ -151,6 +202,8 @@ type signingCtx struct {
ExpireTime time.Duration ExpireTime time.Duration
SignedHeaderVals http.Header SignedHeaderVals http.Header
DisableURIPathEscaping bool
credValues credentials.Value credValues credentials.Value
isPresign bool isPresign bool
formattedTime string formattedTime string
@ -236,22 +289,18 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
} }
ctx := &signingCtx{ ctx := &signingCtx{
Request: r, Request: r,
Body: body, Body: body,
Query: r.URL.Query(), Query: r.URL.Query(),
Time: signTime, Time: signTime,
ExpireTime: exp, ExpireTime: exp,
isPresign: exp != 0, isPresign: exp != 0,
ServiceName: service, ServiceName: service,
Region: region, Region: region,
DisableURIPathEscaping: v4.DisableURIPathEscaping,
} }
if ctx.isRequestSigned() { if ctx.isRequestSigned() {
if !v4.Credentials.IsExpired() && currentTimeFn().Before(ctx.Time.Add(10*time.Minute)) {
// If the request is already signed, and the credentials have not
// expired, and the request is not too old ignore the signing request.
return ctx.SignedHeaderVals, nil
}
ctx.Time = currentTimeFn() ctx.Time = currentTimeFn()
ctx.handlePresignRemoval() ctx.handlePresignRemoval()
} }
@ -359,6 +408,10 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
v4.Logger = req.Config.Logger v4.Logger = req.Config.Logger
v4.DisableHeaderHoisting = req.NotHoist v4.DisableHeaderHoisting = req.NotHoist
v4.currentTimeFn = curTimeFn v4.currentTimeFn = curTimeFn
if name == "s3" {
// S3 service should not have any escaping applied
v4.DisableURIPathEscaping = true
}
}) })
signingTime := req.Time signingTime := req.Time
@ -366,7 +419,9 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
signingTime = req.LastSignedAt signingTime = req.LastSignedAt
} }
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.Body, name, region, req.ExpireTime, signingTime) signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
name, region, req.ExpireTime, signingTime,
)
if err != nil { if err != nil {
req.Error = err req.Error = err
req.SignedHeaderVals = nil req.SignedHeaderVals = nil
@ -512,18 +567,15 @@ func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
} }
func (ctx *signingCtx) buildCanonicalString() { func (ctx *signingCtx) buildCanonicalString() {
ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1) query := ctx.Query
uri := ctx.Request.URL.Opaque for key := range query {
if uri != "" { sort.Strings(query[key])
uri = "/" + strings.Join(strings.Split(uri, "/")[3:], "/")
} else {
uri = ctx.Request.URL.Path
}
if uri == "" {
uri = "/"
} }
ctx.Request.URL.RawQuery = strings.Replace(query.Encode(), "+", "%20", -1)
if ctx.ServiceName != "s3" { uri := getURIPath(ctx.Request.URL)
if !ctx.DisableURIPathEscaping {
uri = rest.EscapePath(uri, false) uri = rest.EscapePath(uri, false)
} }

View file

@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go" const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK // SDKVersion is the version of this SDK
const SDKVersion = "1.4.14" const SDKVersion = "1.5.6"

View file

@ -1,7 +1,7 @@
// Package endpoints validates regional endpoints for services. // Package endpoints validates regional endpoints for services.
package endpoints package endpoints
//go:generate go run ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go //go:generate go run -tags codegen ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go
//go:generate gofmt -s -w endpoints_map.go //go:generate gofmt -s -w endpoints_map.go
import ( import (

View file

@ -23,6 +23,10 @@
"us-gov-west-1/ec2metadata": { "us-gov-west-1/ec2metadata": {
"endpoint": "http://169.254.169.254/latest" "endpoint": "http://169.254.169.254/latest"
}, },
"*/budgets": {
"endpoint": "budgets.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/cloudfront": { "*/cloudfront": {
"endpoint": "cloudfront.amazonaws.com", "endpoint": "cloudfront.amazonaws.com",
"signingRegion": "us-east-1" "signingRegion": "us-east-1"

View file

@ -18,6 +18,10 @@ var endpointsMap = endpointStruct{
"*/*": { "*/*": {
Endpoint: "{service}.{region}.amazonaws.com", Endpoint: "{service}.{region}.amazonaws.com",
}, },
"*/budgets": {
Endpoint: "budgets.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/cloudfront": { "*/cloudfront": {
Endpoint: "cloudfront.amazonaws.com", Endpoint: "cloudfront.amazonaws.com",
SigningRegion: "us-east-1", SigningRegion: "us-east-1",

View file

@ -1,7 +1,7 @@
// Package ec2query provides serialization of AWS EC2 requests and responses. // Package ec2query provides serialization of AWS EC2 requests and responses.
package ec2query package ec2query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/ec2.json build_test.go //go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/ec2.json build_test.go
import ( import (
"net/url" "net/url"

View file

@ -1,6 +1,6 @@
package ec2query package ec2query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/ec2.json unmarshal_test.go //go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/ec2.json unmarshal_test.go
import ( import (
"encoding/xml" "encoding/xml"

View file

@ -2,8 +2,8 @@
// requests and responses. // requests and responses.
package jsonrpc package jsonrpc
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/json.json build_test.go //go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/json.json build_test.go
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/json.json unmarshal_test.go //go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/json.json unmarshal_test.go
import ( import (
"encoding/json" "encoding/json"

View file

@ -1,7 +1,7 @@
// Package query provides serialization of AWS query requests, and responses. // Package query provides serialization of AWS query requests, and responses.
package query package query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/query.json build_test.go //go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/query.json build_test.go
import ( import (
"net/url" "net/url"

View file

@ -1,6 +1,6 @@
package query package query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/query.json unmarshal_test.go //go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/query.json unmarshal_test.go
import ( import (
"encoding/xml" "encoding/xml"

View file

@ -14,6 +14,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
) )
@ -92,7 +93,7 @@ func buildLocationElements(r *request.Request, v reflect.Value) {
} }
r.HTTPRequest.URL.RawQuery = query.Encode() r.HTTPRequest.URL.RawQuery = query.Encode()
updatePath(r.HTTPRequest.URL, r.HTTPRequest.URL.Path) updatePath(r.HTTPRequest.URL, r.HTTPRequest.URL.Path, aws.BoolValue(r.Config.DisableRestProtocolURICleaning))
} }
func buildBody(r *request.Request, v reflect.Value) { func buildBody(r *request.Request, v reflect.Value) {
@ -193,13 +194,15 @@ func buildQueryString(query url.Values, v reflect.Value, name string) error {
return nil return nil
} }
func updatePath(url *url.URL, urlPath string) { func updatePath(url *url.URL, urlPath string, disableRestProtocolURICleaning bool) {
scheme, query := url.Scheme, url.RawQuery scheme, query := url.Scheme, url.RawQuery
hasSlash := strings.HasSuffix(urlPath, "/") hasSlash := strings.HasSuffix(urlPath, "/")
// clean up path // clean up path
urlPath = path.Clean(urlPath) if !disableRestProtocolURICleaning {
urlPath = path.Clean(urlPath)
}
if hasSlash && !strings.HasSuffix(urlPath, "/") { if hasSlash && !strings.HasSuffix(urlPath, "/") {
urlPath += "/" urlPath += "/"
} }

View file

@ -2,8 +2,8 @@
// requests and responses. // requests and responses.
package restxml package restxml
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/rest-xml.json build_test.go //go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/rest-xml.json build_test.go
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/rest-xml.json unmarshal_test.go //go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/rest-xml.json unmarshal_test.go
import ( import (
"bytes" "bytes"

File diff suppressed because it is too large Load diff

Some files were not shown because too many files have changed in this diff Show more