Merge remote-tracking branch 'upstream/master'

This commit is contained in:
mwoolsey 2016-11-20 18:31:55 -08:00
commit 3e72e50fa5
320 changed files with 55380 additions and 18209 deletions

View file

@ -1,3 +1,40 @@
## Next (Unreleased)
DEPRECATIONS/CHANGES:
* http: impose a maximum request size of 32MB to prevent a denial of service
with arbitrarily large requests. [GH-2108]
IMPROVEMENTS:
* auth/github: Policies can now be assigned to users as well as to teams
[GH-2079]
* cli: Set the number of retries on 500 down to 0 by default (no retrying). It
can be very confusing to users when there is a pause while the retries
happen if they haven't explicitly set it. With request forwarding the need
for this is lessened anyways. [GH-2093]
* core: Response wrapping is now allowed to be specified by backend responses
(requires backends gaining support) [GH-2088]
* secret/consul: Added listing functionality to roles [GH-2065]
* secret/postgresql: Added `revocation_sql` parameter on the role endpoint to
enable customization of user revocation SQL statements [GH-2033]
* secret/transit: Add listing of keys [GH-1987]
BUG FIXES:
* auth/approle: Creating the index for the role_id properly [GH-2004]
* auth/aws-ec2: Handle the case of multiple upgrade attempts when setting the
instance-profile ARN [GH-2035]
* api/unwrap, command/unwrap: Fix compatibility of `unwrap` command with Vault
0.6.1 and older [GH-2014]
* api/unwrap, command/unwrap: Fix error when no client token exists [GH-2077]
* command/ssh: Use temporary file for identity and ensure its deletion before
the command returns [GH-2016]
* core: Fix bug where a failure to come up as active node (e.g. if an audit
backend failed) could lead to deadlock [GH-2083]
* physical/mysql: Fix potential crash during setup due to a query failure
[GH-2105]
## 0.6.2 (October 5, 2016)
DEPRECATIONS/CHANGES:
@ -140,11 +177,11 @@ DEPRECATIONS/CHANGES:
* Status codes for sealed/uninitialized Vaults have changed to `503`/`501`
respectively. See the [version-specific upgrade
guide](https://www.vaultproject.io/docs/install/upgrade-to-0.6.1.html) for
more details.
more details.
* Root tokens (tokens with the `root` policy) can no longer be created except
by another root token or the `generate-root` endpoint.
* Issued certificates from the `pki` backend against new roles created or
modified after upgrading will contain a set of default key usages.
modified after upgrading will contain a set of default key usages.
* The `dynamodb` physical data store no longer supports HA by default. It has
some non-ideal behavior around failover that was causing confusion. See the
[documentation](https://www.vaultproject.io/docs/config/index.html#ha_enabled)
@ -214,7 +251,7 @@ IMPROVEMENTS:
the request portion of the response. [GH-1650]
* auth/aws-ec2: Added a new constraint `bound_account_id` to the role
[GH-1523]
* auth/aws-ec2: Added a new constraint `bound_iam_role_arn` to the role
* auth/aws-ec2: Added a new constraint `bound_iam_role_arn` to the role
[GH-1522]
* auth/aws-ec2: Added `ttl` field for the role [GH-1703]
* auth/ldap, secret/cassandra, physical/consul: Clients with `tls.Config`
@ -258,7 +295,7 @@ IMPROVEMENTS:
configuration [GH-1581]
* secret/mssql,mysql,postgresql: Reading of connection settings is supported
in all the sql backends [GH-1515]
* secret/mysql: Added optional maximum idle connections value to MySQL
* secret/mysql: Added optional maximum idle connections value to MySQL
connection configuration [GH-1635]
* secret/mysql: Use a combination of the role name and token display name in
generated user names and allow the length to be controlled [GH-1604]
@ -601,7 +638,7 @@ BUG FIXES:
during renewals [GH-1176]
## 0.5.1 (February 25th, 2016)
DEPRECATIONS/CHANGES:
* RSA keys less than 2048 bits are no longer supported in the PKI backend.
@ -631,7 +668,7 @@ IMPROVEMENTS:
* api/health: Add the server's time in UTC to health responses [GH-1117]
* command/rekey and command/generate-root: These now return the status at
attempt initialization time, rather than requiring a separate fetch for the
nonce [GH-1054]
nonce [GH-1054]
* credential/cert: Don't require root/sudo tokens for the `certs/` and `crls/`
paths; use normal ACL behavior instead [GH-468]
* credential/github: The validity of the token used for login will be checked
@ -761,7 +798,7 @@ FEATURES:
documentation](https://vaultproject.io/docs/config/index.html) for details.
[GH-945]
* **STS Support in AWS Secret Backend**: You can now use the AWS secret
backend to fetch STS tokens rather than IAM users. [GH-927]
backend to fetch STS tokens rather than IAM users. [GH-927]
* **Speedups in the transit backend**: The `transit` backend has gained a
cache, and now loads only the working set of keys (e.g. from the
`min_decryption_version` to the current key version) into its working set.

View file

@ -60,4 +60,8 @@ bootstrap:
go get -u $$tool; \
done
proto:
protoc -I helper/forwarding -I vault -I ../../.. vault/request_forwarding_service.proto --go_out=plugins=grpc:vault
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
.PHONY: bin default generate test vet bootstrap

View file

@ -48,7 +48,7 @@ type Config struct {
redirectSetup sync.Once
// MaxRetries controls the maximum number of times to retry when a 5xx error
// occurs. Set to 0 or less to disable retrying.
// occurs. Set to 0 or less to disable retrying. Defaults to 0.
MaxRetries int
}
@ -99,8 +99,6 @@ func DefaultConfig() *Config {
config.Address = v
}
config.MaxRetries = pester.DefaultClient.MaxRetries
return config
}

View file

@ -120,8 +120,12 @@ func (c *Logical) Delete(path string) (*Secret, error) {
func (c *Logical) Unwrap(wrappingToken string) (*Secret, error) {
var data map[string]interface{}
if wrappingToken != "" {
data = map[string]interface{}{
"token": wrappingToken,
if c.c.Token() == "" {
c.c.SetToken(wrappingToken)
} else if wrappingToken != c.c.Token() {
data = map[string]interface{}{
"token": wrappingToken,
}
}
}
@ -146,7 +150,7 @@ func (c *Logical) Unwrap(wrappingToken string) (*Secret, error) {
return nil, nil
}
if wrappingToken == "" {
if wrappingToken != "" {
origToken := c.c.Token()
defer c.c.SetToken(origToken)
c.c.SetToken(wrappingToken)

View file

@ -64,9 +64,18 @@ func (f *AuditFormatter) FormatRequest(
if err := Hash(config.Salt, auth); err != nil {
return err
}
// Cache and restore accessor in the request
var clientTokenAccessor string
if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" {
clientTokenAccessor = req.ClientTokenAccessor
}
if err := Hash(config.Salt, req); err != nil {
return err
}
if clientTokenAccessor != "" {
req.ClientTokenAccessor = clientTokenAccessor
}
}
// If auth is nil, make an empty one
@ -89,13 +98,14 @@ func (f *AuditFormatter) FormatRequest(
},
Request: AuditRequest{
ID: req.ID,
ClientToken: req.ClientToken,
Operation: req.Operation,
Path: req.Path,
Data: req.Data,
RemoteAddr: getRemoteAddr(req),
WrapTTL: int(req.WrapTTL / time.Second),
ID: req.ID,
ClientToken: req.ClientToken,
ClientTokenAccessor: req.ClientTokenAccessor,
Operation: req.Operation,
Path: req.Path,
Data: req.Data,
RemoteAddr: getRemoteAddr(req),
WrapTTL: int(req.WrapTTL / time.Second),
},
}
@ -167,9 +177,17 @@ func (f *AuditFormatter) FormatResponse(
auth.Accessor = accessor
}
// Cache and restore accessor in the request
var clientTokenAccessor string
if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" {
clientTokenAccessor = req.ClientTokenAccessor
}
if err := Hash(config.Salt, req); err != nil {
return err
}
if clientTokenAccessor != "" {
req.ClientTokenAccessor = clientTokenAccessor
}
// Cache and restore accessor in the response
accessor = ""
@ -241,13 +259,14 @@ func (f *AuditFormatter) FormatResponse(
},
Request: AuditRequest{
ID: req.ID,
ClientToken: req.ClientToken,
Operation: req.Operation,
Path: req.Path,
Data: req.Data,
RemoteAddr: getRemoteAddr(req),
WrapTTL: int(req.WrapTTL / time.Second),
ID: req.ID,
ClientToken: req.ClientToken,
ClientTokenAccessor: req.ClientTokenAccessor,
Operation: req.Operation,
Path: req.Path,
Data: req.Data,
RemoteAddr: getRemoteAddr(req),
WrapTTL: int(req.WrapTTL / time.Second),
},
Response: AuditResponse{
@ -286,13 +305,14 @@ type AuditResponseEntry struct {
}
type AuditRequest struct {
ID string `json:"id"`
Operation logical.Operation `json:"operation"`
ClientToken string `json:"client_token"`
Path string `json:"path"`
Data map[string]interface{} `json:"data"`
RemoteAddr string `json:"remote_address"`
WrapTTL int `json:"wrap_ttl"`
ID string `json:"id"`
Operation logical.Operation `json:"operation"`
ClientToken string `json:"client_token"`
ClientTokenAccessor string `json:"client_token_accessor"`
Path string `json:"path"`
Data map[string]interface{} `json:"data"`
RemoteAddr string `json:"remote_address"`
WrapTTL int `json:"wrap_ttl"`
}
type AuditResponse struct {

View file

@ -32,7 +32,7 @@ func TestFormatJSONx_formatRequest(t *testing.T) {
},
errors.New("this is an error"),
"",
`<json:object name="auth"><json:string name="accessor"></json:string><json:string name="client_token"></json:string><json:string name="display_name"></json:string><json:null name="metadata" /><json:array name="policies"><json:string>root</json:string></json:array></json:object><json:string name="error">this is an error</json:string><json:object name="request"><json:string name="client_token"></json:string><json:null name="data" /><json:string name="id"></json:string><json:string name="operation">update</json:string><json:string name="path">/foo</json:string><json:string name="remote_address">127.0.0.1</json:string><json:number name="wrap_ttl">60</json:number></json:object><json:string name="type">request</json:string>`,
`<json:object name="auth"><json:string name="accessor"></json:string><json:string name="client_token"></json:string><json:string name="display_name"></json:string><json:null name="metadata" /><json:array name="policies"><json:string>root</json:string></json:array></json:object><json:string name="error">this is an error</json:string><json:object name="request"><json:string name="client_token"></json:string><json:string name="client_token_accessor"></json:string><json:null name="data" /><json:string name="id"></json:string><json:string name="operation">update</json:string><json:string name="path">/foo</json:string><json:string name="remote_address">127.0.0.1</json:string><json:number name="wrap_ttl">60</json:number></json:object><json:string name="type">request</json:string>`,
},
}

View file

@ -49,6 +49,10 @@ func Hash(salter *salt.Salt, raw interface{}) error {
s.ClientToken = fn(s.ClientToken)
}
if s.ClientTokenAccessor != "" {
s.ClientTokenAccessor = fn(s.ClientTokenAccessor)
}
data, err := HashStructure(s.Data, fn)
if err != nil {
return err

View file

@ -121,7 +121,7 @@ will expire. Defaults to 0 meaning that the the secret_id is of unlimited use.`,
"secret_id_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `Duration in seconds after which the issued SecretID should expire. Defaults
to 0, in which case the value will fall back to the system/mount defaults.`,
to 0, meaning no expiration.`,
},
"token_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
@ -249,7 +249,7 @@ addresses which can perform the login operation`,
"secret_id_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: `Duration in seconds after which the issued SecretID should expire. Defaults
to 0, in which case the value will fall back to the system/mount defaults.`,
to 0, meaning no expiration.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -640,7 +640,7 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
}
// If previousRoleID is still intact, don't create another one
if previousRoleID != "" {
if previousRoleID != "" && previousRoleID == role.RoleID {
return nil
}

View file

@ -111,6 +111,77 @@ func TestAppRole_RoleConstraints(t *testing.T) {
}
}
func TestAppRole_RoleIDUpdate(t *testing.T) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)
roleData := map[string]interface{}{
"role_id": "role-id-123",
"policies": "a,b",
"secret_id_num_uses": 10,
"secret_id_ttl": 300,
"token_ttl": 400,
"token_max_ttl": 500,
}
roleReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "role/testrole1",
Storage: storage,
Data: roleData,
}
resp, err = b.HandleRequest(roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
roleIDUpdateReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "role/testrole1/role-id",
Storage: storage,
Data: map[string]interface{}{
"role_id": "customroleid",
},
}
resp, err = b.HandleRequest(roleIDUpdateReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
secretIDReq := &logical.Request{
Operation: logical.UpdateOperation,
Storage: storage,
Path: "role/testrole1/secret-id",
}
resp, err = b.HandleRequest(secretIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
secretID := resp.Data["secret_id"].(string)
loginData := map[string]interface{}{
"role_id": "customroleid",
"secret_id": secretID,
}
loginReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "login",
Storage: storage,
Data: loginData,
Connection: &logical.Connection{
RemoteAddr: "127.0.0.1",
},
}
resp, err = b.HandleRequest(loginReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Auth == nil {
t.Fatalf("expected a non-nil auth object in the response")
}
}
func TestAppRole_RoleIDUniqueness(t *testing.T) {
var resp *logical.Response
var err error

View file

@ -207,11 +207,6 @@ func (b *backend) nonLockedAWSRole(s logical.Storage, roleName string) (*awsRole
// Check if the value held by role ARN field is actually an instance profile ARN
if result.BoundIamRoleARN != "" && strings.Contains(result.BoundIamRoleARN, ":instance-profile/") {
// For sanity
if result.BoundIamInstanceProfileARN != "" {
return nil, fmt.Errorf("bound_iam_role_arn contains instance profile ARN and bound_iam_instance_profile_arn is non empty")
}
// If yes, move it to the correct field
result.BoundIamInstanceProfileARN = result.BoundIamRoleARN

View file

@ -14,12 +14,22 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Backend() *backend {
var b backend
b.Map = &framework.PolicyMap{
b.TeamMap = &framework.PolicyMap{
PathMap: framework.PathMap{
Name: "teams",
},
DefaultKey: "default",
}
b.UserMap = &framework.PolicyMap{
PathMap: framework.PathMap{
Name: "users",
},
DefaultKey: "default",
}
allPaths := append(b.TeamMap.Paths(), b.UserMap.Paths()...)
b.Backend = &framework.Backend{
Help: backendHelp,
@ -32,7 +42,7 @@ func Backend() *backend {
Paths: append([]*framework.Path{
pathConfig(&b),
pathLogin(&b),
}, b.Map.Paths()...),
}, allPaths...),
AuthRenew: b.pathLoginRenew,
}
@ -43,7 +53,9 @@ func Backend() *backend {
type backend struct {
*framework.Backend
Map *framework.PolicyMap
TeamMap *framework.PolicyMap
UserMap *framework.PolicyMap
}
// Client returns the GitHub client to communicate to GitHub via the

View file

@ -110,17 +110,21 @@ func TestBackend_basic(t *testing.T) {
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, false),
testAccMap(t, "default", "root"),
testAccMap(t, "oWnErs", "root"),
testAccLogin(t, []string{"root"}),
testAccMap(t, "default", "fakepol"),
testAccMap(t, "oWnErs", "fakepol"),
testAccLogin(t, []string{"default", "fakepol"}),
testAccStepConfig(t, true),
testAccMap(t, "default", "root"),
testAccMap(t, "oWnErs", "root"),
testAccLogin(t, []string{"root"}),
testAccMap(t, "default", "fakepol"),
testAccMap(t, "oWnErs", "fakepol"),
testAccLogin(t, []string{"default", "fakepol"}),
testAccStepConfigWithBaseURL(t),
testAccMap(t, "default", "root"),
testAccMap(t, "oWnErs", "root"),
testAccLogin(t, []string{"root"}),
testAccMap(t, "default", "fakepol"),
testAccMap(t, "oWnErs", "fakepol"),
testAccLogin(t, []string{"default", "fakepol"}),
testAccMap(t, "default", "fakepol"),
testAccStepConfig(t, true),
mapUserToPolicy(t, os.Getenv("GITHUB_USER"), "userpolicy"),
testAccLogin(t, []string{"default", "fakepol", "userpolicy"}),
},
})
}
@ -174,7 +178,17 @@ func testAccMap(t *testing.T, k string, v string) logicaltest.TestStep {
}
}
func testAccLogin(t *testing.T, keys []string) logicaltest.TestStep {
func mapUserToPolicy(t *testing.T, k string, v string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "map/users/" + k,
Data: map[string]interface{}{
"value": v,
},
}
}
func testAccLogin(t *testing.T, policies []string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "login",
@ -183,6 +197,6 @@ func testAccLogin(t *testing.T, keys []string) logicaltest.TestStep {
},
Unauthenticated: true,
Check: logicaltest.TestCheckAuth(keys),
Check: logicaltest.TestCheckAuth(policies),
}
}

View file

@ -194,14 +194,22 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
}
}
policiesList, err := b.Map.Policies(req.Storage, teamNames...)
groupPoliciesList, err := b.TeamMap.Policies(req.Storage, teamNames...)
if err != nil {
return nil, nil, err
}
userPoliciesList, err := b.UserMap.Policies(req.Storage, []string{*user.Login}...)
if err != nil {
return nil, nil, err
}
return &verifyCredentialsResp{
User: user,
Org: org,
Policies: policiesList,
Policies: append(groupPoliciesList, userPoliciesList...),
}, nil, nil
}

View file

@ -7,6 +7,7 @@ import (
"github.com/go-ldap/ldap"
"github.com/hashicorp/vault/helper/mfa"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -158,6 +159,9 @@ func (b *backend) Login(req *logical.Request, username string, password string)
}
}
// Policies from each group may overlap
policies = strutil.RemoveDuplicates(policies)
if len(policies) == 0 {
errStr := "user is not a member of any authorized group"
if len(ldapResponse.Warnings()) > 0 {

View file

@ -100,6 +100,12 @@ Default: cn`,
Default: "tls12",
Description: "Minimum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
},
"tls_max_version": &framework.FieldSchema{
Type: framework.TypeString,
Default: "tls12",
Description: "Maximum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -225,6 +231,19 @@ func (b *backend) newConfigEntry(d *framework.FieldData) (*ConfigEntry, error) {
return nil, fmt.Errorf("invalid 'tls_min_version'")
}
cfg.TLSMaxVersion = d.Get("tls_max_version").(string)
if cfg.TLSMaxVersion == "" {
return nil, fmt.Errorf("failed to get 'tls_max_version' value")
}
_, ok = tlsutil.TLSLookup[cfg.TLSMaxVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_max_version'")
}
if cfg.TLSMaxVersion < cfg.TLSMinVersion {
return nil, fmt.Errorf("'tls_max_version' must be greater than or equal to 'tls_min_version'")
}
startTLS := d.Get("starttls").(bool)
if startTLS {
cfg.StartTLS = startTLS
@ -280,6 +299,7 @@ type ConfigEntry struct {
BindPassword string `json:"bindpass" structs:"bindpass" mapstructure:"bindpass"`
DiscoverDN bool `json:"discoverdn" structs:"discoverdn" mapstructure:"discoverdn"`
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
TLSMaxVersion string `json:"tls_max_version" structs:"tls_max_version" mapstructure:"tls_max_version"`
}
func (c *ConfigEntry) GetTLSConfig(host string) (*tls.Config, error) {
@ -295,6 +315,14 @@ func (c *ConfigEntry) GetTLSConfig(host string) (*tls.Config, error) {
tlsConfig.MinVersion = tlsMinVersion
}
if c.TLSMaxVersion != "" {
tlsMaxVersion, ok := tlsutil.TLSLookup[c.TLSMaxVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_max_version' in config")
}
tlsConfig.MaxVersion = tlsMaxVersion
}
if c.InsecureTLS {
tlsConfig.InsecureSkipVerify = true
}

View file

@ -14,6 +14,7 @@ func Backend() *backend {
b.Backend = &framework.Backend{
Paths: []*framework.Path{
pathConfigAccess(),
pathListRoles(&b),
pathRoles(),
pathToken(&b),
},

View file

@ -293,7 +293,10 @@ func TestBackend_crud(t *testing.T) {
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepWritePolicy(t, "test", testPolicy, ""),
testAccStepWritePolicy(t, "test2", testPolicy, ""),
testAccStepWritePolicy(t, "test3", testPolicy, ""),
testAccStepReadPolicy(t, "test", testPolicy, 0),
testAccStepListPolicy(t, []string{"test", "test2", "test3"}),
testAccStepDeletePolicy(t, "test"),
},
})
@ -443,6 +446,20 @@ func testAccStepReadPolicy(t *testing.T, name string, policy string, lease time.
}
}
func testAccStepListPolicy(t *testing.T, names []string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ListOperation,
Path: "roles/",
Check: func(resp *logical.Response) error {
respKeys := resp.Data["keys"].([]string)
if !reflect.DeepEqual(respKeys, names) {
return fmt.Errorf("mismatch: %#v %#v", respKeys, names)
}
return nil
},
}
}
func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,

View file

@ -9,6 +9,16 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func pathListRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathRoleList,
},
}
}
func pathRoles() *framework.Path {
return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"),
@ -47,6 +57,16 @@ Defaults to 'client'.`,
}
}
func (b *backend) pathRoleList(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("policy/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func pathRolesRead(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)

View file

@ -55,8 +55,9 @@ func (b *backend) pathTokenRead(
return logical.ErrorResponse(err.Error()), nil
}
// Generate a random name for the token
tokenName := fmt.Sprintf("Vault %s %d", req.DisplayName, time.Now().Unix())
// Generate a name for the token
tokenName := fmt.Sprintf("Vault %s %s %d", name, req.DisplayName, time.Now().UnixNano())
// Create it
token, _, err := c.ACL().Create(&api.ACLEntry{
Name: tokenName,

View file

@ -37,8 +37,11 @@ func pathRoles(b *backend) *framework.Path {
},
"revocation_sql": {
Type: framework.TypeString,
Description: "SQL string to revoke a user. This is in beta; use with caution.",
Type: framework.TypeString,
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
string, a base64-encoded semicolon-separated string, a serialized JSON string
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
will be substituted.`,
},
},
@ -90,19 +93,12 @@ func (b *backend) pathRoleRead(
return nil, nil
}
resp := &logical.Response{
return &logical.Response{
Data: map[string]interface{}{
"sql": role.SQL,
"sql": role.SQL,
"revocation_sql": role.RevocationSQL,
},
}
// TODO: This is separate because this is in beta in 0.6.2, so we don't
// want it to show up in the general case.
if role.RevocationSQL != "" {
resp.Data["revocation_sql"] = role.RevocationSQL
}
return resp, nil
}, nil
}
func (b *backend) pathRoleList(
@ -193,4 +189,12 @@ Example of a decent SQL query to use:
Note the above user would be able to access everything in schema public.
For more complex GRANT clauses, see the PostgreSQL manual.
The "revocation_sql" parameter customizes the SQL string used to revoke a user.
Example of a decent revocation SQL query to use:
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
REVOKE USAGE ON SCHEMA public FROM {{name}};
DROP ROLE IF EXISTS {{name}};
`

View file

@ -7,8 +7,8 @@ const (
#!/bin/bash
#
# This is a default script which installs or uninstalls an RSA public key to/from
# authoried_keys file in a typical linux machine.
#
# authorized_keys file in a typical linux machine.
#
# If the platform differs or if the binaries used in this script are not available
# in target machine, use the 'install_script' parameter with 'roles/' endpoint to
# register a custom script (applicable for Dynamic type only).
@ -51,7 +51,7 @@ fi
# Create the .ssh directory and authorized_keys file if it does not exist
SSH_DIR=$(dirname $AUTH_KEYS_FILE)
sudo mkdir -p "$SSH_DIR"
sudo mkdir -p "$SSH_DIR"
sudo touch "$AUTH_KEYS_FILE"
# Remove the key from authorized_keys file if it is already present.

View file

@ -72,7 +72,7 @@ func (b *backend) pathConfigZeroAddressWrite(req *logical.Request, d *framework.
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("Role [%s] does not exist", item)), nil
return logical.ErrorResponse(fmt.Sprintf("Role %q does not exist", item)), nil
}
}

View file

@ -55,10 +55,10 @@ func (b *backend) pathCredsCreateWrite(
role, err := b.getRole(req.Storage, roleName)
if err != nil {
return nil, fmt.Errorf("error retrieving role: %s", err)
return nil, fmt.Errorf("error retrieving role: %v", err)
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("Role '%s' not found", roleName)), nil
return logical.ErrorResponse(fmt.Sprintf("Role %q not found", roleName)), nil
}
// username is an optional parameter.
@ -89,7 +89,7 @@ func (b *backend) pathCredsCreateWrite(
// Validate the IP address
ipAddr := net.ParseIP(ipRaw)
if ipAddr == nil {
return logical.ErrorResponse(fmt.Sprintf("Invalid IP '%s'", ipRaw)), nil
return logical.ErrorResponse(fmt.Sprintf("Invalid IP %q", ipRaw)), nil
}
// Check if the IP belongs to the registered list of CIDR blocks under the role
@ -97,7 +97,7 @@ func (b *backend) pathCredsCreateWrite(
zeroAddressEntry, err := b.getZeroAddressRoles(req.Storage)
if err != nil {
return nil, fmt.Errorf("error retrieving zero-address roles: %s", err)
return nil, fmt.Errorf("error retrieving zero-address roles: %v", err)
}
var zeroAddressRoles []string
if zeroAddressEntry != nil {
@ -106,7 +106,7 @@ func (b *backend) pathCredsCreateWrite(
err = validateIP(ip, roleName, role.CIDRList, role.ExcludeCIDRList, zeroAddressRoles)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error validating IP: %s", err)), nil
return logical.ErrorResponse(fmt.Sprintf("Error validating IP: %v", err)), nil
}
var result *logical.Response
@ -171,22 +171,22 @@ func (b *backend) GenerateDynamicCredential(req *logical.Request, role *sshRole,
// Fetch the host key to be used for dynamic key installation
keyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", role.KeyName))
if err != nil {
return "", "", fmt.Errorf("key '%s' not found. err:%s", role.KeyName, err)
return "", "", fmt.Errorf("key %q not found. err: %v", role.KeyName, err)
}
if keyEntry == nil {
return "", "", fmt.Errorf("key '%s' not found", role.KeyName)
return "", "", fmt.Errorf("key %q not found", role.KeyName)
}
var hostKey sshHostKey
if err := keyEntry.DecodeJSON(&hostKey); err != nil {
return "", "", fmt.Errorf("error reading the host key: %s", err)
return "", "", fmt.Errorf("error reading the host key: %v", err)
}
// Generate a new RSA key pair with the given key length.
dynamicPublicKey, dynamicPrivateKey, err := generateRSAKeys(role.KeyBits)
if err != nil {
return "", "", fmt.Errorf("error generating key: %s", err)
return "", "", fmt.Errorf("error generating key: %v", err)
}
if len(role.KeyOptionSpecs) != 0 {
@ -196,7 +196,7 @@ func (b *backend) GenerateDynamicCredential(req *logical.Request, role *sshRole,
// Add the public key to authorized_keys file in target machine
err = b.installPublicKeyInTarget(role.AdminUser, username, ip, role.Port, hostKey.Key, dynamicPublicKey, role.InstallScript, true)
if err != nil {
return "", "", fmt.Errorf("error adding public key to authorized_keys file in target")
return "", "", fmt.Errorf("failed to add public key to authorized_keys file in target: %v", err)
}
return dynamicPublicKey, dynamicPrivateKey, nil
}

View file

@ -32,7 +32,7 @@ func (b *backend) pathLookupWrite(req *logical.Request, d *framework.FieldData)
}
ip := net.ParseIP(ipAddr)
if ip == nil {
return logical.ErrorResponse(fmt.Sprintf("Invalid IP '%s'", ip.String())), nil
return logical.ErrorResponse(fmt.Sprintf("Invalid IP %q", ip.String())), nil
}
// Get all the roles created in the backend.

View file

@ -232,7 +232,7 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
}
keyEntry, err := req.Storage.Get(fmt.Sprintf("keys/%s", keyName))
if err != nil || keyEntry == nil {
return logical.ErrorResponse(fmt.Sprintf("invalid 'key': '%s'", keyName)), nil
return logical.ErrorResponse(fmt.Sprintf("invalid 'key': %q", keyName)), nil
}
installScript := d.Get("install_script").(string)

View file

@ -55,10 +55,10 @@ func (b *backend) secretDynamicKeyRevoke(req *logical.Request, d *framework.Fiel
// Fetch the host key using the key name
hostKey, err := b.getKey(req.Storage, intSec.HostKeyName)
if err != nil {
return nil, fmt.Errorf("key '%s' not found error:%s", intSec.HostKeyName, err)
return nil, fmt.Errorf("key %q not found error: %v", intSec.HostKeyName, err)
}
if hostKey == nil {
return nil, fmt.Errorf("key '%s' not found", intSec.HostKeyName)
return nil, fmt.Errorf("key %q not found", intSec.HostKeyName)
}
// Remove the public key from authorized_keys file in target machine

View file

@ -23,7 +23,7 @@ import (
func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, err error) {
privateKey, err := rsa.GenerateKey(rand.Reader, keyBits)
if err != nil {
return "", "", fmt.Errorf("error generating RSA key-pair: %s", err)
return "", "", fmt.Errorf("error generating RSA key-pair: %v", err)
}
privateKeyRsa = string(pem.EncodeToMemory(&pem.Block{
@ -33,7 +33,7 @@ func generateRSAKeys(keyBits int) (publicKeyRsa string, privateKeyRsa string, er
sshPublicKey, err := ssh.NewPublicKey(privateKey.Public())
if err != nil {
return "", "", fmt.Errorf("error generating RSA key-pair: %s", err)
return "", "", fmt.Errorf("error generating RSA key-pair: %v", err)
}
publicKeyRsa = "ssh-rsa " + base64.StdEncoding.EncodeToString(sshPublicKey.Marshal())
return
@ -61,7 +61,7 @@ func (b *backend) installPublicKeyInTarget(adminUser, username, ip string, port
err = comm.Upload(publicKeyFileName, bytes.NewBufferString(dynamicPublicKey), nil)
if err != nil {
return fmt.Errorf("error uploading public key: %s", err)
return fmt.Errorf("error uploading public key: %v", err)
}
// Transfer the script required to install or uninstall the key to the remote
@ -70,14 +70,14 @@ func (b *backend) installPublicKeyInTarget(adminUser, username, ip string, port
scriptFileName := fmt.Sprintf("%s.sh", publicKeyFileName)
err = comm.Upload(scriptFileName, bytes.NewBufferString(installScript), nil)
if err != nil {
return fmt.Errorf("error uploading install script: %s", err)
return fmt.Errorf("error uploading install script: %v", err)
}
// Create a session to run remote command that triggers the script to install
// or uninstall the key.
session, err := comm.NewSession()
if err != nil {
return fmt.Errorf("unable to create SSH Session using public keys: %s", err)
return fmt.Errorf("unable to create SSH Session using public keys: %v", err)
}
if session == nil {
return fmt.Errorf("invalid session object")
@ -116,15 +116,15 @@ func roleContainsIP(s logical.Storage, roleName string, ip string) (bool, error)
roleEntry, err := s.Get(fmt.Sprintf("roles/%s", roleName))
if err != nil {
return false, fmt.Errorf("error retrieving role '%s'", err)
return false, fmt.Errorf("error retrieving role %v", err)
}
if roleEntry == nil {
return false, fmt.Errorf("role '%s' not found", roleName)
return false, fmt.Errorf("role %q not found", roleName)
}
var role sshRole
if err := roleEntry.DecodeJSON(&role); err != nil {
return false, fmt.Errorf("error decoding role '%s'", roleName)
return false, fmt.Errorf("error decoding role %q", roleName)
}
if matched, err := cidrListContainsIP(ip, role.CIDRList); err != nil {
@ -143,7 +143,7 @@ func cidrListContainsIP(ip, cidrList string) (bool, error) {
for _, item := range strings.Split(cidrList, ",") {
_, cidrIPNet, err := net.ParseCIDR(item)
if err != nil {
return false, fmt.Errorf("invalid CIDR entry '%s'", item)
return false, fmt.Errorf("invalid CIDR entry %q", item)
}
if cidrIPNet.Contains(net.ParseIP(ip)) {
return true, nil

View file

@ -1,6 +1,7 @@
package transit
import (
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -25,6 +26,7 @@ func Backend(conf *logical.BackendConfig) *backend {
b.pathRotate(),
b.pathRewrap(),
b.pathKeys(),
b.pathListKeys(),
b.pathEncrypt(),
b.pathDecrypt(),
b.pathDatakey(),
@ -38,12 +40,12 @@ func Backend(conf *logical.BackendConfig) *backend {
Secrets: []*framework.Secret{},
}
b.lm = newLockManager(conf.System.CachingDisabled())
b.lm = keysutil.NewLockManager(conf.System.CachingDisabled())
return &b
}
type backend struct {
*framework.Backend
lm *lockManager
lm *keysutil.LockManager
}

View file

@ -12,6 +12,7 @@ import (
"time"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
logicaltest "github.com/hashicorp/vault/logical/testing"
@ -27,7 +28,9 @@ func TestBackend_basic(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
Factory: Factory,
Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", false),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false),
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
@ -53,7 +56,9 @@ func TestBackend_upsert(t *testing.T) {
Factory: Factory,
Steps: []logicaltest.TestStep{
testAccStepReadPolicy(t, "test", true, false),
testAccStepListPolicy(t, "test", true),
testAccStepEncryptUpsert(t, "test", testPlaintext, decryptData),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
},
@ -65,7 +70,9 @@ func TestBackend_datakey(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
Factory: Factory,
Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", false),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false),
testAccStepWriteDatakey(t, "test", false, 256, dataKeyInfo),
testAccStepDecryptDatakey(t, "test", dataKeyInfo),
@ -80,7 +87,9 @@ func TestBackend_rotation(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
Factory: Factory,
Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", false),
testAccStepListPolicy(t, "test", false),
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 0, encryptHistory),
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 1, encryptHistory),
testAccStepRotate(t, "test"), // now v2
@ -128,6 +137,7 @@ func TestBackend_rotation(t *testing.T) {
testAccStepEnableDeletion(t, "test"),
testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", true, false),
testAccStepListPolicy(t, "test", true),
},
})
}
@ -137,7 +147,9 @@ func TestBackend_basic_derived(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
Factory: Factory,
Steps: []logicaltest.TestStep{
testAccStepListPolicy(t, "test", true),
testAccStepWritePolicy(t, "test", true),
testAccStepListPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, true),
testAccStepEncryptContext(t, "test", testPlaintext, "my-cool-context", decryptData),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
@ -158,6 +170,42 @@ func testAccStepWritePolicy(t *testing.T, name string, derived bool) logicaltest
}
}
func testAccStepListPolicy(t *testing.T, name string, expectNone bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ListOperation,
Path: "keys",
Check: func(resp *logical.Response) error {
if resp == nil {
return fmt.Errorf("missing response")
}
if expectNone {
keysRaw, ok := resp.Data["keys"]
if ok || keysRaw != nil {
return fmt.Errorf("response data when expecting none")
}
return nil
}
if len(resp.Data) == 0 {
return fmt.Errorf("no data returned")
}
var d struct {
Keys []string `mapstructure:"keys"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if len(d.Keys) > 0 && d.Keys[0] != name {
return fmt.Errorf("bad name: %#v", d)
}
if len(d.Keys) != 1 {
return fmt.Errorf("only 1 key expected, %d returned", len(d.Keys))
}
return nil
},
}
}
func testAccStepAdjustPolicy(t *testing.T, name string, minVer int) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
@ -242,7 +290,7 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool)
if d.Name != name {
return fmt.Errorf("bad name: %#v", d)
}
if d.Type != KeyType(keyType_AES256_GCM96).String() {
if d.Type != keysutil.KeyType(keysutil.KeyType_AES256_GCM96).String() {
return fmt.Errorf("bad key type: %#v", d)
}
// Should NOT get a key back
@ -536,13 +584,13 @@ func testAccStepDecryptDatakey(t *testing.T, name string,
func TestKeyUpgrade(t *testing.T) {
key, _ := uuid.GenerateRandomBytes(32)
p := &policy{
p := &keysutil.Policy{
Name: "test",
Key: key,
Type: keyType_AES256_GCM96,
Type: keysutil.KeyType_AES256_GCM96,
}
p.migrateKeyToKeysMap()
p.MigrateKeyToKeysMap()
if p.Key != nil ||
p.Keys == nil ||
@ -557,18 +605,18 @@ func TestDerivedKeyUpgrade(t *testing.T) {
key, _ := uuid.GenerateRandomBytes(32)
context, _ := uuid.GenerateRandomBytes(32)
p := &policy{
p := &keysutil.Policy{
Name: "test",
Key: key,
Type: keyType_AES256_GCM96,
Type: keysutil.KeyType_AES256_GCM96,
Derived: true,
}
p.migrateKeyToKeysMap()
p.upgrade(storage) // Need to run the upgrade code to make the migration stick
p.MigrateKeyToKeysMap()
p.Upgrade(storage) // Need to run the upgrade code to make the migration stick
if p.KDF != kdf_hmac_sha256_counter {
t.Fatalf("bad KDF value by default; counter val is %d, KDF val is %d, policy is %#v", kdf_hmac_sha256_counter, p.KDF, *p)
if p.KDF != keysutil.Kdf_hmac_sha256_counter {
t.Fatalf("bad KDF value by default; counter val is %d, KDF val is %d, policy is %#v", keysutil.Kdf_hmac_sha256_counter, p.KDF, *p)
}
derBytesOld, err := p.DeriveKey(context, 1)
@ -585,8 +633,8 @@ func TestDerivedKeyUpgrade(t *testing.T) {
t.Fatal("mismatch of same context alg")
}
p.KDF = kdf_hkdf_sha256
if p.needsUpgrade() {
p.KDF = keysutil.Kdf_hkdf_sha256
if p.NeedsUpgrade() {
t.Fatal("expected no upgrade needed")
}
@ -645,15 +693,15 @@ func testConvergentEncryptionCommon(t *testing.T, ver int) {
t.Fatalf("bad: expected error response, got %#v", *resp)
}
p := &policy{
p := &keysutil.Policy{
Name: "testkey",
Type: keyType_AES256_GCM96,
Type: keysutil.KeyType_AES256_GCM96,
Derived: true,
ConvergentEncryption: true,
ConvergentVersion: ver,
}
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
@ -929,7 +977,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) {
resp, err := be.pathDecryptWrite(req, fd)
if err != nil {
// This could well happen since the min version is jumping around
if resp.Data["error"].(string) == ErrTooOld {
if resp.Data["error"].(string) == keysutil.ErrTooOld {
continue
}
t.Fatalf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id)

View file

@ -6,6 +6,7 @@ import (
"sync"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -116,7 +117,7 @@ func (b *backend) pathEncryptWrite(
}
// Get the policy
var p *policy
var p *keysutil.Policy
var lock *sync.RWMutex
var upserted bool
if req.Operation == logical.CreateOperation {
@ -125,17 +126,17 @@ func (b *backend) pathEncryptWrite(
return logical.ErrorResponse("convergent encryption requires derivation to be enabled, so context is required"), nil
}
polReq := policyRequest{
storage: req.Storage,
name: name,
derived: len(context) != 0,
convergent: convergent,
polReq := keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
Derived: len(context) != 0,
Convergent: convergent,
}
keyType := d.Get("type").(string)
switch keyType {
case "aes256-gcm96":
polReq.keyType = keyType_AES256_GCM96
polReq.KeyType = keysutil.KeyType_AES256_GCM96
case "ecdsa-p256":
return logical.ErrorResponse(fmt.Sprintf("key type %v not supported for this operation", keyType)), logical.ErrInvalidRequest
default:

View file

@ -124,7 +124,7 @@ func TestTransit_HMAC(t *testing.T) {
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
// Rotate
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}

View file

@ -5,10 +5,24 @@ import (
"fmt"
"strconv"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func (b *backend) pathListKeys() *framework.Path {
return &framework.Path{
Pattern: "keys/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathKeysList,
},
HelpSynopsis: pathPolicyHelpSyn,
HelpDescription: pathPolicyHelpDesc,
}
}
func (b *backend) pathKeys() *framework.Path {
return &framework.Path{
Pattern: "keys/" + framework.GenericNameRegex("name"),
@ -61,6 +75,16 @@ impact the ciphertext's security.`,
}
}
func (b *backend) pathKeysList(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("policy/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func (b *backend) pathPolicyWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
@ -72,17 +96,17 @@ func (b *backend) pathPolicyWrite(
return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil
}
polReq := policyRequest{
storage: req.Storage,
name: name,
derived: derived,
convergent: convergent,
polReq := keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
Derived: derived,
Convergent: convergent,
}
switch keyType {
case "aes256-gcm96":
polReq.keyType = keyType_AES256_GCM96
polReq.KeyType = keysutil.KeyType_AES256_GCM96
case "ecdsa-p256":
polReq.keyType = keyType_ECDSA_P256
polReq.KeyType = keysutil.KeyType_ECDSA_P256
default:
return logical.ErrorResponse(fmt.Sprintf("unknown key type %v", keyType)), logical.ErrInvalidRequest
}
@ -135,10 +159,10 @@ func (b *backend) pathPolicyRead(
if p.Derived {
switch p.KDF {
case kdf_hmac_sha256_counter:
case keysutil.Kdf_hmac_sha256_counter:
resp.Data["kdf"] = "hmac-sha256-counter"
resp.Data["kdf_mode"] = "hmac-sha256-counter"
case kdf_hkdf_sha256:
case keysutil.Kdf_hkdf_sha256:
resp.Data["kdf"] = "hkdf_sha256"
}
resp.Data["convergent_encryption"] = p.ConvergentEncryption
@ -148,14 +172,14 @@ func (b *backend) pathPolicyRead(
}
switch p.Type {
case keyType_AES256_GCM96:
case keysutil.KeyType_AES256_GCM96:
retKeys := map[string]int64{}
for k, v := range p.Keys {
retKeys[strconv.Itoa(k)] = v.CreationTime
}
resp.Data["keys"] = retKeys
case keyType_ECDSA_P256:
case keysutil.KeyType_ECDSA_P256:
type ecdsaKey struct {
Name string `json:"name"`
PublicKey string `json:"public_key"`

View file

@ -41,7 +41,7 @@ func (b *backend) pathRotateWrite(
}
// Rotate the policy
err = p.rotate(req.Storage)
err = p.Rotate(req.Storage)
return nil, err
}

View file

@ -177,11 +177,11 @@ func TestTransit_SignVerify(t *testing.T) {
signRequest(req, true, "")
// Rotate and set min decryption version
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}

View file

@ -820,6 +820,8 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
cfg.CheckManager.Check.ForceMetricActivation = telConfig.CirconusCheckForceMetricActivation
cfg.CheckManager.Check.InstanceID = telConfig.CirconusCheckInstanceID
cfg.CheckManager.Check.SearchTag = telConfig.CirconusCheckSearchTag
cfg.CheckManager.Check.DisplayName = telConfig.CirconusCheckDisplayName
cfg.CheckManager.Check.Tags = telConfig.CirconusCheckTags
cfg.CheckManager.Broker.ID = telConfig.CirconusBrokerID
cfg.CheckManager.Broker.SelectTag = telConfig.CirconusBrokerSelectTag

View file

@ -149,6 +149,13 @@ type Telemetry struct {
// narrow down the search results when neither a Submission URL or Check ID is provided.
// Default: service:app (e.g. service:consul)
CirconusCheckSearchTag string `hcl:"circonus_check_search_tag"`
// CirconusCheckTags is a comma separated list of tags to apply to the check. Note that
// the value of CirconusCheckSearchTag will always be added to the check.
// Default: none
CirconusCheckTags string `mapstructure:"circonus_check_tags"`
// CirconusCheckDisplayName is the name for the check which will be displayed in the Circonus UI.
// Default: value of CirconusCheckInstanceID
CirconusCheckDisplayName string `mapstructure:"circonus_check_display_name"`
// CirconusBrokerID is an explicit broker to use when creating a new check. The numeric portion
// of broker._cid. If metric management is enabled and neither a Submission URL nor Check ID
// is provided, an attempt will be made to search for an existing check using Instance ID and
@ -597,6 +604,8 @@ func parseTelemetry(result *Config, list *ast.ObjectList) error {
"circonus_check_force_metric_activation",
"circonus_check_instance_id",
"circonus_check_search_tag",
"circonus_check_display_name",
"circonus_check_tags",
"circonus_broker_id",
"circonus_broker_select_tag",
"disable_hostname",

View file

@ -122,6 +122,8 @@ func TestLoadConfigFile_json(t *testing.T) {
CirconusCheckForceMetricActivation: "",
CirconusCheckInstanceID: "",
CirconusCheckSearchTag: "",
CirconusCheckDisplayName: "",
CirconusCheckTags: "",
CirconusBrokerID: "",
CirconusBrokerSelectTag: "",
},
@ -191,6 +193,8 @@ func TestLoadConfigFile_json2(t *testing.T) {
CirconusCheckForceMetricActivation: "true",
CirconusCheckInstanceID: "node1:vault",
CirconusCheckSearchTag: "service:vault",
CirconusCheckDisplayName: "node1:vault",
CirconusCheckTags: "cat1:tag1,cat2:tag2",
CirconusBrokerID: "0",
CirconusBrokerSelectTag: "dc:sfo",
},

View file

@ -3,18 +3,26 @@ package server
import (
"io"
"net"
"strings"
"time"
"github.com/hashicorp/vault/vault"
)
func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) {
bind_proto := "tcp"
addr, ok := config["address"]
if !ok {
addr = "127.0.0.1:8200"
}
ln, err := net.Listen("tcp", addr)
// If they've passed 0.0.0.0, we only want to bind on IPv4
// rather than golang's dual stack default
if strings.HasPrefix(addr, "0.0.0.0:") {
bind_proto = "tcp4"
}
ln, err := net.Listen(bind_proto, addr)
if err != nil {
return nil, nil, nil, err
}

View file

@ -36,6 +36,8 @@
"circonus_check_force_metric_activation": "true",
"circonus_check_instance_id": "node1:vault",
"circonus_check_search_tag": "service:vault",
"circonus_check_display_name": "node1:vault",
"circonus_check_tags": "cat1:tag1,cat2:tag2",
"circonus_broker_id": "0",
"circonus_broker_select_tag": "dc:sfo"
}

View file

@ -34,7 +34,6 @@ func (c *SSHCommand) Run(args []string) int {
var role, mountPoint, format, userKnownHostsFile, strictHostKeyChecking string
var noExec bool
var sshCmdArgs []string
var sshDynamicKeyFileName string
flags := c.Meta.FlagSet("ssh", meta.FlagSetDefault)
flags.StringVar(&strictHostKeyChecking, "strict-host-key-checking", "", "")
flags.StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "")
@ -76,7 +75,7 @@ func (c *SSHCommand) Run(args []string) int {
client, err := c.Client()
if err != nil {
c.Ui.Error(fmt.Sprintf("Error initializing client: %s", err))
c.Ui.Error(fmt.Sprintf("Error initializing client: %v", err))
return 1
}
@ -92,7 +91,7 @@ func (c *SSHCommand) Run(args []string) int {
if len(input) == 1 {
u, err := user.Current()
if err != nil {
c.Ui.Error(fmt.Sprintf("Error fetching username: %s", err))
c.Ui.Error(fmt.Sprintf("Error fetching username: %v", err))
return 1
}
username = u.Username
@ -101,7 +100,7 @@ func (c *SSHCommand) Run(args []string) int {
username = input[0]
ipAddr = input[1]
} else {
c.Ui.Error(fmt.Sprintf("Invalid parameter: %s", args[0]))
c.Ui.Error(fmt.Sprintf("Invalid parameter: %q", args[0]))
return 1
}
@ -109,7 +108,7 @@ func (c *SSHCommand) Run(args []string) int {
// Vault only deals with IP addresses.
ip, err := net.ResolveIPAddr("ip", ipAddr)
if err != nil {
c.Ui.Error(fmt.Sprintf("Error resolving IP Address: %s", err))
c.Ui.Error(fmt.Sprintf("Error resolving IP Address: %v", err))
return 1
}
@ -120,14 +119,14 @@ func (c *SSHCommand) Run(args []string) int {
if role == "" {
role, err = c.defaultRole(mountPoint, ip.String())
if err != nil {
c.Ui.Error(fmt.Sprintf("Error choosing role: %s", err))
c.Ui.Error(fmt.Sprintf("Error choosing role: %v", err))
return 1
}
// Print the default role chosen so that user knows the role name
// if something doesn't work. If the role chosen is not allowed to
// be used by the user (ACL enforcement), then user should see an
// error message accordingly.
c.Ui.Output(fmt.Sprintf("Vault SSH: Role: %s", role))
c.Ui.Output(fmt.Sprintf("Vault SSH: Role: %q", role))
}
data := map[string]interface{}{
@ -137,7 +136,7 @@ func (c *SSHCommand) Run(args []string) int {
keySecret, err := client.SSHWithMountPoint(mountPoint).Credential(role, data)
if err != nil {
c.Ui.Error(fmt.Sprintf("Error getting key for SSH session:%s", err))
c.Ui.Error(fmt.Sprintf("Error getting key for SSH session: %v", err))
return 1
}
@ -152,7 +151,7 @@ func (c *SSHCommand) Run(args []string) int {
}
var resp SSHCredentialResp
if err := mapstructure.Decode(keySecret.Data, &resp); err != nil {
c.Ui.Error(fmt.Sprintf("Error parsing the credential response:%s", err))
c.Ui.Error(fmt.Sprintf("Error parsing the credential response: %v", err))
return 1
}
@ -161,9 +160,21 @@ func (c *SSHCommand) Run(args []string) int {
c.Ui.Error(fmt.Sprintf("Invalid key"))
return 1
}
sshDynamicKeyFileName = fmt.Sprintf("vault_ssh_%s_%s", username, ip.String())
err = ioutil.WriteFile(sshDynamicKeyFileName, []byte(resp.Key), 0600)
sshCmdArgs = append(sshCmdArgs, []string{"-i", sshDynamicKeyFileName}...)
sshDynamicKeyFile, err := ioutil.TempFile("", fmt.Sprintf("vault_ssh_%s_%s_", username, ip.String()))
if err != nil {
c.Ui.Error(fmt.Sprintf("Error creating temporary file: %v", err))
return 1
}
// Ensure that we delete the temporary file
defer os.Remove(sshDynamicKeyFile.Name())
if err = ioutil.WriteFile(sshDynamicKeyFile.Name(),
[]byte(resp.Key), 0600); err != nil {
c.Ui.Error(fmt.Sprintf("Error storing the dynamic key into the temporary file: %v", err))
return 1
}
sshCmdArgs = append(sshCmdArgs, []string{"-i", sshDynamicKeyFile.Name()}...)
} else if resp.KeyType == ssh.KeyTypeOTP {
// Check if the application 'sshpass' is installed in the client machine.
@ -182,7 +193,7 @@ func (c *SSHCommand) Run(args []string) int {
sshCmd.Stdout = os.Stdout
err = sshCmd.Run()
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to establish SSH connection:%s", err))
c.Ui.Error(fmt.Sprintf("Failed to establish SSH connection: %q", err))
}
return 0
}
@ -204,15 +215,7 @@ func (c *SSHCommand) Run(args []string) int {
// to establish an independent session like this.
err = sshCmd.Run()
if err != nil {
c.Ui.Error(fmt.Sprintf("Error while running ssh command:%s", err))
}
// Delete the temporary key file generated by the command.
if resp.KeyType == ssh.KeyTypeDynamic {
// Ignoring the error from the below call since it is not a security
// issue if the deletion of file is not successful. User is authorized
// to have this secret.
os.Remove(sshDynamicKeyFileName)
c.Ui.Error(fmt.Sprintf("Error while running ssh command: %q", err))
}
// If the session established was longer than the lease expiry, the secret
@ -222,7 +225,7 @@ func (c *SSHCommand) Run(args []string) int {
// is run, a fresh credential is generated anyways.
err = client.Sys().Revoke(keySecret.LeaseID)
if err != nil {
c.Ui.Error(fmt.Sprintf("Error revoking the key: %s", err))
c.Ui.Error(fmt.Sprintf("Error revoking the key: %q", err))
}
return 0
@ -241,15 +244,15 @@ func (c *SSHCommand) defaultRole(mountPoint, ip string) (string, error) {
}
secret, err := client.Logical().Write(mountPoint+"/lookup", data)
if err != nil {
return "", fmt.Errorf("Error finding roles for IP %s: %s", ip, err)
return "", fmt.Errorf("Error finding roles for IP %q: %q", ip, err)
}
if secret == nil {
return "", fmt.Errorf("Error finding roles for IP %s: %s", ip, err)
return "", fmt.Errorf("Error finding roles for IP %q: %q", ip, err)
}
if secret.Data["roles"] == nil {
return "", fmt.Errorf("No matching roles found for IP %s", ip)
return "", fmt.Errorf("No matching roles found for IP %q", ip)
}
if len(secret.Data["roles"].([]interface{})) == 1 {
@ -260,7 +263,7 @@ func (c *SSHCommand) defaultRole(mountPoint, ip string) (string, error) {
roleNames += item.(string) + ", "
}
roleNames = strings.TrimRight(roleNames, ", ")
return "", fmt.Errorf("Roles:[%s]"+`
return "", fmt.Errorf("Roles:%q. "+`
Multiple roles are registered for this IP.
Select a role using '-role' option.
Note that all roles may not be permitted, based on ACLs.`, roleNames)

View file

@ -49,6 +49,13 @@ func (m *Request) String() string { return proto.CompactTextString(m)
func (*Request) ProtoMessage() {}
func (*Request) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *Request) GetMethod() string {
if m != nil {
return m.Method
}
return ""
}
func (m *Request) GetUrl() *URL {
if m != nil {
return m.Url
@ -63,6 +70,34 @@ func (m *Request) GetHeaderEntries() map[string]*HeaderEntry {
return nil
}
func (m *Request) GetBody() []byte {
if m != nil {
return m.Body
}
return nil
}
func (m *Request) GetHost() string {
if m != nil {
return m.Host
}
return ""
}
func (m *Request) GetRemoteAddr() string {
if m != nil {
return m.RemoteAddr
}
return ""
}
func (m *Request) GetPeerCertificates() [][]byte {
if m != nil {
return m.PeerCertificates
}
return nil
}
type URL struct {
Scheme string `protobuf:"bytes,1,opt,name=scheme" json:"scheme,omitempty"`
Opaque string `protobuf:"bytes,2,opt,name=opaque" json:"opaque,omitempty"`
@ -83,6 +118,55 @@ func (m *URL) String() string { return proto.CompactTextString(m) }
func (*URL) ProtoMessage() {}
func (*URL) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *URL) GetScheme() string {
if m != nil {
return m.Scheme
}
return ""
}
func (m *URL) GetOpaque() string {
if m != nil {
return m.Opaque
}
return ""
}
func (m *URL) GetHost() string {
if m != nil {
return m.Host
}
return ""
}
func (m *URL) GetPath() string {
if m != nil {
return m.Path
}
return ""
}
func (m *URL) GetRawPath() string {
if m != nil {
return m.RawPath
}
return ""
}
func (m *URL) GetRawQuery() string {
if m != nil {
return m.RawQuery
}
return ""
}
func (m *URL) GetFragment() string {
if m != nil {
return m.Fragment
}
return ""
}
type HeaderEntry struct {
Values []string `protobuf:"bytes,1,rep,name=values" json:"values,omitempty"`
}
@ -92,6 +176,13 @@ func (m *HeaderEntry) String() string { return proto.CompactTextStrin
func (*HeaderEntry) ProtoMessage() {}
func (*HeaderEntry) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *HeaderEntry) GetValues() []string {
if m != nil {
return m.Values
}
return nil
}
type Response struct {
// Not used right now but reserving in case it turns out that streaming
// makes things more economical on the gRPC side
@ -108,6 +199,20 @@ func (m *Response) String() string { return proto.CompactTextString(m
func (*Response) ProtoMessage() {}
func (*Response) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *Response) GetStatusCode() uint32 {
if m != nil {
return m.StatusCode
}
return 0
}
func (m *Response) GetBody() []byte {
if m != nil {
return m.Body
}
return nil
}
func (m *Response) GetHeaderEntries() map[string]*HeaderEntry {
if m != nil {
return m.HeaderEntries

View file

@ -1,4 +1,4 @@
package transit
package keysutil
import (
"errors"
@ -18,29 +18,29 @@ var (
errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation")
)
// policyRequest holds values used when requesting a policy. Most values are
// PolicyRequest holds values used when requesting a policy. Most values are
// only used during an upsert.
type policyRequest struct {
type PolicyRequest struct {
// The storage to use
storage logical.Storage
Storage logical.Storage
// The name of the policy
name string
Name string
// The key type
keyType KeyType
KeyType KeyType
// Whether it should be derived
derived bool
Derived bool
// Whether to enable convergent encryption
convergent bool
Convergent bool
// Whether to upsert
upsert bool
Upsert bool
}
type lockManager struct {
type LockManager struct {
// A lock for each named key
locks map[string]*sync.RWMutex
@ -48,27 +48,27 @@ type lockManager struct {
locksMutex sync.RWMutex
// If caching is enabled, the map of name to in-memory policy cache
cache map[string]*policy
cache map[string]*Policy
// Used for global locking, and as the cache map mutex
cacheMutex sync.RWMutex
}
func newLockManager(cacheDisabled bool) *lockManager {
lm := &lockManager{
func NewLockManager(cacheDisabled bool) *LockManager {
lm := &LockManager{
locks: map[string]*sync.RWMutex{},
}
if !cacheDisabled {
lm.cache = map[string]*policy{}
lm.cache = map[string]*Policy{}
}
return lm
}
func (lm *lockManager) CacheActive() bool {
func (lm *LockManager) CacheActive() bool {
return lm.cache != nil
}
func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex {
func (lm *LockManager) policyLock(name string, lockType bool) *sync.RWMutex {
lm.locksMutex.RLock()
lock := lm.locks[name]
if lock != nil {
@ -115,7 +115,7 @@ func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex {
return lock
}
func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
func (lm *LockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
if lockType == exclusive {
lock.Unlock()
} else {
@ -126,10 +126,10 @@ func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
// Get the policy with a read lock. If we get an error saying an exclusive lock
// is needed (for instance, for an upgrade/migration), give up the read lock,
// call again with an exclusive lock, then swap back out for a read lock.
func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(policyRequest{
storage: storage,
name: name,
func (lm *LockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(PolicyRequest{
Storage: storage,
Name: name,
}, shared)
if err == nil ||
(err != nil && err != errNeedExclusiveLock) {
@ -137,9 +137,9 @@ func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*p
}
// Try again while asking for an exlusive lock
p, lock, _, err = lm.getPolicyCommon(policyRequest{
storage: storage,
name: name,
p, lock, _, err = lm.getPolicyCommon(PolicyRequest{
Storage: storage,
Name: name,
}, exclusive)
if err != nil || p == nil || lock == nil {
return p, lock, err
@ -147,18 +147,18 @@ func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*p
lock.Unlock()
p, lock, _, err = lm.getPolicyCommon(policyRequest{
storage: storage,
name: name,
p, lock, _, err = lm.getPolicyCommon(PolicyRequest{
Storage: storage,
Name: name,
}, shared)
return p, lock, err
}
// Get the policy with an exclusive lock
func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(policyRequest{
storage: storage,
name: name,
func (lm *LockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(PolicyRequest{
Storage: storage,
Name: name,
}, exclusive)
return p, lock, err
}
@ -166,8 +166,8 @@ func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string)
// Get the policy with a read lock; if it returns that an exclusive lock is
// needed, retry. If successful, call one more time to get a read lock and
// return the value.
func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMutex, bool, error) {
req.upsert = true
func (lm *LockManager) GetPolicyUpsert(req PolicyRequest) (*Policy, *sync.RWMutex, bool, error) {
req.Upsert = true
p, lock, _, err := lm.getPolicyCommon(req, shared)
if err == nil ||
@ -182,7 +182,7 @@ func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMute
}
lock.Unlock()
req.upsert = false
req.Upsert = false
// Now get a shared lock for the return, but preserve the value of upserted
p, lock, _, err = lm.getPolicyCommon(req, shared)
@ -191,16 +191,16 @@ func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMute
// When the function returns, a lock will be held on the policy if err == nil.
// It is the caller's responsibility to unlock.
func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*policy, *sync.RWMutex, bool, error) {
lock := lm.policyLock(req.name, lockType)
func (lm *LockManager) getPolicyCommon(req PolicyRequest, lockType bool) (*Policy, *sync.RWMutex, bool, error) {
lock := lm.policyLock(req.Name, lockType)
var p *policy
var p *Policy
var err error
// Check if it's in our cache. If so, return right away.
if lm.CacheActive() {
lm.cacheMutex.RLock()
p = lm.cache[req.name]
p = lm.cache[req.Name]
if p != nil {
lm.cacheMutex.RUnlock()
return p, lock, false, nil
@ -209,7 +209,7 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
}
// Load it from storage
p, err = lm.getStoredPolicy(req.storage, req.name)
p, err = lm.getStoredPolicy(req.Storage, req.Name)
if err != nil {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err
@ -218,7 +218,7 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
if p == nil {
// This is the only place we upsert a new policy, so if upsert is not
// specified, or the lock type is wrong, unlock before returning
if !req.upsert {
if !req.Upsert {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, nil
}
@ -228,33 +228,33 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
return nil, nil, false, errNeedExclusiveLock
}
switch req.keyType {
case keyType_AES256_GCM96:
if req.convergent && !req.derived {
switch req.KeyType {
case KeyType_AES256_GCM96:
if req.Convergent && !req.Derived {
return nil, nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled")
}
case keyType_ECDSA_P256:
if req.derived || req.convergent {
return nil, nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %s", keyType_ECDSA_P256)
case KeyType_ECDSA_P256:
if req.Derived || req.Convergent {
return nil, nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %s", KeyType_ECDSA_P256)
}
default:
return nil, nil, false, fmt.Errorf("unsupported key type %v", req.keyType)
return nil, nil, false, fmt.Errorf("unsupported key type %v", req.KeyType)
}
p = &policy{
Name: req.name,
Type: req.keyType,
Derived: req.derived,
p = &Policy{
Name: req.Name,
Type: req.KeyType,
Derived: req.Derived,
}
if req.derived {
p.KDF = kdf_hkdf_sha256
p.ConvergentEncryption = req.convergent
if req.Derived {
p.KDF = Kdf_hkdf_sha256
p.ConvergentEncryption = req.Convergent
p.ConvergentVersion = 2
}
err = p.rotate(req.storage)
err = p.Rotate(req.Storage)
if err != nil {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err
@ -267,12 +267,12 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
defer lm.cacheMutex.Unlock()
// Make sure a policy didn't appear. If so, it will only be set if
// there was no error, so assume it's good and return that
exp := lm.cache[req.name]
exp := lm.cache[req.Name]
if exp != nil {
return exp, lock, false, nil
}
if err == nil {
lm.cache[req.name] = p
lm.cache[req.Name] = p
}
}
@ -280,13 +280,13 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
return p, lock, true, nil
}
if p.needsUpgrade() {
if p.NeedsUpgrade() {
if lockType == shared {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, errNeedExclusiveLock
}
err = p.upgrade(req.storage)
err = p.Upgrade(req.Storage)
if err != nil {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err
@ -300,25 +300,25 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
defer lm.cacheMutex.Unlock()
// Make sure a policy didn't appear. If so, it will only be set if
// there was no error, so assume it's good and return that
exp := lm.cache[req.name]
exp := lm.cache[req.Name]
if exp != nil {
return exp, lock, false, nil
}
if err == nil {
lm.cache[req.name] = p
lm.cache[req.Name] = p
}
}
return p, lock, false, nil
}
func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error {
func (lm *LockManager) DeletePolicy(storage logical.Storage, name string) error {
lm.cacheMutex.Lock()
lock := lm.policyLock(name, exclusive)
defer lock.Unlock()
defer lm.cacheMutex.Unlock()
var p *policy
var p *Policy
var err error
if lm.CacheActive() {
@ -355,7 +355,7 @@ func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error
return nil
}
func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*policy, error) {
func (lm *LockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) {
// Check if the policy already exists
raw, err := storage.Get("policy/" + name)
if err != nil {
@ -366,7 +366,7 @@ func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*p
}
// Decode the policy
policy := &policy{
policy := &Policy{
Keys: keyEntryMap{},
}
err = jsonutil.DecodeJSON(raw.Value, policy)

View file

@ -1,4 +1,4 @@
package transit
package keysutil
import (
"bytes"
@ -33,14 +33,14 @@ import (
// Careful with iota; don't put anything before it in this const block because
// we need the default of zero to be the old-style KDF
const (
kdf_hmac_sha256_counter = iota // built-in helper
kdf_hkdf_sha256 // golang.org/x/crypto/hkdf
Kdf_hmac_sha256_counter = iota // built-in helper
Kdf_hkdf_sha256 // golang.org/x/crypto/hkdf
)
// Or this one...we need the default of zero to be the original AES256-GCM96
const (
keyType_AES256_GCM96 = iota
keyType_ECDSA_P256
KeyType_AES256_GCM96 = iota
KeyType_ECDSA_P256
)
const ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)"
@ -53,7 +53,7 @@ type KeyType int
func (kt KeyType) EncryptionSupported() bool {
switch kt {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
return true
}
return false
@ -61,7 +61,7 @@ func (kt KeyType) EncryptionSupported() bool {
func (kt KeyType) DecryptionSupported() bool {
switch kt {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
return true
}
return false
@ -69,7 +69,7 @@ func (kt KeyType) DecryptionSupported() bool {
func (kt KeyType) SigningSupported() bool {
switch kt {
case keyType_ECDSA_P256:
case KeyType_ECDSA_P256:
return true
}
return false
@ -77,7 +77,7 @@ func (kt KeyType) SigningSupported() bool {
func (kt KeyType) DerivationSupported() bool {
switch kt {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
return true
}
return false
@ -85,17 +85,17 @@ func (kt KeyType) DerivationSupported() bool {
func (kt KeyType) String() string {
switch kt {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
return "aes256-gcm96"
case keyType_ECDSA_P256:
case KeyType_ECDSA_P256:
return "ecdsa-p256"
}
return "[unknown]"
}
// keyEntry stores the key and metadata
type keyEntry struct {
// KeyEntry stores the key and metadata
type KeyEntry struct {
AESKey []byte `json:"key"`
HMACKey []byte `json:"hmac_key"`
CreationTime int64 `json:"creation_time"`
@ -106,11 +106,11 @@ type keyEntry struct {
}
// keyEntryMap is used to allow JSON marshal/unmarshal
type keyEntryMap map[int]keyEntry
type keyEntryMap map[int]KeyEntry
// MarshalJSON implements JSON marshaling
func (kem keyEntryMap) MarshalJSON() ([]byte, error) {
intermediate := map[string]keyEntry{}
intermediate := map[string]KeyEntry{}
for k, v := range kem {
intermediate[strconv.Itoa(k)] = v
}
@ -119,7 +119,7 @@ func (kem keyEntryMap) MarshalJSON() ([]byte, error) {
// MarshalJSON implements JSON unmarshaling
func (kem keyEntryMap) UnmarshalJSON(data []byte) error {
intermediate := map[string]keyEntry{}
intermediate := map[string]KeyEntry{}
if err := jsonutil.DecodeJSON(data, &intermediate); err != nil {
return err
}
@ -135,7 +135,7 @@ func (kem keyEntryMap) UnmarshalJSON(data []byte) error {
}
// Policy is the struct used to store metadata
type policy struct {
type Policy struct {
Name string `json:"name"`
Key []byte `json:"key,omitempty"` //DEPRECATED
Keys keyEntryMap `json:"keys"`
@ -171,10 +171,10 @@ type policy struct {
// ArchivedKeys stores old keys. This is used to keep the key loading time sane
// when there are huge numbers of rotations.
type archivedKeys struct {
Keys []keyEntry `json:"keys"`
Keys []KeyEntry `json:"keys"`
}
func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) {
func (p *Policy) LoadArchive(storage logical.Storage) (*archivedKeys, error) {
archive := &archivedKeys{}
raw, err := storage.Get("archive/" + p.Name)
@ -182,7 +182,7 @@ func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) {
return nil, err
}
if raw == nil {
archive.Keys = make([]keyEntry, 0)
archive.Keys = make([]KeyEntry, 0)
return archive, nil
}
@ -193,7 +193,7 @@ func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) {
return archive, nil
}
func (p *policy) storeArchive(archive *archivedKeys, storage logical.Storage) error {
func (p *Policy) storeArchive(archive *archivedKeys, storage logical.Storage) error {
// Encode the policy
buf, err := json.Marshal(archive)
if err != nil {
@ -215,7 +215,7 @@ func (p *policy) storeArchive(archive *archivedKeys, storage logical.Storage) er
// handleArchiving manages the movement of keys to and from the policy archive.
// This should *ONLY* be called from Persist() since it assumes that the policy
// will be persisted afterwards.
func (p *policy) handleArchiving(storage logical.Storage) error {
func (p *Policy) handleArchiving(storage logical.Storage) error {
// We need to move keys that are no longer accessible to archivedKeys, and keys
// that now need to be accessible back here.
//
@ -241,7 +241,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
p.MinDecryptionVersion, p.LatestVersion)
}
archive, err := p.loadArchive(storage)
archive, err := p.LoadArchive(storage)
if err != nil {
return err
}
@ -263,7 +263,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
// key version
if len(archive.Keys) < p.LatestVersion+1 {
// Increase the size of the archive slice
newKeys := make([]keyEntry, p.LatestVersion+1)
newKeys := make([]KeyEntry, p.LatestVersion+1)
copy(newKeys, archive.Keys)
archive.Keys = newKeys
}
@ -289,7 +289,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
return nil
}
func (p *policy) Persist(storage logical.Storage) error {
func (p *Policy) Persist(storage logical.Storage) error {
err := p.handleArchiving(storage)
if err != nil {
return err
@ -313,11 +313,11 @@ func (p *policy) Persist(storage logical.Storage) error {
return nil
}
func (p *policy) Serialize() ([]byte, error) {
func (p *Policy) Serialize() ([]byte, error) {
return json.Marshal(p)
}
func (p *policy) needsUpgrade() bool {
func (p *Policy) NeedsUpgrade() bool {
// Ensure we've moved from Key -> Keys
if p.Key != nil && len(p.Key) > 0 {
return true
@ -352,11 +352,11 @@ func (p *policy) needsUpgrade() bool {
return false
}
func (p *policy) upgrade(storage logical.Storage) error {
func (p *Policy) Upgrade(storage logical.Storage) error {
persistNeeded := false
// Ensure we've moved from Key -> Keys
if p.Key != nil && len(p.Key) > 0 {
p.migrateKeyToKeysMap()
p.MigrateKeyToKeysMap()
persistNeeded = true
}
@ -409,7 +409,7 @@ func (p *policy) upgrade(storage logical.Storage) error {
// on the policy. If derivation is disabled the raw key is used and no context
// is required, otherwise the KDF mode is used with the context to derive the
// proper key.
func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) {
func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
if !p.Type.DerivationSupported() {
return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)}
}
@ -433,11 +433,11 @@ func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) {
}
switch p.KDF {
case kdf_hmac_sha256_counter:
case Kdf_hmac_sha256_counter:
prf := kdf.HMACSHA256PRF
prfLen := kdf.HMACSHA256PRFLen
return kdf.CounterMode(prf, prfLen, p.Keys[ver].AESKey, context, 256)
case kdf_hkdf_sha256:
case Kdf_hkdf_sha256:
reader := hkdf.New(sha256.New, p.Keys[ver].AESKey, nil, context)
derBytes := bytes.NewBuffer(nil)
derBytes.Grow(32)
@ -458,14 +458,14 @@ func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) {
}
}
func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) {
func (p *Policy) Encrypt(context, nonce []byte, value string) (string, error) {
if !p.Type.EncryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)}
}
// Guard against a potentially invalid key type
switch p.Type {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
@ -484,7 +484,7 @@ func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) {
// Guard against a potentially invalid key type
switch p.Type {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
@ -539,7 +539,7 @@ func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) {
return encoded, nil
}
func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) {
func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
if !p.Type.DecryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)}
}
@ -585,7 +585,7 @@ func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) {
// Guard against a potentially invalid key type
switch p.Type {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
@ -626,7 +626,7 @@ func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) {
return base64.StdEncoding.EncodeToString(plain), nil
}
func (p *policy) HMACKey(version int) ([]byte, error) {
func (p *Policy) HMACKey(version int) ([]byte, error) {
if version < p.MinDecryptionVersion {
return nil, fmt.Errorf("key version disallowed by policy (minimum is %d)", p.MinDecryptionVersion)
}
@ -642,14 +642,14 @@ func (p *policy) HMACKey(version int) ([]byte, error) {
return p.Keys[version].HMACKey, nil
}
func (p *policy) Sign(hashedInput []byte) (string, error) {
func (p *Policy) Sign(hashedInput []byte) (string, error) {
if !p.Type.SigningSupported() {
return "", fmt.Errorf("message signing not supported for key type %v", p.Type)
}
var sig []byte
switch p.Type {
case keyType_ECDSA_P256:
case KeyType_ECDSA_P256:
keyParams := p.Keys[p.LatestVersion]
key := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
@ -685,7 +685,7 @@ func (p *policy) Sign(hashedInput []byte) (string, error) {
return encoded, nil
}
func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
func (p *Policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
if !p.Type.SigningSupported() {
return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)}
}
@ -714,7 +714,7 @@ func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
}
switch p.Type {
case keyType_ECDSA_P256:
case KeyType_ECDSA_P256:
asn1Sig, err := base64.StdEncoding.DecodeString(splitVerSig[1])
if err != nil {
return false, errutil.UserError{Err: "invalid base64 signature value"}
@ -744,7 +744,7 @@ func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
return false, errutil.InternalError{Err: "no valid key type found"}
}
func (p *policy) rotate(storage logical.Storage) error {
func (p *Policy) Rotate(storage logical.Storage) error {
if p.Keys == nil {
// This is an initial key rotation when generating a new policy. We
// don't need to call migrate here because if we've called getPolicy to
@ -753,7 +753,7 @@ func (p *policy) rotate(storage logical.Storage) error {
}
p.LatestVersion += 1
entry := keyEntry{
entry := KeyEntry{
CreationTime: time.Now().Unix(),
}
@ -764,7 +764,7 @@ func (p *policy) rotate(storage logical.Storage) error {
entry.HMACKey = hmacKey
switch p.Type {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
// Generate a 256bit key
newKey, err := uuid.GenerateRandomBytes(32)
if err != nil {
@ -772,7 +772,7 @@ func (p *policy) rotate(storage logical.Storage) error {
}
entry.AESKey = newKey
case keyType_ECDSA_P256:
case KeyType_ECDSA_P256:
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
@ -807,9 +807,9 @@ func (p *policy) rotate(storage logical.Storage) error {
return p.Persist(storage)
}
func (p *policy) migrateKeyToKeysMap() {
func (p *Policy) MigrateKeyToKeysMap() {
p.Keys = keyEntryMap{
1: keyEntry{
1: KeyEntry{
AESKey: p.Key,
CreationTime: time.Now().Unix(),
},

View file

@ -1,4 +1,4 @@
package transit
package keysutil
import (
"reflect"
@ -8,24 +8,24 @@ import (
)
var (
keysArchive []keyEntry
keysArchive []KeyEntry
)
func resetKeysArchive() {
keysArchive = []keyEntry{keyEntry{}}
keysArchive = []KeyEntry{KeyEntry{}}
}
func Test_KeyUpgrade(t *testing.T) {
testKeyUpgradeCommon(t, newLockManager(false))
testKeyUpgradeCommon(t, newLockManager(true))
testKeyUpgradeCommon(t, NewLockManager(false))
testKeyUpgradeCommon(t, NewLockManager(true))
}
func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
func testKeyUpgradeCommon(t *testing.T, lm *LockManager) {
storage := &logical.InmemStorage{}
p, lock, upserted, err := lm.GetPolicyUpsert(policyRequest{
storage: storage,
keyType: keyType_AES256_GCM96,
name: "test",
p, lock, upserted, err := lm.GetPolicyUpsert(PolicyRequest{
Storage: storage,
KeyType: KeyType_AES256_GCM96,
Name: "test",
})
if lock != nil {
defer lock.RUnlock()
@ -45,7 +45,7 @@ func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
p.Key = p.Keys[1].AESKey
p.Keys = nil
p.migrateKeyToKeysMap()
p.MigrateKeyToKeysMap()
if p.Key != nil {
t.Fatal("policy.Key is not nil")
}
@ -58,11 +58,11 @@ func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
}
func Test_ArchivingUpgrade(t *testing.T) {
testArchivingUpgradeCommon(t, newLockManager(false))
testArchivingUpgradeCommon(t, newLockManager(true))
testArchivingUpgradeCommon(t, NewLockManager(false))
testArchivingUpgradeCommon(t, NewLockManager(true))
}
func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
func testArchivingUpgradeCommon(t *testing.T, lm *LockManager) {
resetKeysArchive()
// First, we generate a policy and rotate it a number of times. Each time
@ -71,10 +71,10 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
// zero and latest, respectively
storage := &logical.InmemStorage{}
p, lock, _, err := lm.GetPolicyUpsert(policyRequest{
storage: storage,
keyType: keyType_AES256_GCM96,
name: "test",
p, lock, _, err := lm.GetPolicyUpsert(PolicyRequest{
Storage: storage,
KeyType: KeyType_AES256_GCM96,
Name: "test",
})
if err != nil {
t.Fatal(err)
@ -89,7 +89,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
checkKeys(t, p, storage, "initial", 1, 1, 1)
for i := 2; i <= 10; i++ {
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
@ -191,11 +191,11 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
}
func Test_Archiving(t *testing.T) {
testArchivingCommon(t, newLockManager(false))
testArchivingCommon(t, newLockManager(true))
testArchivingCommon(t, NewLockManager(false))
testArchivingCommon(t, NewLockManager(true))
}
func testArchivingCommon(t *testing.T, lm *lockManager) {
func testArchivingCommon(t *testing.T, lm *LockManager) {
resetKeysArchive()
// First, we generate a policy and rotate it a number of times. Each time // we'll ensure that we have the expected number of keys in the archive and
@ -203,10 +203,10 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
// zero and latest, respectively
storage := &logical.InmemStorage{}
p, lock, _, err := lm.GetPolicyUpsert(policyRequest{
storage: storage,
keyType: keyType_AES256_GCM96,
name: "test",
p, lock, _, err := lm.GetPolicyUpsert(PolicyRequest{
Storage: storage,
KeyType: KeyType_AES256_GCM96,
Name: "test",
})
if lock != nil {
defer lock.RUnlock()
@ -223,7 +223,7 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
checkKeys(t, p, storage, "initial", 1, 1, 1)
for i := 2; i <= 10; i++ {
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
@ -271,7 +271,7 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
}
func checkKeys(t *testing.T,
p *policy,
p *Policy,
storage logical.Storage,
action string,
archiveVer, latestVer, keysSize int) {
@ -282,7 +282,7 @@ func checkKeys(t *testing.T,
"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
}
archive, err := p.loadArchive(storage)
archive, err := p.LoadArchive(storage)
if err != nil {
t.Fatal(err)
}

View file

@ -21,6 +21,7 @@ import (
"github.com/hashicorp/vault/api"
credCert "github.com/hashicorp/vault/builtin/credential/cert"
"github.com/hashicorp/vault/builtin/logical/transit"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
@ -381,7 +382,7 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, rpc, parallel bool, num uin
secret, err := doResp(resp)
if err != nil {
// This could well happen since the min version is jumping around
if strings.Contains(err.Error(), transit.ErrTooOld) {
if strings.Contains(err.Error(), keysutil.ErrTooOld) {
mySuccessfulOps++
continue
}

View file

@ -26,6 +26,11 @@ const (
// NoRequestForwardingHeaderName is the name of the header telling Vault
// not to use request forwarding
NoRequestForwardingHeaderName = "X-Vault-No-Request-Forwarding"
// MaxRequestSize is the maximum accepted request size. This is to prevent
// a denial of service attack where no Content-Length is provided and the server
// is fed ever more data until it exhausts memory.
MaxRequestSize = 32 * 1024 * 1024
)
// Handler returns an http.Handler for the API. This can be used on
@ -109,7 +114,10 @@ func stripPrefix(prefix, path string) (string, bool) {
}
func parseRequest(r *http.Request, out interface{}) error {
err := jsonutil.DecodeJSONFromReader(r.Body, out)
// Limit the maximum number of bytes to MaxRequestSize to protect
// against an indefinite amount of data being read.
limit := &io.LimitedReader{R: r.Body, N: MaxRequestSize}
err := jsonutil.DecodeJSONFromReader(limit, out)
if err != nil && err != io.EOF {
return fmt.Errorf("Failed to parse JSON input: %s", err)
}
@ -245,10 +253,19 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) {
}
// requestAuth adds the token to the logical.Request if it exists.
func requestAuth(r *http.Request, req *logical.Request) *logical.Request {
func requestAuth(core *vault.Core, r *http.Request, req *logical.Request) *logical.Request {
// Attach the header value if we have it
if v := r.Header.Get(AuthHeaderName); v != "" {
req.ClientToken = v
// Also attach the accessor if we have it. This doesn't fail if it
// doesn't exist because the request may be to an unauthenticated
// endpoint/login endpoint where a bad current token doesn't matter, or
// a token from a Vault version pre-accessors.
te, err := core.LookupToken(v)
if err == nil && te != nil {
req.ClientTokenAccessor = te.Accessor
}
}
return req

View file

@ -27,11 +27,13 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, req *http.Request) {
return
}
resp, err := core.HandleRequest(requestAuth(req, &logical.Request{
lreq := requestAuth(core, req, &logical.Request{
Operation: logical.HelpOperation,
Path: path,
Connection: getConnection(req),
}))
})
resp, err := core.HandleRequest(lreq)
if err != nil {
respondError(w, http.StatusInternalServerError, err)
return

View file

@ -16,7 +16,7 @@ import (
type PrepareRequestFunc func(*vault.Core, *logical.Request) error
func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Request, int, error) {
func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, int, error) {
// Determine the path...
if !strings.HasPrefix(r.URL.Path, "/v1/") {
return nil, http.StatusNotFound, nil
@ -26,6 +26,11 @@ func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Reque
return nil, http.StatusNotFound, nil
}
// Verify the content length does not exceed the maximum size
if r.ContentLength >= MaxRequestSize {
return nil, http.StatusRequestEntityTooLarge, nil
}
// Determine the operation
var op logical.Operation
switch r.Method {
@ -72,13 +77,14 @@ func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Reque
return nil, http.StatusBadRequest, errwrap.Wrapf("failed to generate identifier for the request: {{err}}", err)
}
req := requestAuth(r, &logical.Request{
req := requestAuth(core, r, &logical.Request{
ID: request_id,
Operation: op,
Path: path,
Data: data,
Connection: getConnection(r),
})
req, err = requestWrapTTL(r, req)
if err != nil {
return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
@ -89,7 +95,7 @@ func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Reque
func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback PrepareRequestFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, statusCode, err := buildLogicalRequest(w, r)
req, statusCode, err := buildLogicalRequest(core, w, r)
if err != nil || statusCode != 0 {
respondError(w, statusCode, err)
return

View file

@ -231,3 +231,16 @@ func TestLogical_RawHTTP(t *testing.T) {
t.Fatalf("Bad: %s", body.Bytes())
}
}
func TestLogical_RequestSizeLimit(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := TestServer(t, core)
defer ln.Close()
TestServerAuth(t, addr, token)
// Write a very large object, should fail
resp := testHttpPut(t, token, addr+"/v1/secret/foo", map[string]interface{}{
"data": make([]byte, MaxRequestSize),
})
testResponseStatus(t, resp, 413)
}

View file

@ -15,7 +15,7 @@ import (
func handleSysSeal(core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, statusCode, err := buildLogicalRequest(w, r)
req, statusCode, err := buildLogicalRequest(core, w, r)
if err != nil || statusCode != 0 {
respondError(w, statusCode, err)
return
@ -30,8 +30,13 @@ func handleSysSeal(core *vault.Core) http.Handler {
// Seal with the token above
if err := core.SealWithRequest(req); err != nil {
respondError(w, http.StatusInternalServerError, err)
return
if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
respondError(w, http.StatusForbidden, err)
return
} else {
respondError(w, http.StatusInternalServerError, err)
return
}
}
respondOk(w, nil)
@ -40,7 +45,7 @@ func handleSysSeal(core *vault.Core) http.Handler {
func handleSysStepDown(core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, statusCode, err := buildLogicalRequest(w, r)
req, statusCode, err := buildLogicalRequest(core, w, r)
if err != nil || statusCode != 0 {
respondError(w, statusCode, err)
return

View file

@ -285,7 +285,7 @@ func TestSysSeal_Permissions(t *testing.T) {
// We expect this to fail since it needs update and sudo
httpResp := testHttpPut(t, "child", addr+"/v1/sys/seal", nil)
testResponseStatus(t, httpResp, 500)
testResponseStatus(t, httpResp, 403)
// Now modify to add update capability
req = &logical.Request{
@ -306,7 +306,7 @@ func TestSysSeal_Permissions(t *testing.T) {
// We expect this to fail since it needs sudo
httpResp = testHttpPut(t, "child", addr+"/v1/sys/seal", nil)
testResponseStatus(t, httpResp, 500)
testResponseStatus(t, httpResp, 403)
// Now modify to just sudo capability
req = &logical.Request{
@ -327,7 +327,7 @@ func TestSysSeal_Permissions(t *testing.T) {
// We expect this to fail since it needs update
httpResp = testHttpPut(t, "child", addr+"/v1/sys/seal", nil)
testResponseStatus(t, httpResp, 500)
testResponseStatus(t, httpResp, 403)
// Now modify to add all needed capabilities
req = &logical.Request{

View file

@ -47,6 +47,10 @@ type Request struct {
// hashed.
ClientToken string `json:"client_token" structs:"client_token" mapstructure:"client_token"`
// ClientTokenAccessor is provided to the core so that the it can get
// logged as part of request audit logging.
ClientTokenAccessor string `json:"client_token_accessor" structs:"client_token_accessor" mapstructure:"client_token_accessor"`
// DisplayName is provided to the logical backend to help associate
// dynamic secrets with the source entity. This is not a sensitive
// name, but is useful for operators.

View file

@ -9,7 +9,7 @@ import (
const (
// DefaultCacheSize is used if no cache size is specified for NewCache
DefaultCacheSize = 1024 * 1024
DefaultCacheSize = 32 * 1024
)
// Cache is used to wrap an underlying physical backend
@ -45,7 +45,9 @@ func (c *Cache) Purge() {
func (c *Cache) Put(entry *Entry) error {
err := c.backend.Put(entry)
c.lru.Add(entry.Key, entry)
if err == nil {
c.lru.Add(entry.Key, entry)
}
return err
}
@ -78,7 +80,9 @@ func (c *Cache) Get(key string) (*Entry, error) {
func (c *Cache) Delete(key string) error {
err := c.backend.Delete(key)
c.lru.Remove(key)
if err == nil {
c.lru.Remove(key)
}
return err
}

View file

@ -173,6 +173,9 @@ func (m *MySQLBackend) List(prefix string) ([]string, error) {
// Add the % wildcard to the prefix to do the prefix search
likePrefix := prefix + "%"
rows, err := m.statements["list"].Query(likePrefix)
if err != nil {
return nil, fmt.Errorf("failed to execute statement: %v", err)
}
var keys []string
for rows.Next() {

View file

@ -10,7 +10,7 @@ RUN apt-get update -y && apt-get install --no-install-recommends -y -q \
git mercurial bzr \
&& rm -rf /var/lib/apt/lists/*
ENV GOVERSION 1.7.1
ENV GOVERSION 1.7.3
RUN mkdir /goroot && mkdir /gopath
RUN curl https://storage.googleapis.com/golang/go${GOVERSION}.linux-amd64.tar.gz \
| tar xvzf - -C /goroot --strip-components=1

View file

@ -29,13 +29,11 @@ func makePolynomial(intercept, degree uint8) (polynomial, error) {
// Ensure the intercept is set
p.coefficients[0] = intercept
// Assign random co-efficients to the polynomial, ensuring
// the highest order co-efficient is non-zero
for p.coefficients[degree] == 0 {
if _, err := rand.Read(p.coefficients[1:]); err != nil {
return p, err
}
// Assign random co-efficients to the polynomial
if _, err := rand.Read(p.coefficients[1:]); err != nil {
return p, err
}
return p, nil
}

View file

@ -256,8 +256,10 @@ func (c *Core) teardownAudits() error {
c.auditLock.Lock()
defer c.auditLock.Unlock()
for _, entry := range c.audit.Entries {
c.removeAuditReloadFunc(entry)
if c.audit != nil {
for _, entry := range c.audit.Entries {
c.removeAuditReloadFunc(entry)
}
}
c.audit = nil

View file

@ -350,6 +350,17 @@ func (c *Core) teardownCredentials() error {
c.authLock.Lock()
defer c.authLock.Unlock()
if c.auth != nil {
authTable := c.auth.shallowClone()
for _, e := range authTable.Entries {
prefix := e.Path
b, ok := c.router.root.Get(prefix)
if ok {
b.(*routeEntry).backend.Cleanup()
}
}
}
c.auth = nil
c.tokenStore = nil
return nil

View file

@ -316,6 +316,10 @@ func (c *Core) stopClusterListener() {
return
}
if !c.clusterListenersRunning {
c.logger.Info("core/stopClusterListener: listeners not running")
return
}
c.logger.Info("core/stopClusterListener: stopping listeners")
// Tell the goroutine managing the listeners to perform the shutdown
@ -327,6 +331,8 @@ func (c *Core) stopClusterListener() {
// bind errors. This ensures proper ordering.
c.logger.Trace("core/stopClusterListener: waiting for success notification")
<-c.clusterListenerShutdownSuccessCh
c.clusterListenersRunning = false
c.logger.Info("core/stopClusterListener: success")
}
@ -417,21 +423,3 @@ func WrapHandlerForClustering(handler http.Handler, logger log.Logger) func() (h
return handler, mux
}
}
// WrapListenersForClustering takes in Vault's cluster addresses and returns a
// setup function that creates the new listeners
func WrapListenersForClustering(addrs []string, logger log.Logger) func() ([]net.Listener, error) {
return func() ([]net.Listener, error) {
ret := make([]net.Listener, 0, len(addrs))
// Loop over the existing listeners and start listeners on appropriate ports
for _, addr := range addrs {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
ret = append(ret, ln)
}
return ret, nil
}
}

View file

@ -16,6 +16,10 @@ import (
log "github.com/mgutz/logxi/v1"
)
var (
clusterTestPausePeriod = 2 * time.Second
)
func TestClusterFetching(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
@ -109,16 +113,16 @@ func TestCluster_ListenForRequests(t *testing.T) {
t.Fatal("%s not a TCP port", tcpAddr.String())
}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port+1), tlsConfig)
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port+10), tlsConfig)
if err != nil {
if expectFail {
t.Logf("testing %s:%d unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+1)
t.Logf("testing %s:%d unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+10)
continue
}
t.Fatalf("error: %v\nlisteners are\n%#v\n%#v\n", err, cores[0].Listeners[0], cores[0].Listeners[1])
}
if expectFail {
t.Fatalf("testing %s:%d not unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+1)
t.Fatalf("testing %s:%d not unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+10)
}
err = conn.Handshake()
if err != nil {
@ -131,11 +135,11 @@ func TestCluster_ListenForRequests(t *testing.T) {
case connState.NegotiatedProtocol != "h2" || !connState.NegotiatedProtocolIsMutual:
t.Fatal("bad protocol negotiation")
}
t.Logf("testing %s:%d successful", tcpAddr.IP.String(), tcpAddr.Port+1)
t.Logf("testing %s:%d successful", tcpAddr.IP.String(), tcpAddr.Port+10)
}
}
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
checkListenersFunc(false)
err := cores[0].StepDown(&logical.Request{
@ -149,7 +153,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
// StepDown doesn't wait during actual preSeal so give time for listeners
// to close
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
checkListenersFunc(true)
// After this period it should be active again
@ -160,7 +164,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
if err != nil {
t.Fatal(err)
}
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
// After sealing it should be inactive again
checkListenersFunc(true)
}
@ -230,13 +234,13 @@ func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) {
if err != nil {
t.Fatal(err)
}
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
_ = cores[2].StepDown(&logical.Request{
Operation: logical.UpdateOperation,
Path: "sys/step-down",
ClientToken: root,
})
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
TestWaitActive(t, cores[1].Core)
testCluster_ForwardRequests(t, cores[0], "core2")
testCluster_ForwardRequests(t, cores[2], "core2")
@ -250,13 +254,13 @@ func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) {
if err != nil {
t.Fatal(err)
}
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
_ = cores[0].StepDown(&logical.Request{
Operation: logical.UpdateOperation,
Path: "sys/step-down",
ClientToken: root,
})
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
TestWaitActive(t, cores[2].Core)
testCluster_ForwardRequests(t, cores[0], "core3")
testCluster_ForwardRequests(t, cores[1], "core3")
@ -270,13 +274,13 @@ func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) {
if err != nil {
t.Fatal(err)
}
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
_ = cores[1].StepDown(&logical.Request{
Operation: logical.UpdateOperation,
Path: "sys/step-down",
ClientToken: root,
})
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
TestWaitActive(t, cores[0].Core)
testCluster_ForwardRequests(t, cores[1], "core1")
testCluster_ForwardRequests(t, cores[2], "core1")
@ -290,13 +294,13 @@ func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) {
if err != nil {
t.Fatal(err)
}
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
_ = cores[2].StepDown(&logical.Request{
Operation: logical.UpdateOperation,
Path: "sys/step-down",
ClientToken: root,
})
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
TestWaitActive(t, cores[1].Core)
testCluster_ForwardRequests(t, cores[0], "core2")
testCluster_ForwardRequests(t, cores[2], "core2")
@ -310,13 +314,13 @@ func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) {
if err != nil {
t.Fatal(err)
}
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
_ = cores[0].StepDown(&logical.Request{
Operation: logical.UpdateOperation,
Path: "sys/step-down",
ClientToken: root,
})
time.Sleep(2 * time.Second)
time.Sleep(clusterTestPausePeriod)
TestWaitActive(t, cores[2].Core)
testCluster_ForwardRequests(t, cores[0], "core3")
testCluster_ForwardRequests(t, cores[1], "core3")

View file

@ -13,12 +13,12 @@ import (
"sync"
"time"
"github.com/armon/go-metrics"
log "github.com/mgutz/logxi/v1"
"golang.org/x/net/context"
"google.golang.org/grpc"
"github.com/armon/go-metrics"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-uuid"
@ -269,6 +269,9 @@ type Core struct {
clusterListenerAddrs []*net.TCPAddr
// The setup function that gives us the handler to use
clusterHandlerSetupFunc func() (http.Handler, http.Handler)
// Tracks whether cluster listeners are running, e.g. it's safe to send a
// shutdown down the channel
clusterListenersRunning bool
// Shutdown channel for the cluster listeners
clusterListenerShutdownCh chan struct{}
// Shutdown success channel. We need this to be done serially to ensure
@ -492,6 +495,23 @@ func (c *Core) Shutdown() error {
return c.sealInternal()
}
// LookupToken returns the properties of the token from the token store. This
// is particularly useful to fetch the accessor of the client token and get it
// populated in the logical request along with the client token. The accessor
// of the client token can get audit logged.
func (c *Core) LookupToken(token string) (*TokenEntry, error) {
if token == "" {
return nil, fmt.Errorf("missing client token")
}
// Many tests don't have a token store running
if c.tokenStore == nil {
return nil, nil
}
return c.tokenStore.Lookup(token)
}
func (c *Core) fetchACLandTokenEntry(req *logical.Request) (*ACL, *TokenEntry, error) {
defer metrics.MeasureSince([]string{"core", "fetch_acl_and_token"}, time.Now())

View file

@ -1679,11 +1679,13 @@ func (b *SystemBackend) handleWrappingRewrap(
// Return response in "response"; wrapping code will detect the rewrap and
// slot in instead of nesting
req.WrapTTL = time.Duration(creationTTL)
return &logical.Response{
Data: map[string]interface{}{
"response": response,
},
WrapInfo: &logical.WrapInfo{
TTL: time.Duration(creationTTL),
},
}, nil
}

View file

@ -68,7 +68,9 @@ func (c *Core) startForwarding() error {
go func() {
defer shutdownWg.Done()
c.logger.Info("core/startClusterListener: starting listener")
if c.logger.IsInfo() {
c.logger.Info("core/startClusterListener: starting listener", "listener_address", laddr)
}
// Create a TCP listener. We do this separately and specifically
// with TCP so that we can set deadlines.
@ -143,6 +145,10 @@ func (c *Core) startForwarding() error {
// This is in its own goroutine so that we don't block the main thread, and
// thus we use atomic and channels to coordinate
// However, because you can't query the status of a channel, we set a bool
// here while we have the state lock to know whether to actually send a
// shutdown (e.g. whether the channel will block). See issue #2083.
c.clusterListenersRunning = true
go func() {
// If we get told to shut down...
<-c.clusterListenerShutdownCh

View file

@ -39,7 +39,7 @@ var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion3
const _ = grpc.SupportPackageIsVersion4
// Client API for RequestForwarding service
@ -102,7 +102,7 @@ var _RequestForwarding_serviceDesc = grpc.ServiceDesc{
},
},
Streams: []grpc.StreamDesc{},
Metadata: fileDescriptor0,
Metadata: "request_forwarding_service.proto",
}
func init() { proto.RegisterFile("request_forwarding_service.proto", fileDescriptor0) }

View file

@ -186,12 +186,29 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
// Route the request
resp, err := c.router.Route(req)
if resp != nil {
// We don't allow backends to specify this, so ensure it's not set
resp.WrapInfo = nil
// If wrapping is used, use the shortest between the request and response
var wrapTTL time.Duration
if req.WrapTTL != 0 {
// Ensure no wrap info information is set other than, possibly, the TTL
if resp.WrapInfo != nil {
if resp.WrapInfo.TTL > 0 {
wrapTTL = resp.WrapInfo.TTL
}
resp.WrapInfo = nil
}
if req.WrapTTL > 0 {
switch {
case wrapTTL == 0:
wrapTTL = req.WrapTTL
case req.WrapTTL < wrapTTL:
wrapTTL = req.WrapTTL
}
}
if wrapTTL > 0 {
resp.WrapInfo = &logical.WrapInfo{
TTL: req.WrapTTL,
TTL: wrapTTL,
}
}
}
@ -306,14 +323,32 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
// Route the request
resp, err := c.router.Route(req)
if resp != nil {
// We don't allow backends to specify this, so ensure it's not set
resp.WrapInfo = nil
// If wrapping is used, use the shortest between the request and response
var wrapTTL time.Duration
if req.WrapTTL != 0 {
resp.WrapInfo = &logical.WrapInfo{
TTL: req.WrapTTL,
// Ensure no wrap info information is set other than, possibly, the TTL
if resp.WrapInfo != nil {
if resp.WrapInfo.TTL > 0 {
wrapTTL = resp.WrapInfo.TTL
}
resp.WrapInfo = nil
}
if req.WrapTTL > 0 {
switch {
case wrapTTL == 0:
wrapTTL = req.WrapTTL
case req.WrapTTL < wrapTTL:
wrapTTL = req.WrapTTL
}
}
if wrapTTL > 0 {
resp.WrapInfo = &logical.WrapInfo{
TTL: wrapTTL,
}
}
}
// A login request should never return a secret!

View file

@ -263,12 +263,7 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica
req.ID = originalReqID
req.Storage = nil
req.ClientToken = clientToken
// Only the rewrap endpoint is allowed to declare a wrap TTL on a
// request that did not come from the client
if req.Path != "sys/wrapping/rewrap" {
req.WrapTTL = originalWrapTTL
}
req.WrapTTL = originalWrapTTL
}()
// Invoke the backend

View file

@ -584,6 +584,11 @@ func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unseal
}
// Create three cores with the same physical and different redirect/cluster addrs
// N.B.: On OSX, instead of random ports, it assigns new ports to new
// listeners sequentially. Aside from being a bad idea in a security sense,
// it also broke tests that assumed it was OK to just use the port above
// the redirect addr. This has now been changed to 10 ports above, but if
// we ever do more than three nodes in a cluster it may need to be bumped.
coreConfig := &CoreConfig{
Physical: physical.NewInmem(logger),
HAPhysical: physical.NewInmemHA(logger),
@ -591,7 +596,7 @@ func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unseal
CredentialBackends: make(map[string]logical.Factory),
AuditBackends: make(map[string]audit.Factory),
RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port),
ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port+1),
ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port+10),
DisableMlock: true,
}
@ -629,7 +634,7 @@ func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unseal
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port)
if coreConfig.ClusterAddr != "" {
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port+1)
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port+10)
}
c2, err := NewCore(coreConfig)
if err != nil {
@ -638,7 +643,7 @@ func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unseal
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port)
if coreConfig.ClusterAddr != "" {
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port+1)
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port+10)
}
c3, err := NewCore(coreConfig)
if err != nil {
@ -653,7 +658,7 @@ func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unseal
for i, ln := range lns {
ret[i] = &net.TCPAddr{
IP: ln.Address.IP,
Port: ln.Address.Port + 1,
Port: ln.Address.Port + 10,
}
}
return ret

View file

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strconv"
@ -256,6 +257,23 @@ const (
blobCopyStatusFailed = "failed"
)
// lease constants.
const (
leaseHeaderPrefix = "x-ms-lease-"
leaseID = "x-ms-lease-id"
leaseAction = "x-ms-lease-action"
leaseBreakPeriod = "x-ms-lease-break-period"
leaseDuration = "x-ms-lease-duration"
leaseProposedID = "x-ms-proposed-lease-id"
leaseTime = "x-ms-lease-time"
acquireLease = "acquire"
renewLease = "renew"
changeLease = "change"
releaseLease = "release"
breakLease = "break"
)
// BlockListType is used to filter out types of blocks in a Get Blocks List call
// for a block blob.
//
@ -284,6 +302,65 @@ const (
ContainerAccessTypeContainer ContainerAccessType = "container"
)
// ContainerAccessOptions are used when setting ACLs of containers (after creation)
type ContainerAccessOptions struct {
ContainerAccess ContainerAccessType
Timeout int
LeaseID string
}
// AccessPolicyDetails are used for SETTING policies
type AccessPolicyDetails struct {
ID string
StartTime time.Time
ExpiryTime time.Time
CanRead bool
CanWrite bool
CanDelete bool
}
// ContainerPermissions is used when setting permissions and Access Policies for containers.
type ContainerPermissions struct {
AccessOptions ContainerAccessOptions
AccessPolicy AccessPolicyDetails
}
// AccessPolicyDetailsXML has specifics about an access policy
// annotated with XML details.
type AccessPolicyDetailsXML struct {
StartTime time.Time `xml:"Start"`
ExpiryTime time.Time `xml:"Expiry"`
Permission string `xml:"Permission"`
}
// SignedIdentifier is a wrapper for a specific policy
type SignedIdentifier struct {
ID string `xml:"Id"`
AccessPolicy AccessPolicyDetailsXML `xml:"AccessPolicy"`
}
// SignedIdentifiers part of the response from GetPermissions call.
type SignedIdentifiers struct {
SignedIdentifiers []SignedIdentifier `xml:"SignedIdentifier"`
}
// AccessPolicy is the response type from the GetPermissions call.
type AccessPolicy struct {
SignedIdentifiersList SignedIdentifiers `xml:"SignedIdentifiers"`
}
// ContainerAccessResponse is returned for the GetContainerPermissions function.
// This contains both the permission and access policy for the container.
type ContainerAccessResponse struct {
ContainerAccess ContainerAccessType
AccessPolicy SignedIdentifiers
}
// ContainerAccessHeader references header used when setting/getting container ACL
const (
ContainerAccessHeader string = "x-ms-blob-public-access"
)
// Maximum sizes (per REST API) for various concepts
const (
MaxBlobBlockSize = 4 * 1024 * 1024
@ -399,7 +476,7 @@ func (b BlobStorageClient) createContainer(name string, access ContainerAccessTy
headers := b.client.getStandardHeaders()
if access != "" {
headers["x-ms-blob-public-access"] = string(access)
headers[ContainerAccessHeader] = string(access)
}
return b.client.exec(verb, uri, headers, nil)
}
@ -421,6 +498,101 @@ func (b BlobStorageClient) ContainerExists(name string) (bool, error) {
return false, err
}
// SetContainerPermissions sets up container permissions as per https://msdn.microsoft.com/en-us/library/azure/dd179391.aspx
func (b BlobStorageClient) SetContainerPermissions(container string, containerPermissions ContainerPermissions) (err error) {
params := url.Values{
"restype": {"container"},
"comp": {"acl"},
}
if containerPermissions.AccessOptions.Timeout > 0 {
params.Add("timeout", strconv.Itoa(containerPermissions.AccessOptions.Timeout))
}
uri := b.client.getEndpoint(blobServiceName, pathForContainer(container), params)
headers := b.client.getStandardHeaders()
if containerPermissions.AccessOptions.ContainerAccess != "" {
headers[ContainerAccessHeader] = string(containerPermissions.AccessOptions.ContainerAccess)
}
if containerPermissions.AccessOptions.LeaseID != "" {
headers[leaseID] = containerPermissions.AccessOptions.LeaseID
}
// generate the XML for the SharedAccessSignature if required.
accessPolicyXML, err := generateAccessPolicy(containerPermissions.AccessPolicy)
if err != nil {
return err
}
var resp *storageResponse
if accessPolicyXML != "" {
headers["Content-Length"] = strconv.Itoa(len(accessPolicyXML))
resp, err = b.client.exec("PUT", uri, headers, strings.NewReader(accessPolicyXML))
} else {
resp, err = b.client.exec("PUT", uri, headers, nil)
}
if err != nil {
return err
}
if resp != nil {
defer func() {
err = resp.body.Close()
}()
if resp.statusCode != http.StatusOK {
return errors.New("Unable to set permissions")
}
}
return nil
}
// GetContainerPermissions gets the container permissions as per https://msdn.microsoft.com/en-us/library/azure/dd179469.aspx
// If timeout is 0 then it will not be passed to Azure
// leaseID will only be passed to Azure if populated
// Returns permissionResponse which is combined permissions and AccessPolicy
func (b BlobStorageClient) GetContainerPermissions(container string, timeout int, leaseID string) (permissionResponse *ContainerAccessResponse, err error) {
params := url.Values{"restype": {"container"},
"comp": {"acl"}}
if timeout > 0 {
params.Add("timeout", strconv.Itoa(timeout))
}
uri := b.client.getEndpoint(blobServiceName, pathForContainer(container), params)
headers := b.client.getStandardHeaders()
if leaseID != "" {
headers[leaseID] = leaseID
}
resp, err := b.client.exec("GET", uri, headers, nil)
if err != nil {
return nil, err
}
// containerAccess. Blob, Container, empty
containerAccess := resp.headers.Get(http.CanonicalHeaderKey(ContainerAccessHeader))
defer func() {
err = resp.body.Close()
}()
var out AccessPolicy
err = xmlUnmarshal(resp.body, &out.SignedIdentifiersList)
if err != nil {
return nil, err
}
permissionResponse = &ContainerAccessResponse{}
permissionResponse.AccessPolicy = out.SignedIdentifiersList
permissionResponse.ContainerAccess = ContainerAccessType(containerAccess)
return permissionResponse, nil
}
// DeleteContainer deletes the container with given name on the storage
// account. If the container does not exist returns error.
//
@ -560,6 +732,132 @@ func (b BlobStorageClient) getBlobRange(container, name, bytesRange string, extr
return resp, err
}
// leasePut is common PUT code for the various aquire/release/break etc functions.
func (b BlobStorageClient) leaseCommonPut(container string, name string, headers map[string]string, expectedStatus int) (http.Header, error) {
params := url.Values{"comp": {"lease"}}
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), params)
resp, err := b.client.exec("PUT", uri, headers, nil)
if err != nil {
return nil, err
}
defer resp.body.Close()
if err := checkRespCode(resp.statusCode, []int{expectedStatus}); err != nil {
return nil, err
}
return resp.headers, nil
}
// AcquireLease creates a lease for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
// returns leaseID acquired
func (b BlobStorageClient) AcquireLease(container string, name string, leaseTimeInSeconds int, proposedLeaseID string) (returnedLeaseID string, err error) {
headers := b.client.getStandardHeaders()
headers[leaseAction] = acquireLease
headers[leaseProposedID] = proposedLeaseID
headers[leaseDuration] = strconv.Itoa(leaseTimeInSeconds)
respHeaders, err := b.leaseCommonPut(container, name, headers, http.StatusCreated)
if err != nil {
return "", err
}
returnedLeaseID = respHeaders.Get(http.CanonicalHeaderKey(leaseID))
if returnedLeaseID != "" {
return returnedLeaseID, nil
}
return "", errors.New("LeaseID not returned")
}
// BreakLease breaks the lease for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
// Returns the timeout remaining in the lease in seconds
func (b BlobStorageClient) BreakLease(container string, name string) (breakTimeout int, err error) {
headers := b.client.getStandardHeaders()
headers[leaseAction] = breakLease
return b.breakLeaseCommon(container, name, headers)
}
// BreakLeaseWithBreakPeriod breaks the lease for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
// breakPeriodInSeconds is used to determine how long until new lease can be created.
// Returns the timeout remaining in the lease in seconds
func (b BlobStorageClient) BreakLeaseWithBreakPeriod(container string, name string, breakPeriodInSeconds int) (breakTimeout int, err error) {
headers := b.client.getStandardHeaders()
headers[leaseAction] = breakLease
headers[leaseBreakPeriod] = strconv.Itoa(breakPeriodInSeconds)
return b.breakLeaseCommon(container, name, headers)
}
// breakLeaseCommon is common code for both version of BreakLease (with and without break period)
func (b BlobStorageClient) breakLeaseCommon(container string, name string, headers map[string]string) (breakTimeout int, err error) {
respHeaders, err := b.leaseCommonPut(container, name, headers, http.StatusAccepted)
if err != nil {
return 0, err
}
breakTimeoutStr := respHeaders.Get(http.CanonicalHeaderKey(leaseTime))
if breakTimeoutStr != "" {
breakTimeout, err = strconv.Atoi(breakTimeoutStr)
if err != nil {
return 0, err
}
}
return breakTimeout, nil
}
// ChangeLease changes a lease ID for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
// Returns the new LeaseID acquired
func (b BlobStorageClient) ChangeLease(container string, name string, currentLeaseID string, proposedLeaseID string) (newLeaseID string, err error) {
headers := b.client.getStandardHeaders()
headers[leaseAction] = changeLease
headers[leaseID] = currentLeaseID
headers[leaseProposedID] = proposedLeaseID
respHeaders, err := b.leaseCommonPut(container, name, headers, http.StatusOK)
if err != nil {
return "", err
}
newLeaseID = respHeaders.Get(http.CanonicalHeaderKey(leaseID))
if newLeaseID != "" {
return newLeaseID, nil
}
return "", errors.New("LeaseID not returned")
}
// ReleaseLease releases the lease for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
func (b BlobStorageClient) ReleaseLease(container string, name string, currentLeaseID string) error {
headers := b.client.getStandardHeaders()
headers[leaseAction] = releaseLease
headers[leaseID] = currentLeaseID
_, err := b.leaseCommonPut(container, name, headers, http.StatusOK)
if err != nil {
return err
}
return nil
}
// RenewLease renews the lease for a blob as per https://msdn.microsoft.com/en-us/library/azure/ee691972.aspx
func (b BlobStorageClient) RenewLease(container string, name string, currentLeaseID string) error {
headers := b.client.getStandardHeaders()
headers[leaseAction] = renewLease
headers[leaseID] = currentLeaseID
_, err := b.leaseCommonPut(container, name, headers, http.StatusOK)
if err != nil {
return err
}
return nil
}
// GetBlobProperties provides various information about the specified
// blob. See https://msdn.microsoft.com/en-us/library/azure/dd179394.aspx
func (b BlobStorageClient) GetBlobProperties(container, name string) (*BlobProperties, error) {
@ -961,15 +1259,20 @@ func (b BlobStorageClient) AppendBlock(container, name string, chunk []byte, ext
//
// See https://msdn.microsoft.com/en-us/library/azure/dd894037.aspx
func (b BlobStorageClient) CopyBlob(container, name, sourceBlob string) error {
copyID, err := b.startBlobCopy(container, name, sourceBlob)
copyID, err := b.StartBlobCopy(container, name, sourceBlob)
if err != nil {
return err
}
return b.waitForBlobCopy(container, name, copyID)
return b.WaitForBlobCopy(container, name, copyID)
}
func (b BlobStorageClient) startBlobCopy(container, name, sourceBlob string) (string, error) {
// StartBlobCopy starts a blob copy operation.
// sourceBlob parameter must be a canonical URL to the blob (can be
// obtained using GetBlobURL method.)
//
// See https://msdn.microsoft.com/en-us/library/azure/dd894037.aspx
func (b BlobStorageClient) StartBlobCopy(container, name, sourceBlob string) (string, error) {
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{})
headers := b.client.getStandardHeaders()
@ -992,7 +1295,39 @@ func (b BlobStorageClient) startBlobCopy(container, name, sourceBlob string) (st
return copyID, nil
}
func (b BlobStorageClient) waitForBlobCopy(container, name, copyID string) error {
// AbortBlobCopy aborts a BlobCopy which has already been triggered by the StartBlobCopy function.
// copyID is generated from StartBlobCopy function.
// currentLeaseID is required IF the destination blob has an active lease on it.
// As defined in https://msdn.microsoft.com/en-us/library/azure/jj159098.aspx
func (b BlobStorageClient) AbortBlobCopy(container, name, copyID, currentLeaseID string, timeout int) error {
params := url.Values{"comp": {"copy"}, "copyid": {copyID}}
if timeout > 0 {
params.Add("timeout", strconv.Itoa(timeout))
}
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), params)
headers := b.client.getStandardHeaders()
headers["x-ms-copy-action"] = "abort"
if currentLeaseID != "" {
headers[leaseID] = currentLeaseID
}
resp, err := b.client.exec("PUT", uri, headers, nil)
if err != nil {
return err
}
defer resp.body.Close()
if err := checkRespCode(resp.statusCode, []int{http.StatusNoContent}); err != nil {
return err
}
return nil
}
// WaitForBlobCopy loops until a BlobCopy operation is completed (or fails with error)
func (b BlobStorageClient) WaitForBlobCopy(container, name, copyID string) error {
for {
props, err := b.GetBlobProperties(container, name)
if err != nil {
@ -1036,10 +1371,12 @@ func (b BlobStorageClient) DeleteBlob(container, name string, extraHeaders map[s
// See https://msdn.microsoft.com/en-us/library/azure/dd179413.aspx
func (b BlobStorageClient) DeleteBlobIfExists(container, name string, extraHeaders map[string]string) (bool, error) {
resp, err := b.deleteBlob(container, name, extraHeaders)
if resp != nil && (resp.statusCode == http.StatusAccepted || resp.statusCode == http.StatusNotFound) {
return resp.statusCode == http.StatusAccepted, nil
if resp != nil {
defer resp.body.Close()
if resp.statusCode == http.StatusAccepted || resp.statusCode == http.StatusNotFound {
return resp.statusCode == http.StatusAccepted, nil
}
}
defer resp.body.Close()
return false, err
}
@ -1065,17 +1402,18 @@ func pathForBlob(container, name string) string {
return fmt.Sprintf("/%s/%s", container, name)
}
// GetBlobSASURI creates an URL to the specified blob which contains the Shared
// Access Signature with specified permissions and expiration time.
// GetBlobSASURIWithSignedIPAndProtocol creates an URL to the specified blob which contains the Shared
// Access Signature with specified permissions and expiration time. Also includes signedIPRange and allowed procotols.
// If old API version is used but no signedIP is passed (ie empty string) then this should still work.
// We only populate the signedIP when it non-empty.
//
// See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx
func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Time, permissions string) (string, error) {
func (b BlobStorageClient) GetBlobSASURIWithSignedIPAndProtocol(container, name string, expiry time.Time, permissions string, signedIPRange string, HTTPSOnly bool) (string, error) {
var (
signedPermissions = permissions
blobURL = b.GetBlobURL(container, name)
)
canonicalizedResource, err := b.client.buildCanonicalizedResource(blobURL)
if err != nil {
return "", err
}
@ -1087,7 +1425,6 @@ func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Tim
// We need to replace + with %2b first to avoid being treated as a space (which is correct for query strings, but not the path component).
canonicalizedResource = strings.Replace(canonicalizedResource, "+", "%2b", -1)
canonicalizedResource, err = url.QueryUnescape(canonicalizedResource)
if err != nil {
return "", err
@ -1096,7 +1433,11 @@ func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Tim
signedExpiry := expiry.UTC().Format(time.RFC3339)
signedResource := "b"
stringToSign, err := blobSASStringToSign(b.client.apiVersion, canonicalizedResource, signedExpiry, signedPermissions)
protocols := "https,http"
if HTTPSOnly {
protocols = "https"
}
stringToSign, err := blobSASStringToSign(b.client.apiVersion, canonicalizedResource, signedExpiry, signedPermissions, signedIPRange, protocols)
if err != nil {
return "", err
}
@ -1110,6 +1451,13 @@ func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Tim
"sig": {sig},
}
if b.client.apiVersion >= "2015-04-05" {
sasParams.Add("spr", protocols)
if signedIPRange != "" {
sasParams.Add("sip", signedIPRange)
}
}
sasURL, err := url.Parse(blobURL)
if err != nil {
return "", err
@ -1118,16 +1466,89 @@ func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Tim
return sasURL.String(), nil
}
func blobSASStringToSign(signedVersion, canonicalizedResource, signedExpiry, signedPermissions string) (string, error) {
// GetBlobSASURI creates an URL to the specified blob which contains the Shared
// Access Signature with specified permissions and expiration time.
//
// See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx
func (b BlobStorageClient) GetBlobSASURI(container, name string, expiry time.Time, permissions string) (string, error) {
url, err := b.GetBlobSASURIWithSignedIPAndProtocol(container, name, expiry, permissions, "", false)
return url, err
}
func blobSASStringToSign(signedVersion, canonicalizedResource, signedExpiry, signedPermissions string, signedIP string, protocols string) (string, error) {
var signedStart, signedIdentifier, rscc, rscd, rsce, rscl, rsct string
if signedVersion >= "2015-02-21" {
canonicalizedResource = "/blob" + canonicalizedResource
}
// https://msdn.microsoft.com/en-us/library/azure/dn140255.aspx#Anchor_12
if signedVersion >= "2015-04-05" {
return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s", signedPermissions, signedStart, signedExpiry, canonicalizedResource, signedIdentifier, signedIP, protocols, signedVersion, rscc, rscd, rsce, rscl, rsct), nil
}
// reference: http://msdn.microsoft.com/en-us/library/azure/dn140255.aspx
if signedVersion >= "2013-08-15" {
return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s\n%s", signedPermissions, signedStart, signedExpiry, canonicalizedResource, signedIdentifier, signedVersion, rscc, rscd, rsce, rscl, rsct), nil
}
return "", errors.New("storage: not implemented SAS for versions earlier than 2013-08-15")
}
func generatePermissions(accessPolicy AccessPolicyDetails) (permissions string) {
// generate the permissions string (rwd).
// still want the end user API to have bool flags.
permissions = ""
if accessPolicy.CanRead {
permissions += "r"
}
if accessPolicy.CanWrite {
permissions += "w"
}
if accessPolicy.CanDelete {
permissions += "d"
}
return permissions
}
// convertAccessPolicyToXMLStructs converts between AccessPolicyDetails which is a struct better for API usage to the
// AccessPolicy struct which will get converted to XML.
func convertAccessPolicyToXMLStructs(accessPolicy AccessPolicyDetails) SignedIdentifiers {
return SignedIdentifiers{
SignedIdentifiers: []SignedIdentifier{
{
ID: accessPolicy.ID,
AccessPolicy: AccessPolicyDetailsXML{
StartTime: accessPolicy.StartTime.UTC().Round(time.Second),
ExpiryTime: accessPolicy.ExpiryTime.UTC().Round(time.Second),
Permission: generatePermissions(accessPolicy),
},
},
},
}
}
// generateAccessPolicy generates the XML access policy used as the payload for SetContainerPermissions.
func generateAccessPolicy(accessPolicy AccessPolicyDetails) (accessPolicyXML string, err error) {
if accessPolicy.ID != "" {
signedIdentifiers := convertAccessPolicyToXMLStructs(accessPolicy)
body, _, err := xmlMarshal(signedIdentifiers)
if err != nil {
return "", err
}
xmlByteArray, err := ioutil.ReadAll(body)
if err != nil {
return "", err
}
accessPolicyXML = string(xmlByteArray)
return accessPolicyXML, nil
}
return "", nil
}

View file

@ -305,7 +305,7 @@ func (c Client) buildCanonicalizedResourceTable(uri string) (string, error) {
cr := "/" + c.getCanonicalizedAccountName()
if len(u.Path) > 0 {
cr += u.Path
cr += u.EscapedPath()
}
return cr, nil

View file

@ -82,6 +82,24 @@ func (p PeekMessagesParameters) getParameters() url.Values {
return out
}
// UpdateMessageParameters is the set of options can be specified for Update Messsage
// operation. A zero struct does not use any preferences for the request.
type UpdateMessageParameters struct {
PopReceipt string
VisibilityTimeout int
}
func (p UpdateMessageParameters) getParameters() url.Values {
out := url.Values{}
if p.PopReceipt != "" {
out.Set("popreceipt", p.PopReceipt)
}
if p.VisibilityTimeout != 0 {
out.Set("visibilitytimeout", strconv.Itoa(p.VisibilityTimeout))
}
return out
}
// GetMessagesResponse represents a response returned from Get Messages
// operation.
type GetMessagesResponse struct {
@ -304,3 +322,23 @@ func (c QueueServiceClient) DeleteMessage(queue, messageID, popReceipt string) e
defer resp.body.Close()
return checkRespCode(resp.statusCode, []int{http.StatusNoContent})
}
// UpdateMessage operation deletes the specified message.
//
// See https://msdn.microsoft.com/en-us/library/azure/hh452234.aspx
func (c QueueServiceClient) UpdateMessage(queue string, messageID string, message string, params UpdateMessageParameters) error {
uri := c.client.getEndpoint(queueServiceName, pathForMessage(queue, messageID), params.getParameters())
req := putMessageRequest{MessageText: message}
body, nn, err := xmlMarshal(req)
if err != nil {
return err
}
headers := c.client.getStandardHeaders()
headers["Content-Length"] = fmt.Sprintf("%d", nn)
resp, err := c.client.exec("PUT", uri, headers, body)
if err != nil {
return err
}
defer resp.body.Close()
return checkRespCode(resp.statusCode, []int{http.StatusNoContent})
}

View file

@ -31,8 +31,7 @@ import (
"strings"
)
/*---------------------------------------------------------------------------------------------------
*/
//--------------------------------------------------------------------------------------------------
var (
// ErrOutOfBounds - Index out of bounds.
@ -63,40 +62,30 @@ var (
ErrInvalidBuffer = errors.New("input buffer contained invalid JSON")
)
/*---------------------------------------------------------------------------------------------------
*/
//--------------------------------------------------------------------------------------------------
/*
Container - an internal structure that holds a reference to the core interface map of the parsed
json. Use this container to move context.
*/
// Container - an internal structure that holds a reference to the core interface map of the parsed
// json. Use this container to move context.
type Container struct {
object interface{}
}
/*
Data - Return the contained data as an interface{}.
*/
// Data - Return the contained data as an interface{}.
func (g *Container) Data() interface{} {
return g.object
}
/*---------------------------------------------------------------------------------------------------
*/
//--------------------------------------------------------------------------------------------------
/*
Path - Search for a value using dot notation.
*/
// Path - Search for a value using dot notation.
func (g *Container) Path(path string) *Container {
return g.Search(strings.Split(path, ".")...)
}
/*
Search - Attempt to find and return an object within the JSON structure by specifying the hierarchy
of field names to locate the target. If the search encounters an array and has not reached the end
target then it will iterate each object of the array for the target and return all of the results in
a JSON array.
*/
// Search - Attempt to find and return an object within the JSON structure by specifying the
// hierarchy of field names to locate the target. If the search encounters an array and has not
// reached the end target then it will iterate each object of the array for the target and return
// all of the results in a JSON array.
func (g *Container) Search(hierarchy ...string) *Container {
var object interface{}
@ -124,31 +113,22 @@ func (g *Container) Search(hierarchy ...string) *Container {
return &Container{object}
}
/*
S - Shorthand method, does the same thing as Search.
*/
// S - Shorthand method, does the same thing as Search.
func (g *Container) S(hierarchy ...string) *Container {
return g.Search(hierarchy...)
}
/*
Exists - Checks whether a path exists.
*/
// Exists - Checks whether a path exists.
func (g *Container) Exists(hierarchy ...string) bool {
return g.Search(hierarchy...).Data() != nil
}
/*
ExistsP - Checks whether a dot notation path exists.
*/
// ExistsP - Checks whether a dot notation path exists.
func (g *Container) ExistsP(path string) bool {
return g.Exists(strings.Split(path, ".")...)
}
/*
Index - Attempt to find and return an object with a JSON array by specifying the index of the
target.
*/
// Index - Attempt to find and return an object within a JSON array by index.
func (g *Container) Index(index int) *Container {
if array, ok := g.Data().([]interface{}); ok {
if index >= len(array) {
@ -159,11 +139,9 @@ func (g *Container) Index(index int) *Container {
return &Container{nil}
}
/*
Children - Return a slice of all the children of the array. This also works for objects, however,
the children returned for an object will NOT be in order and you lose the names of the returned
objects this way.
*/
// Children - Return a slice of all the children of the array. This also works for objects, however,
// the children returned for an object will NOT be in order and you lose the names of the returned
// objects this way.
func (g *Container) Children() ([]*Container, error) {
if array, ok := g.Data().([]interface{}); ok {
children := make([]*Container, len(array))
@ -182,9 +160,7 @@ func (g *Container) Children() ([]*Container, error) {
return nil, ErrNotObjOrArray
}
/*
ChildrenMap - Return a map of all the children of an object.
*/
// ChildrenMap - Return a map of all the children of an object.
func (g *Container) ChildrenMap() (map[string]*Container, error) {
if mmap, ok := g.Data().(map[string]interface{}); ok {
children := map[string]*Container{}
@ -196,14 +172,11 @@ func (g *Container) ChildrenMap() (map[string]*Container, error) {
return nil, ErrNotObj
}
/*---------------------------------------------------------------------------------------------------
*/
//--------------------------------------------------------------------------------------------------
/*
Set - Set the value of a field at a JSON path, any parts of the path that do not exist will be
constructed, and if a collision occurs with a non object type whilst iterating the path an error is
returned.
*/
// Set - Set the value of a field at a JSON path, any parts of the path that do not exist will be
// constructed, and if a collision occurs with a non object type whilst iterating the path an error
// is returned.
func (g *Container) Set(value interface{}, path ...string) (*Container, error) {
if len(path) == 0 {
g.object = value
@ -229,16 +202,12 @@ func (g *Container) Set(value interface{}, path ...string) (*Container, error) {
return &Container{object}, nil
}
/*
SetP - Does the same as Set, but using a dot notation JSON path.
*/
// SetP - Does the same as Set, but using a dot notation JSON path.
func (g *Container) SetP(value interface{}, path string) (*Container, error) {
return g.Set(value, strings.Split(path, ".")...)
}
/*
SetIndex - Set a value of an array element based on the index.
*/
// SetIndex - Set a value of an array element based on the index.
func (g *Container) SetIndex(value interface{}, index int) (*Container, error) {
if array, ok := g.Data().([]interface{}); ok {
if index >= len(array) {
@ -250,80 +219,60 @@ func (g *Container) SetIndex(value interface{}, index int) (*Container, error) {
return &Container{nil}, ErrNotArray
}
/*
Object - Create a new JSON object at a path. Returns an error if the path contains a collision with
a non object type.
*/
// Object - Create a new JSON object at a path. Returns an error if the path contains a collision
// with a non object type.
func (g *Container) Object(path ...string) (*Container, error) {
return g.Set(map[string]interface{}{}, path...)
}
/*
ObjectP - Does the same as Object, but using a dot notation JSON path.
*/
// ObjectP - Does the same as Object, but using a dot notation JSON path.
func (g *Container) ObjectP(path string) (*Container, error) {
return g.Object(strings.Split(path, ".")...)
}
/*
ObjectI - Create a new JSON object at an array index. Returns an error if the object is not an array
or the index is out of bounds.
*/
// ObjectI - Create a new JSON object at an array index. Returns an error if the object is not an
// array or the index is out of bounds.
func (g *Container) ObjectI(index int) (*Container, error) {
return g.SetIndex(map[string]interface{}{}, index)
}
/*
Array - Create a new JSON array at a path. Returns an error if the path contains a collision with
a non object type.
*/
// Array - Create a new JSON array at a path. Returns an error if the path contains a collision with
// a non object type.
func (g *Container) Array(path ...string) (*Container, error) {
return g.Set([]interface{}{}, path...)
}
/*
ArrayP - Does the same as Array, but using a dot notation JSON path.
*/
// ArrayP - Does the same as Array, but using a dot notation JSON path.
func (g *Container) ArrayP(path string) (*Container, error) {
return g.Array(strings.Split(path, ".")...)
}
/*
ArrayI - Create a new JSON array at an array index. Returns an error if the object is not an array
or the index is out of bounds.
*/
// ArrayI - Create a new JSON array at an array index. Returns an error if the object is not an
// array or the index is out of bounds.
func (g *Container) ArrayI(index int) (*Container, error) {
return g.SetIndex([]interface{}{}, index)
}
/*
ArrayOfSize - Create a new JSON array of a particular size at a path. Returns an error if the path
contains a collision with a non object type.
*/
// ArrayOfSize - Create a new JSON array of a particular size at a path. Returns an error if the
// path contains a collision with a non object type.
func (g *Container) ArrayOfSize(size int, path ...string) (*Container, error) {
a := make([]interface{}, size)
return g.Set(a, path...)
}
/*
ArrayOfSizeP - Does the same as ArrayOfSize, but using a dot notation JSON path.
*/
// ArrayOfSizeP - Does the same as ArrayOfSize, but using a dot notation JSON path.
func (g *Container) ArrayOfSizeP(size int, path string) (*Container, error) {
return g.ArrayOfSize(size, strings.Split(path, ".")...)
}
/*
ArrayOfSizeI - Create a new JSON array of a particular size at an array index. Returns an error if
the object is not an array or the index is out of bounds.
*/
// ArrayOfSizeI - Create a new JSON array of a particular size at an array index. Returns an error
// if the object is not an array or the index is out of bounds.
func (g *Container) ArrayOfSizeI(size, index int) (*Container, error) {
a := make([]interface{}, size)
return g.SetIndex(a, index)
}
/*
Delete - Delete an element at a JSON path, an error is returned if the element does not exist.
*/
// Delete - Delete an element at a JSON path, an error is returned if the element does not exist.
func (g *Container) Delete(path ...string) error {
var object interface{}
@ -346,24 +295,19 @@ func (g *Container) Delete(path ...string) error {
return nil
}
/*
DeleteP - Does the same as Delete, but using a dot notation JSON path.
*/
// DeleteP - Does the same as Delete, but using a dot notation JSON path.
func (g *Container) DeleteP(path string) error {
return g.Delete(strings.Split(path, ".")...)
}
/*---------------------------------------------------------------------------------------------------
*/
//--------------------------------------------------------------------------------------------------
/*
Array modification/search - Keeping these options simple right now, no need for anything more
complicated since you can just cast to []interface{}, modify and then reassign with Set.
*/
/*
ArrayAppend - Append a value onto a JSON array.
*/
// ArrayAppend - Append a value onto a JSON array.
func (g *Container) ArrayAppend(value interface{}, path ...string) error {
array, ok := g.Search(path...).Data().([]interface{})
if !ok {
@ -374,16 +318,12 @@ func (g *Container) ArrayAppend(value interface{}, path ...string) error {
return err
}
/*
ArrayAppendP - Append a value onto a JSON array using a dot notation JSON path.
*/
// ArrayAppendP - Append a value onto a JSON array using a dot notation JSON path.
func (g *Container) ArrayAppendP(value interface{}, path string) error {
return g.ArrayAppend(value, strings.Split(path, ".")...)
}
/*
ArrayRemove - Remove an element from a JSON array.
*/
// ArrayRemove - Remove an element from a JSON array.
func (g *Container) ArrayRemove(index int, path ...string) error {
if index < 0 {
return ErrOutOfBounds
@ -401,16 +341,12 @@ func (g *Container) ArrayRemove(index int, path ...string) error {
return err
}
/*
ArrayRemoveP - Remove an element from a JSON array using a dot notation JSON path.
*/
// ArrayRemoveP - Remove an element from a JSON array using a dot notation JSON path.
func (g *Container) ArrayRemoveP(index int, path string) error {
return g.ArrayRemove(index, strings.Split(path, ".")...)
}
/*
ArrayElement - Access an element from a JSON array.
*/
// ArrayElement - Access an element from a JSON array.
func (g *Container) ArrayElement(index int, path ...string) (*Container, error) {
if index < 0 {
return &Container{nil}, ErrOutOfBounds
@ -425,16 +361,12 @@ func (g *Container) ArrayElement(index int, path ...string) (*Container, error)
return &Container{nil}, ErrOutOfBounds
}
/*
ArrayElementP - Access an element from a JSON array using a dot notation JSON path.
*/
// ArrayElementP - Access an element from a JSON array using a dot notation JSON path.
func (g *Container) ArrayElementP(index int, path string) (*Container, error) {
return g.ArrayElement(index, strings.Split(path, ".")...)
}
/*
ArrayCount - Count the number of elements in a JSON array.
*/
// ArrayCount - Count the number of elements in a JSON array.
func (g *Container) ArrayCount(path ...string) (int, error) {
if array, ok := g.Search(path...).Data().([]interface{}); ok {
return len(array), nil
@ -442,19 +374,14 @@ func (g *Container) ArrayCount(path ...string) (int, error) {
return 0, ErrNotArray
}
/*
ArrayCountP - Count the number of elements in a JSON array using a dot notation JSON path.
*/
// ArrayCountP - Count the number of elements in a JSON array using a dot notation JSON path.
func (g *Container) ArrayCountP(path string) (int, error) {
return g.ArrayCount(strings.Split(path, ".")...)
}
/*---------------------------------------------------------------------------------------------------
*/
//--------------------------------------------------------------------------------------------------
/*
Bytes - Converts the contained object back to a JSON []byte blob.
*/
// Bytes - Converts the contained object back to a JSON []byte blob.
func (g *Container) Bytes() []byte {
if g.object != nil {
if bytes, err := json.Marshal(g.object); err == nil {
@ -464,9 +391,7 @@ func (g *Container) Bytes() []byte {
return []byte("{}")
}
/*
BytesIndent - Converts the contained object back to a JSON []byte blob formatted with prefix and indent.
*/
// BytesIndent - Converts the contained object to a JSON []byte blob formatted with prefix, indent.
func (g *Container) BytesIndent(prefix string, indent string) []byte {
if g.object != nil {
if bytes, err := json.MarshalIndent(g.object, prefix, indent); err == nil {
@ -476,37 +401,27 @@ func (g *Container) BytesIndent(prefix string, indent string) []byte {
return []byte("{}")
}
/*
String - Converts the contained object back to a JSON formatted string.
*/
// String - Converts the contained object to a JSON formatted string.
func (g *Container) String() string {
return string(g.Bytes())
}
/*
StringIndent - Converts the contained object back to a JSON formatted string with prefix and indent.
*/
// StringIndent - Converts the contained object back to a JSON formatted string with prefix, indent.
func (g *Container) StringIndent(prefix string, indent string) string {
return string(g.BytesIndent(prefix, indent))
}
/*
New - Create a new gabs JSON object.
*/
// New - Create a new gabs JSON object.
func New() *Container {
return &Container{map[string]interface{}{}}
}
/*
Consume - Gobble up an already converted JSON object, or a fresh map[string]interface{} object.
*/
// Consume - Gobble up an already converted JSON object, or a fresh map[string]interface{} object.
func Consume(root interface{}) (*Container, error) {
return &Container{root}, nil
}
/*
ParseJSON - Convert a string into a representation of the parsed JSON.
*/
// ParseJSON - Convert a string into a representation of the parsed JSON.
func ParseJSON(sample []byte) (*Container, error) {
var gabs Container
@ -517,9 +432,7 @@ func ParseJSON(sample []byte) (*Container, error) {
return &gabs, nil
}
/*
ParseJSONDecoder - Convert a json.Decoder into a representation of the parsed JSON.
*/
// ParseJSONDecoder - Convert a json.Decoder into a representation of the parsed JSON.
func ParseJSONDecoder(decoder *json.Decoder) (*Container, error) {
var gabs Container
@ -530,9 +443,7 @@ func ParseJSONDecoder(decoder *json.Decoder) (*Container, error) {
return &gabs, nil
}
/*
ParseJSONFile - Read a file and convert into a representation of the parsed JSON.
*/
// ParseJSONFile - Read a file and convert into a representation of the parsed JSON.
func ParseJSONFile(path string) (*Container, error) {
if len(path) > 0 {
cBytes, err := ioutil.ReadFile(path)
@ -550,9 +461,7 @@ func ParseJSONFile(path string) (*Container, error) {
return nil, ErrInvalidPath
}
/*
ParseJSONBuffer - Read the contents of a buffer into a representation of the parsed JSON.
*/
// ParseJSONBuffer - Read the contents of a buffer into a representation of the parsed JSON.
func ParseJSONBuffer(buffer io.Reader) (*Container, error) {
var gabs Container
jsonDecoder := json.NewDecoder(buffer)
@ -563,83 +472,4 @@ func ParseJSONBuffer(buffer io.Reader) (*Container, error) {
return &gabs, nil
}
/*---------------------------------------------------------------------------------------------------
*/
// DEPRECATED METHODS
/*
Push - DEPRECATED: Push a value onto a JSON array.
*/
func (g *Container) Push(target string, value interface{}) error {
if mmap, ok := g.Data().(map[string]interface{}); ok {
arrayTarget := mmap[target]
if array, ok := arrayTarget.([]interface{}); ok {
mmap[target] = append(array, value)
} else {
return ErrNotArray
}
} else {
return ErrNotObj
}
return nil
}
/*
RemoveElement - DEPRECATED: Remove a value from a JSON array.
*/
func (g *Container) RemoveElement(target string, index int) error {
if index < 0 {
return ErrOutOfBounds
}
if mmap, ok := g.Data().(map[string]interface{}); ok {
arrayTarget := mmap[target]
if array, ok := arrayTarget.([]interface{}); ok {
if index < len(array) {
mmap[target] = append(array[:index], array[index+1:]...)
} else {
return ErrOutOfBounds
}
} else {
return ErrNotArray
}
} else {
return ErrNotObj
}
return nil
}
/*
GetElement - DEPRECATED: Get the desired element from a JSON array
*/
func (g *Container) GetElement(target string, index int) *Container {
if index < 0 {
return &Container{nil}
}
if mmap, ok := g.Data().(map[string]interface{}); ok {
arrayTarget := mmap[target]
if array, ok := arrayTarget.([]interface{}); ok {
if index < len(array) {
return &Container{array[index]}
}
}
}
return &Container{nil}
}
/*
CountElements - DEPRECATED: Count the elements of a JSON array, returns -1 if the target is not an
array
*/
func (g *Container) CountElements(target string) int {
if mmap, ok := g.Data().(map[string]interface{}); ok {
arrayTarget := mmap[target]
if array, ok := arrayTarget.([]interface{}); ok {
return len(array)
}
}
return -1
}
/*---------------------------------------------------------------------------------------------------
*/
//--------------------------------------------------------------------------------------------------

View file

@ -28,39 +28,42 @@ Examples
Here is an example of using the package:
func SlowMethod() {
// Profiling the runtime of a method
defer metrics.MeasureSince([]string{"SlowMethod"}, time.Now())
}
```go
func SlowMethod() {
// Profiling the runtime of a method
defer metrics.MeasureSince([]string{"SlowMethod"}, time.Now())
}
// Configure a statsite sink as the global metrics sink
sink, _ := metrics.NewStatsiteSink("statsite:8125")
metrics.NewGlobal(metrics.DefaultConfig("service-name"), sink)
// Emit a Key/Value pair
metrics.EmitKey([]string{"questions", "meaning of life"}, 42)
// Configure a statsite sink as the global metrics sink
sink, _ := metrics.NewStatsiteSink("statsite:8125")
metrics.NewGlobal(metrics.DefaultConfig("service-name"), sink)
// Emit a Key/Value pair
metrics.EmitKey([]string{"questions", "meaning of life"}, 42)
```
Here is an example of setting up an signal handler:
// Setup the inmem sink and signal handler
inm := metrics.NewInmemSink(10*time.Second, time.Minute)
sig := metrics.DefaultInmemSignal(inm)
metrics.NewGlobal(metrics.DefaultConfig("service-name"), inm)
```go
// Setup the inmem sink and signal handler
inm := metrics.NewInmemSink(10*time.Second, time.Minute)
sig := metrics.DefaultInmemSignal(inm)
metrics.NewGlobal(metrics.DefaultConfig("service-name"), inm)
// Run some code
inm.SetGauge([]string{"foo"}, 42)
inm.EmitKey([]string{"bar"}, 30)
// Run some code
inm.SetGauge([]string{"foo"}, 42)
inm.EmitKey([]string{"bar"}, 30)
inm.IncrCounter([]string{"baz"}, 42)
inm.IncrCounter([]string{"baz"}, 1)
inm.IncrCounter([]string{"baz"}, 80)
inm.IncrCounter([]string{"baz"}, 42)
inm.IncrCounter([]string{"baz"}, 1)
inm.IncrCounter([]string{"baz"}, 80)
inm.AddSample([]string{"method", "wow"}, 42)
inm.AddSample([]string{"method", "wow"}, 100)
inm.AddSample([]string{"method", "wow"}, 22)
inm.AddSample([]string{"method", "wow"}, 42)
inm.AddSample([]string{"method", "wow"}, 100)
inm.AddSample([]string{"method", "wow"}, 22)
....
....
```
When a signal comes in, output like the following will be dumped to stderr:

View file

@ -30,7 +30,15 @@ const (
Latitude string = "^[-+]?([1-8]?\\d(\\.\\d+)?|90(\\.0+)?)$"
Longitude string = "^[-+]?(180(\\.0+)?|((1[0-7]\\d)|([1-9]?\\d))(\\.\\d+)?)$"
DNSName string = `^([a-zA-Z0-9]{1}[a-zA-Z0-9_-]{1,62}){1}(\.[a-zA-Z0-9]{1}[a-zA-Z0-9_-]{1,62})*$`
URL string = `^((ftp|https?):\/\/)?(\S+(:\S*)?@)?((([1-9]\d?|1\d\d|2[01]\d|22[0-3])(\.(1?\d{1,2}|2[0-4]\d|25[0-5])){2}(?:\.([0-9]\d?|1\d\d|2[0-4]\d|25[0-4]))|(([a-zA-Z0-9]([a-zA-Z0-9-]+)?[a-zA-Z0-9]([-\.][a-zA-Z0-9]+)*)|((www\.)?))?(([a-zA-Z\x{00a1}-\x{ffff}0-9]+-?-?)*[a-zA-Z\x{00a1}-\x{ffff}0-9]+)(?:\.([a-zA-Z\x{00a1}-\x{ffff}]{1,}))?))(:(\d{1,5}))?((\/|\?|#)[^\s]*)?$`
IP string = `(([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,7}:|([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|:((:[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))`
URLSchema string = `((ftp|tcp|udp|wss?|https?):\/\/)`
URLUsername string = `(\S+(:\S*)?@)`
Hostname string = ``
URLPath string = `((\/|\?|#)[^\s]*)`
URLPort string = `(:(\d{1,5}))`
URLIP string = `([1-9]\d?|1\d\d|2[01]\d|22[0-3])(\.(1?\d{1,2}|2[0-4]\d|25[0-5])){2}(?:\.([0-9]\d?|1\d\d|2[0-4]\d|25[0-4]))`
URLSubdomain string = `((www\.)|([a-zA-Z0-9]([-\.][a-zA-Z0-9]+)*))`
URL string = `^` + URLSchema + `?` + URLUsername + `?` + `((` + URLIP + `|(\[` + IP + `\])|(([a-zA-Z0-9]([a-zA-Z0-9-]+)?[a-zA-Z0-9]([-\.][a-zA-Z0-9]+)*)|(` + URLSubdomain + `?))?(([a-zA-Z\x{00a1}-\x{ffff}0-9]+-?-?)*[a-zA-Z\x{00a1}-\x{ffff}0-9]+)(?:\.([a-zA-Z\x{00a1}-\x{ffff}]{1,}))?))` + URLPort + `?` + URLPath + `?$`
SSN string = `^\d{3}[- ]?\d{2}[- ]?\d{4}$`
WinPath string = `^[a-zA-Z]:\\(?:[^\\/:*?"<>|\r\n]+\\)*[^\\/:*?"<>|\r\n]*$`
UnixPath string = `^((?:\/[a-zA-Z0-9\.\:]+(?:_[a-zA-Z0-9\:\.]+)*(?:\-[\:a-zA-Z0-9\.]+)*)+\/?)$`

View file

@ -496,6 +496,12 @@ func IsIPv6(str string) bool {
return ip != nil && strings.Contains(str, ":")
}
// IsCIDR check if the string is an valid CIDR notiation (IPV4 & IPV6)
func IsCIDR(str string) bool {
_, _, err := net.ParseCIDR(str)
return err == nil
}
// IsMAC check if a string is valid MAC address.
// Possible MAC formats:
// 01:23:45:67:89:ab

View file

@ -2,7 +2,6 @@ package client
import (
"fmt"
"io/ioutil"
"net/http/httputil"
"github.com/aws/aws-sdk-go/aws"
@ -104,8 +103,7 @@ func logRequest(r *request.Request) {
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.Body.Seek(r.BodyStart, 0)
r.HTTPRequest.Body = ioutil.NopCloser(r.Body)
r.ResetBody()
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody)))

View file

@ -137,9 +137,6 @@ type Config struct {
// accelerate enabled. If the bucket is not enabled for accelerate an error
// will be returned. The bucket name must be DNS compatible to also work
// with accelerate.
//
// Not compatible with UseDualStack requests will fail if both flags are
// specified.
S3UseAccelerate *bool
// Set this to `true` to disable the EC2Metadata client from overriding the
@ -185,6 +182,19 @@ type Config struct {
// the delay of a request see the aws/client.DefaultRetryer and
// aws/request.Retryer.
SleepDelay func(time.Duration)
// DisableRestProtocolURICleaning will not clean the URL path when making rest protocol requests.
// Will default to false. This would only be used for empty directory names in s3 requests.
//
// Example:
// sess, err := session.NewSession(&aws.Config{DisableRestProtocolURICleaning: aws.Bool(true))
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput {
// Bucket: aws.String("bucketname"),
// Key: aws.String("//foo//bar//moo"),
// })
DisableRestProtocolURICleaning *bool
}
// NewConfig returns a new Config pointer that can be chained with builder
@ -406,6 +416,10 @@ func mergeInConfig(dst *Config, other *Config) {
if other.SleepDelay != nil {
dst.SleepDelay = other.SleepDelay
}
if other.DisableRestProtocolURICleaning != nil {
dst.DisableRestProtocolURICleaning = other.DisableRestProtocolURICleaning
}
}
// Copy will return a shallow copy of the Config object. If any additional

View file

@ -10,9 +10,11 @@ import (
"regexp"
"runtime"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)
@ -67,6 +69,34 @@ var SDKVersionUserAgentHandler = request.NamedHandler{
var reStatusCode = regexp.MustCompile(`^(\d{3})`)
// ValidateReqSigHandler is a request handler to ensure that the request's
// signature doesn't expire before it is sent. This can happen when a request
// is built and signed signficantly before it is sent. Or signficant delays
// occur whne retrying requests that would cause the signature to expire.
var ValidateReqSigHandler = request.NamedHandler{
Name: "core.ValidateReqSigHandler",
Fn: func(r *request.Request) {
// Unsigned requests are not signed
if r.Config.Credentials == credentials.AnonymousCredentials {
return
}
signedTime := r.Time
if !r.LastSignedAt.IsZero() {
signedTime = r.LastSignedAt
}
// 10 minutes to allow for some clock skew/delays in transmission.
// Would be improved with aws/aws-sdk-go#423
if signedTime.Add(10 * time.Minute).After(time.Now()) {
return
}
fmt.Println("request expired, resigning")
r.Sign()
},
}
// SendHandler is a request handler to send service request using HTTP client.
var SendHandler = request.NamedHandler{Name: "core.SendHandler", Fn: func(r *request.Request) {
var err error

View file

@ -34,7 +34,7 @@ var (
//
// Example of ChainProvider to be used with an EnvProvider and EC2RoleProvider.
// In this example EnvProvider will first check if any credentials are available
// vai the environment variables. If there are none ChainProvider will check
// via the environment variables. If there are none ChainProvider will check
// the next Provider in the list, EC2RoleProvider in this case. If EC2RoleProvider
// does not return any credentials ChainProvider will return the error
// ErrNoValidProvidersFoundInChain

View file

@ -111,7 +111,7 @@ func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
}, nil
}
// A ec2RoleCredRespBody provides the shape for unmarshalling credential
// A ec2RoleCredRespBody provides the shape for unmarshaling credential
// request responses.
type ec2RoleCredRespBody struct {
// Success State

View file

@ -72,6 +72,7 @@ func Handlers() request.Handlers {
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
handlers.Build.AfterEachFn = request.HandlerListStopOnError
handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
handlers.Send.PushBackNamed(corehandlers.ValidateReqSigHandler)
handlers.Send.PushBackNamed(corehandlers.SendHandler)
handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
handlers.ValidateResponse.PushBackNamed(corehandlers.ValidateResponseHandler)

View file

@ -3,6 +3,7 @@ package ec2metadata
import (
"encoding/json"
"fmt"
"net/http"
"path"
"strings"
"time"
@ -27,6 +28,27 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) {
return output.Content, req.Send()
}
// GetUserData returns the userdata that was configured for the service. If
// there is no user-data setup for the EC2 instance a "NotFoundError" error
// code will be returned.
func (c *EC2Metadata) GetUserData() (string, error) {
op := &request.Operation{
Name: "GetUserData",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "user-data"),
}
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
req.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
if r.HTTPResponse.StatusCode == http.StatusNotFound {
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
}
})
return output.Content, req.Send()
}
// GetDynamicData uses the path provided to request information from the EC2
// instance metadata service for dynamic data. The content will be returned
// as a string, or error if the request failed.
@ -111,7 +133,7 @@ func (c *EC2Metadata) Available() bool {
return true
}
// An EC2IAMInfo provides the shape for unmarshalling
// An EC2IAMInfo provides the shape for unmarshaling
// an IAM info from the metadata API
type EC2IAMInfo struct {
Code string
@ -120,7 +142,7 @@ type EC2IAMInfo struct {
InstanceProfileID string
}
// An EC2InstanceIdentityDocument provides the shape for unmarshalling
// An EC2InstanceIdentityDocument provides the shape for unmarshaling
// an instance identity document
type EC2InstanceIdentityDocument struct {
DevpayProductCodes []string `json:"devpayProductCodes"`

View file

@ -9,7 +9,7 @@ import (
// with retrying requests
type offsetReader struct {
buf io.ReadSeeker
lock sync.RWMutex
lock sync.Mutex
closed bool
}
@ -21,7 +21,8 @@ func newOffsetReader(buf io.ReadSeeker, offset int64) *offsetReader {
return reader
}
// Close is a thread-safe close. Uses the write lock.
// Close will close the instance of the offset reader's access to
// the underlying io.ReadSeeker.
func (o *offsetReader) Close() error {
o.lock.Lock()
defer o.lock.Unlock()
@ -29,10 +30,10 @@ func (o *offsetReader) Close() error {
return nil
}
// Read is a thread-safe read using a read lock.
// Read is a thread-safe read of the underlying io.ReadSeeker
func (o *offsetReader) Read(p []byte) (int, error) {
o.lock.RLock()
defer o.lock.RUnlock()
o.lock.Lock()
defer o.lock.Unlock()
if o.closed {
return 0, io.EOF
@ -41,6 +42,14 @@ func (o *offsetReader) Read(p []byte) (int, error) {
return o.buf.Read(p)
}
// Seek is a thread-safe seeking operation.
func (o *offsetReader) Seek(offset int64, whence int) (int64, error) {
o.lock.Lock()
defer o.lock.Unlock()
return o.buf.Seek(offset, whence)
}
// CloseAndCopy will return a new offsetReader with a copy of the old buffer
// and close the old buffer.
func (o *offsetReader) CloseAndCopy(offset int64) *offsetReader {

View file

@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"reflect"
@ -42,6 +41,12 @@ type Request struct {
LastSignedAt time.Time
built bool
// Need to persist an intermideant body betweend the input Body and HTTP
// request body because the HTTP Client's transport can maintain a reference
// to the HTTP request's body after the client has returned. This value is
// safe to use concurrently and rewraps the input Body for each HTTP request.
safeBody *offsetReader
}
// An Operation is the service API operation to be made.
@ -135,8 +140,8 @@ func (r *Request) SetStringBody(s string) {
// SetReaderBody will set the request's body reader.
func (r *Request) SetReaderBody(reader io.ReadSeeker) {
r.HTTPRequest.Body = newOffsetReader(reader, 0)
r.Body = reader
r.ResetBody()
}
// Presign returns the request's signed URL. Error will be returned
@ -220,6 +225,24 @@ func (r *Request) Sign() error {
return r.Error
}
// ResetBody rewinds the request body backto its starting position, and
// set's the HTTP Request body reference. When the body is read prior
// to being sent in the HTTP request it will need to be rewound.
func (r *Request) ResetBody() {
if r.safeBody != nil {
r.safeBody.Close()
}
r.safeBody = newOffsetReader(r.Body, r.BodyStart)
r.HTTPRequest.Body = r.safeBody
}
// GetBody will return an io.ReadSeeker of the Request's underlying
// input body with a concurrency safe wrapper.
func (r *Request) GetBody() io.ReadSeeker {
return r.safeBody
}
// Send will send the request returning error if errors are encountered.
//
// Send will sign the request prior to sending. All Send Handlers will
@ -231,6 +254,8 @@ func (r *Request) Sign() error {
//
// readLoop() and getConn(req *Request, cm connectMethod)
// https://github.com/golang/go/blob/master/src/net/http/transport.go
//
// Send will not close the request.Request's body.
func (r *Request) Send() error {
for {
if aws.BoolValue(r.Retryable) {
@ -239,21 +264,15 @@ func (r *Request) Send() error {
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
}
var body io.ReadCloser
if reader, ok := r.HTTPRequest.Body.(*offsetReader); ok {
body = reader.CloseAndCopy(r.BodyStart)
} else {
if r.Config.Logger != nil {
r.Config.Logger.Log("Request body type has been overwritten. May cause race conditions")
}
r.Body.Seek(r.BodyStart, 0)
body = ioutil.NopCloser(r.Body)
}
// The previous http.Request will have a reference to the r.Body
// and the HTTP Client's Transport may still be reading from
// the request's body even though the Client's Do returned.
r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil)
r.ResetBody()
r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, body)
// Closing response body to ensure that no response body is leaked
// between retry attempts.
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil {
// Closing response body. Since we are setting a new request to send off, this
// response will get squashed and leaked.
r.HTTPResponse.Body.Close()
}
}
@ -281,7 +300,6 @@ func (r *Request) Send() error {
debugLogReqError(r, "Send Request", true, err)
continue
}
r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r)
if r.Error != nil {

View file

@ -66,7 +66,7 @@ through code instead of being driven by environment variables only.
Use NewSessionWithOptions when you want to provide the config profile, or
override the shared config state (AWS_SDK_LOAD_CONFIG).
// Equivalent to session.New
// Equivalent to session.NewSession()
sess, err := session.NewSessionWithOptions(session.Options{})
// Specify profile to load for the session's config

View file

@ -2,7 +2,7 @@ package session
import (
"fmt"
"os"
"io/ioutil"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
@ -105,12 +105,13 @@ func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) {
files := make([]sharedConfigFile, 0, len(filenames))
for _, filename := range filenames {
if _, err := os.Stat(filename); os.IsNotExist(err) {
// Trim files from the list that don't exist.
b, err := ioutil.ReadFile(filename)
if err != nil {
// Skip files which can't be opened and read for whatever reason
continue
}
f, err := ini.Load(filename)
f, err := ini.Load(b)
if err != nil {
return nil, SharedConfigLoadError{Filename: filename}
}

View file

@ -0,0 +1,24 @@
// +build go1.5
package v4
import (
"net/url"
"strings"
)
func getURIPath(u *url.URL) string {
var uri string
if len(u.Opaque) > 0 {
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
} else {
uri = u.EscapedPath()
}
if len(uri) == 0 {
uri = "/"
}
return uri
}

View file

@ -0,0 +1,24 @@
// +build !go1.5
package v4
import (
"net/url"
"strings"
)
func getURIPath(u *url.URL) string {
var uri string
if len(u.Opaque) > 0 {
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
} else {
uri = u.Path
}
if len(uri) == 0 {
uri = "/"
}
return uri
}

View file

@ -2,6 +2,48 @@
//
// Provides request signing for request that need to be signed with
// AWS V4 Signatures.
//
// Standalone Signer
//
// Generally using the signer outside of the SDK should not require any additional
// logic when using Go v1.5 or higher. The signer does this by taking advantage
// of the URL.EscapedPath method. If your request URI requires additional escaping
// you many need to use the URL.Opaque to define what the raw URI should be sent
// to the service as.
//
// The signer will first check the URL.Opaque field, and use its value if set.
// The signer does require the URL.Opaque field to be set in the form of:
//
// "//<hostname>/<path>"
//
// // e.g.
// "//example.com/some/path"
//
// The leading "//" and hostname are required or the URL.Opaque escaping will
// not work correctly.
//
// If URL.Opaque is not set the signer will fallback to the URL.EscapedPath()
// method and using the returned value. If you're using Go v1.4 you must set
// URL.Opaque if the URI path needs escaping. If URL.Opaque is not set with
// Go v1.5 the signer will fallback to URL.Path.
//
// AWS v4 signature validation requires that the canonical string's URI path
// element must be the URI escaped form of the HTTP request's path.
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
//
// The Go HTTP client will perform escaping automatically on the request. Some
// of these escaping may cause signature validation errors because the HTTP
// request differs from the URI path or query that the signature was generated.
// https://golang.org/pkg/net/url/#URL.EscapedPath
//
// Because of this, it is recommended that when using the signer outside of the
// SDK that explicitly escaping the request prior to being signed is preferable,
// and will help prevent signature validation errors. This can be done by setting
// the URL.Opaque or URL.RawPath. The SDK will use URL.Opaque first and then
// call URL.EscapedPath() if Opaque is not set.
//
// Test `TestStandaloneSign` provides a complete example of using the signer
// outside of the SDK and pre-escaping the URI path.
package v4
import (
@ -120,6 +162,15 @@ type Signer struct {
// request's query string.
DisableHeaderHoisting bool
// Disables the automatic escaping of the URI path of the request for the
// siganture's canonical string's path. For services that do not need additional
// escaping then use this to disable the signer escaping the path.
//
// S3 is an example of a service that does not need additional escaping.
//
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
DisableURIPathEscaping bool
// currentTimeFn returns the time value which represents the current time.
// This value should only be used for testing. If it is nil the default
// time.Now will be used.
@ -151,6 +202,8 @@ type signingCtx struct {
ExpireTime time.Duration
SignedHeaderVals http.Header
DisableURIPathEscaping bool
credValues credentials.Value
isPresign bool
formattedTime string
@ -236,22 +289,18 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
}
ctx := &signingCtx{
Request: r,
Body: body,
Query: r.URL.Query(),
Time: signTime,
ExpireTime: exp,
isPresign: exp != 0,
ServiceName: service,
Region: region,
Request: r,
Body: body,
Query: r.URL.Query(),
Time: signTime,
ExpireTime: exp,
isPresign: exp != 0,
ServiceName: service,
Region: region,
DisableURIPathEscaping: v4.DisableURIPathEscaping,
}
if ctx.isRequestSigned() {
if !v4.Credentials.IsExpired() && currentTimeFn().Before(ctx.Time.Add(10*time.Minute)) {
// If the request is already signed, and the credentials have not
// expired, and the request is not too old ignore the signing request.
return ctx.SignedHeaderVals, nil
}
ctx.Time = currentTimeFn()
ctx.handlePresignRemoval()
}
@ -359,6 +408,10 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
v4.Logger = req.Config.Logger
v4.DisableHeaderHoisting = req.NotHoist
v4.currentTimeFn = curTimeFn
if name == "s3" {
// S3 service should not have any escaping applied
v4.DisableURIPathEscaping = true
}
})
signingTime := req.Time
@ -366,7 +419,9 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
signingTime = req.LastSignedAt
}
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.Body, name, region, req.ExpireTime, signingTime)
signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
name, region, req.ExpireTime, signingTime,
)
if err != nil {
req.Error = err
req.SignedHeaderVals = nil
@ -512,18 +567,15 @@ func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
}
func (ctx *signingCtx) buildCanonicalString() {
ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1)
uri := ctx.Request.URL.Opaque
if uri != "" {
uri = "/" + strings.Join(strings.Split(uri, "/")[3:], "/")
} else {
uri = ctx.Request.URL.Path
}
if uri == "" {
uri = "/"
query := ctx.Query
for key := range query {
sort.Strings(query[key])
}
ctx.Request.URL.RawQuery = strings.Replace(query.Encode(), "+", "%20", -1)
if ctx.ServiceName != "s3" {
uri := getURIPath(ctx.Request.URL)
if !ctx.DisableURIPathEscaping {
uri = rest.EscapePath(uri, false)
}

View file

@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK
const SDKVersion = "1.4.14"
const SDKVersion = "1.5.6"

View file

@ -1,7 +1,7 @@
// Package endpoints validates regional endpoints for services.
package endpoints
//go:generate go run ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go
//go:generate go run -tags codegen ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go
//go:generate gofmt -s -w endpoints_map.go
import (

View file

@ -23,6 +23,10 @@
"us-gov-west-1/ec2metadata": {
"endpoint": "http://169.254.169.254/latest"
},
"*/budgets": {
"endpoint": "budgets.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/cloudfront": {
"endpoint": "cloudfront.amazonaws.com",
"signingRegion": "us-east-1"

View file

@ -18,6 +18,10 @@ var endpointsMap = endpointStruct{
"*/*": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"*/budgets": {
Endpoint: "budgets.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/cloudfront": {
Endpoint: "cloudfront.amazonaws.com",
SigningRegion: "us-east-1",

View file

@ -1,7 +1,7 @@
// Package ec2query provides serialization of AWS EC2 requests and responses.
package ec2query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/ec2.json build_test.go
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/ec2.json build_test.go
import (
"net/url"

View file

@ -1,6 +1,6 @@
package ec2query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/ec2.json unmarshal_test.go
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/ec2.json unmarshal_test.go
import (
"encoding/xml"

View file

@ -2,8 +2,8 @@
// requests and responses.
package jsonrpc
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/json.json build_test.go
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/json.json unmarshal_test.go
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/json.json build_test.go
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/json.json unmarshal_test.go
import (
"encoding/json"

View file

@ -1,7 +1,7 @@
// Package query provides serialization of AWS query requests, and responses.
package query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/query.json build_test.go
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/query.json build_test.go
import (
"net/url"

View file

@ -1,6 +1,6 @@
package query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/query.json unmarshal_test.go
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/query.json unmarshal_test.go
import (
"encoding/xml"

View file

@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
@ -92,7 +93,7 @@ func buildLocationElements(r *request.Request, v reflect.Value) {
}
r.HTTPRequest.URL.RawQuery = query.Encode()
updatePath(r.HTTPRequest.URL, r.HTTPRequest.URL.Path)
updatePath(r.HTTPRequest.URL, r.HTTPRequest.URL.Path, aws.BoolValue(r.Config.DisableRestProtocolURICleaning))
}
func buildBody(r *request.Request, v reflect.Value) {
@ -193,13 +194,15 @@ func buildQueryString(query url.Values, v reflect.Value, name string) error {
return nil
}
func updatePath(url *url.URL, urlPath string) {
func updatePath(url *url.URL, urlPath string, disableRestProtocolURICleaning bool) {
scheme, query := url.Scheme, url.RawQuery
hasSlash := strings.HasSuffix(urlPath, "/")
// clean up path
urlPath = path.Clean(urlPath)
if !disableRestProtocolURICleaning {
urlPath = path.Clean(urlPath)
}
if hasSlash && !strings.HasSuffix(urlPath, "/") {
urlPath += "/"
}

View file

@ -2,8 +2,8 @@
// requests and responses.
package restxml
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/rest-xml.json build_test.go
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/rest-xml.json unmarshal_test.go
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/rest-xml.json build_test.go
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/rest-xml.json unmarshal_test.go
import (
"bytes"

File diff suppressed because it is too large Load diff

Some files were not shown because too many files have changed in this diff Show more