diff --git a/builtin/credential/userpass/path_user_password.go b/builtin/credential/userpass/path_user_password.go index 4398a7a23..d2d2d534d 100644 --- a/builtin/credential/userpass/path_user_password.go +++ b/builtin/credential/userpass/path_user_password.go @@ -3,6 +3,8 @@ package userpass import ( "fmt" + "golang.org/x/crypto/bcrypt" + "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -33,11 +35,34 @@ func pathUserPassword(b *backend) *framework.Path { func (b *backend) pathUserPasswordUpdate( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + + username := d.Get("username").(string) + + userEntry, err := b.user(req.Storage, username) + if err != nil { + return nil, err + } + + err = b.updateUserPassword(req, d, userEntry) + if err != nil { + return nil, err + } + + return nil, b.setUser(req.Storage, username, userEntry) +} + +func (b *backend) updateUserPassword(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) error { password := d.Get("password").(string) if password == "" { - return nil, fmt.Errorf("missing password") + return fmt.Errorf("missing password") } - return b.userCreateUpdate(req, d) + // Generate a hash of the password + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return err + } + userEntry.PasswordHash = hash + return nil } const pathUserPasswordHelpSyn = ` diff --git a/builtin/credential/userpass/path_user_policies.go b/builtin/credential/userpass/path_user_policies.go index 5a84372b9..d3124ce4f 100644 --- a/builtin/credential/userpass/path_user_policies.go +++ b/builtin/credential/userpass/path_user_policies.go @@ -1,6 +1,8 @@ package userpass import ( + "strings" + "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -30,7 +32,29 @@ func pathUserPolicies(b *backend) *framework.Path { func (b *backend) pathUserPoliciesUpdate( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - return b.userCreateUpdate(req, d) + + username := d.Get("username").(string) + + userEntry, err := b.user(req.Storage, username) + if err != nil { + return nil, err + } + + err = b.updateUserPolicies(req, d, userEntry) + if err != nil { + return nil, err + } + + return nil, b.setUser(req.Storage, username, userEntry) +} + +func (b *backend) updateUserPolicies(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) error { + policies := strings.Split(d.Get("policies").(string), ",") + for i, p := range policies { + policies[i] = strings.TrimSpace(p) + } + userEntry.Policies = policies + return nil } const pathUserPoliciesHelpSyn = ` diff --git a/builtin/credential/userpass/path_users.go b/builtin/credential/userpass/path_users.go index 33bf27b25..bc81018e4 100644 --- a/builtin/credential/userpass/path_users.go +++ b/builtin/credential/userpass/path_users.go @@ -7,7 +7,6 @@ import ( "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" - "golang.org/x/crypto/bcrypt" ) func pathUsers(b *backend) *framework.Path { @@ -132,33 +131,36 @@ func (b *backend) userCreateUpdate(req *logical.Request, d *framework.FieldData) userEntry = &UserEntry{} } - // Set/update the values of UserEntry only if fields are supplied - if passwordRaw, ok := d.GetOk("password"); ok { - // Generate a hash of the password - hash, err := bcrypt.GenerateFromPassword([]byte(passwordRaw.(string)), bcrypt.DefaultCost) + // "password" will always be set here + err = b.updateUserPassword(req, d, userEntry) + if err != nil { + return nil, err + } + + if _, ok := d.GetOk("policies"); ok { + err = b.updateUserPolicies(req, d, userEntry) if err != nil { return nil, err } - userEntry.PasswordHash = hash } - if policiesRaw, ok := d.GetOk("policies"); ok { - policies := strings.Split(policiesRaw.(string), ",") - for i, p := range policies { - policies[i] = strings.TrimSpace(p) - } - userEntry.Policies = policies + ttlStr := "" + if ttlStrRaw, ok := d.GetOk("ttl"); ok { + ttlStr = ttlStrRaw.(string) + } else if req.Operation == logical.CreateOperation { + ttlStr = d.Get("ttl").(string) } - _, ttlSet := d.GetOk("ttl") - _, maxTTLSet := d.GetOk("max_ttl") - if ttlSet || maxTTLSet { - ttlStr := d.Get("ttl").(string) - maxTTLStr := d.Get("max_ttl").(string) - userEntry.TTL, userEntry.MaxTTL, err = b.SanitizeTTL(ttlStr, maxTTLStr) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("err: %s", err)), nil - } + maxTTLStr := "" + if maxTTLStrRaw, ok := d.GetOk("max_ttl"); ok { + maxTTLStr = maxTTLStrRaw.(string) + } else if req.Operation == logical.CreateOperation { + maxTTLStr = d.Get("max_ttl").(string) + } + + userEntry.TTL, userEntry.MaxTTL, err = b.SanitizeTTL(ttlStr, maxTTLStr) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("err: %s", err)), nil } return nil, b.setUser(req.Storage, username, userEntry)