Refactor updating user values

This commit is contained in:
vishalnayak 2016-03-16 13:39:20 -04:00
parent 533b136fe7
commit 239ad4ad7e
3 changed files with 75 additions and 24 deletions

View File

@ -3,6 +3,8 @@ package userpass
import ( import (
"fmt" "fmt"
"golang.org/x/crypto/bcrypt"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -33,11 +35,34 @@ func pathUserPassword(b *backend) *framework.Path {
func (b *backend) pathUserPasswordUpdate( func (b *backend) pathUserPasswordUpdate(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) { req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
username := d.Get("username").(string)
userEntry, err := b.user(req.Storage, username)
if err != nil {
return nil, err
}
err = b.updateUserPassword(req, d, userEntry)
if err != nil {
return nil, err
}
return nil, b.setUser(req.Storage, username, userEntry)
}
func (b *backend) updateUserPassword(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) error {
password := d.Get("password").(string) password := d.Get("password").(string)
if password == "" { if password == "" {
return nil, fmt.Errorf("missing password") return fmt.Errorf("missing password")
} }
return b.userCreateUpdate(req, d) // Generate a hash of the password
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
userEntry.PasswordHash = hash
return nil
} }
const pathUserPasswordHelpSyn = ` const pathUserPasswordHelpSyn = `

View File

@ -1,6 +1,8 @@
package userpass package userpass
import ( import (
"strings"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
@ -30,7 +32,29 @@ func pathUserPolicies(b *backend) *framework.Path {
func (b *backend) pathUserPoliciesUpdate( func (b *backend) pathUserPoliciesUpdate(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) { req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
return b.userCreateUpdate(req, d)
username := d.Get("username").(string)
userEntry, err := b.user(req.Storage, username)
if err != nil {
return nil, err
}
err = b.updateUserPolicies(req, d, userEntry)
if err != nil {
return nil, err
}
return nil, b.setUser(req.Storage, username, userEntry)
}
func (b *backend) updateUserPolicies(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) error {
policies := strings.Split(d.Get("policies").(string), ",")
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
}
userEntry.Policies = policies
return nil
} }
const pathUserPoliciesHelpSyn = ` const pathUserPoliciesHelpSyn = `

View File

@ -7,7 +7,6 @@ import (
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
"golang.org/x/crypto/bcrypt"
) )
func pathUsers(b *backend) *framework.Path { func pathUsers(b *backend) *framework.Path {
@ -132,33 +131,36 @@ func (b *backend) userCreateUpdate(req *logical.Request, d *framework.FieldData)
userEntry = &UserEntry{} userEntry = &UserEntry{}
} }
// Set/update the values of UserEntry only if fields are supplied // "password" will always be set here
if passwordRaw, ok := d.GetOk("password"); ok { err = b.updateUserPassword(req, d, userEntry)
// Generate a hash of the password if err != nil {
hash, err := bcrypt.GenerateFromPassword([]byte(passwordRaw.(string)), bcrypt.DefaultCost) return nil, err
}
if _, ok := d.GetOk("policies"); ok {
err = b.updateUserPolicies(req, d, userEntry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userEntry.PasswordHash = hash
} }
if policiesRaw, ok := d.GetOk("policies"); ok { ttlStr := ""
policies := strings.Split(policiesRaw.(string), ",") if ttlStrRaw, ok := d.GetOk("ttl"); ok {
for i, p := range policies { ttlStr = ttlStrRaw.(string)
policies[i] = strings.TrimSpace(p) } else if req.Operation == logical.CreateOperation {
} ttlStr = d.Get("ttl").(string)
userEntry.Policies = policies
} }
_, ttlSet := d.GetOk("ttl") maxTTLStr := ""
_, maxTTLSet := d.GetOk("max_ttl") if maxTTLStrRaw, ok := d.GetOk("max_ttl"); ok {
if ttlSet || maxTTLSet { maxTTLStr = maxTTLStrRaw.(string)
ttlStr := d.Get("ttl").(string) } else if req.Operation == logical.CreateOperation {
maxTTLStr := d.Get("max_ttl").(string) maxTTLStr = d.Get("max_ttl").(string)
userEntry.TTL, userEntry.MaxTTL, err = b.SanitizeTTL(ttlStr, maxTTLStr) }
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("err: %s", err)), nil userEntry.TTL, userEntry.MaxTTL, err = b.SanitizeTTL(ttlStr, maxTTLStr)
} if err != nil {
return logical.ErrorResponse(fmt.Sprintf("err: %s", err)), nil
} }
return nil, b.setUser(req.Storage, username, userEntry) return nil, b.setUser(req.Storage, username, userEntry)