Expand HMAC support in Salt; require an identifier be passed in to specify type but allow generation with and without. Add a StaticSalt ID for testing functions. Fix bugs; unit tests pass.

This commit is contained in:
Jeff Mitchell 2015-09-18 17:36:42 -04:00
parent b655f6b858
commit 5dde76fa1c
11 changed files with 114 additions and 67 deletions

View File

@ -32,4 +32,4 @@ type BackendConfig struct {
}
// Factory is the factory function to create an audit backend.
type Factory func(BackendConfig) (Backend, error)
type Factory func(*BackendConfig) (Backend, error)

View File

@ -16,7 +16,7 @@ import (
//
// The structure is modified in-place.
func Hash(salter *salt.Salt, raw interface{}) error {
fn := salter.GetHMAC
fn := salter.GetIdentifiedHMAC
switch s := raw.(type) {
case *logical.Auth:
@ -86,17 +86,6 @@ func HashStructure(s interface{}, cb HashCallback) (interface{}, error) {
// a value.
type HashCallback func(string) string
// HashSHA1 returns a HashCallback that hashes data with SHA1 and
// with an optional salt. If salt is a blank string, no salt is used.
/*
func HashSHA1(salt string) HashCallback {
return func(v string) (string, error) {
hashed := sha1.Sum([]byte(v + salt))
return "sha1:" + hex.EncodeToString(hashed[:]), nil
}
}
*/
// hashWalker implements interfaces for the reflectwalk package
// (github.com/mitchellh/reflectwalk) that can be used to automatically
// replace primitives with a hashed value.

View File

@ -1,11 +1,13 @@
package audit
import (
"crypto/sha256"
"fmt"
"reflect"
"testing"
"time"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical"
"github.com/mitchellh/copystructure"
)
@ -88,7 +90,7 @@ func TestHash(t *testing.T) {
}{
{
&logical.Auth{ClientToken: "foo"},
&logical.Auth{ClientToken: "sha1:0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"},
&logical.Auth{ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a"},
},
{
&logical.Request{
@ -98,7 +100,7 @@ func TestHash(t *testing.T) {
},
&logical.Request{
Data: map[string]interface{}{
"foo": "sha1:62cdb7020ff920e5aa642c3d4066950dd1f01f4d",
"foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
},
},
},
@ -110,7 +112,7 @@ func TestHash(t *testing.T) {
},
&logical.Response{
Data: map[string]interface{}{
"foo": "sha1:62cdb7020ff920e5aa642c3d4066950dd1f01f4d",
"foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
},
},
},
@ -133,14 +135,22 @@ func TestHash(t *testing.T) {
IssueTime: now,
},
ClientToken: "sha1:0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33",
ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a",
},
},
}
localSalt, err := salt.NewSalt(nil, &salt.Config{
HMAC: sha256.New,
HMACType: "hmac-sha256",
StaticSalt: "foo",
})
if err != nil {
t.Fatalf("Error instantiating salt: %s", err)
}
for _, tc := range cases {
input := fmt.Sprintf("%#v", tc.Input)
if err := Hash(tc.Input); err != nil {
if err := Hash(localSalt, tc.Input); err != nil {
t.Fatalf("err: %s\n\n%s", err, input)
}
if !reflect.DeepEqual(tc.Input, tc.Output) {
@ -176,8 +186,8 @@ func TestHashWalker(t *testing.T) {
}
for _, tc := range cases {
output, err := HashStructure(tc.Input, func(string) (string, error) {
return replaceText, nil
output, err := HashStructure(tc.Input, func(string) string {
return replaceText
})
if err != nil {
t.Fatalf("err: %s\n\n%#v", err, tc.Input)
@ -187,14 +197,3 @@ func TestHashWalker(t *testing.T) {
}
}
}
func TestHashSHA1(t *testing.T) {
fn := HashSHA1("")
result, err := fn("foo")
if err != nil {
t.Fatalf("err: %s", err)
}
if result != "sha1:0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33" {
t.Fatalf("bad: %#v", result)
}
}

View File

@ -13,7 +13,7 @@ import (
"github.com/mitchellh/copystructure"
)
func Factory(conf audit.BackendConfig) (audit.Backend, error) {
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
if conf.Salt == nil {
return nil, fmt.Errorf("Nil salt passed in")
}

View File

@ -12,7 +12,7 @@ import (
"github.com/mitchellh/copystructure"
)
func Factory(conf audit.BackendConfig) (audit.Backend, error) {
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
if conf.Salt == nil {
return nil, fmt.Errorf("Nil salt passed in")
}

View File

@ -27,6 +27,7 @@ type Salt struct {
salt string
generated bool
hmac hash.Hash
hmacType string
}
type HashFunc func([]byte) []byte
@ -44,6 +45,14 @@ type Config struct {
// HMAC allows specification of a hash function to use for
// the HMAC helpers
HMAC func() hash.Hash
// String prepended to HMAC strings for identification.
// Required if using HMAC
HMACType string
// A static string to use if set. If not set, one will be
// generated and persisted. This value will *not* be persisted.
StaticSalt string
}
// NewSalt creates a new salt based on the configuration
@ -64,35 +73,49 @@ func NewSalt(view logical.Storage, config *Config) (*Salt, error) {
config: config,
}
// Look for the salt
raw, err := view.Get(config.Location)
if err != nil {
return nil, fmt.Errorf("failed to read salt: %v", err)
}
var raw *logical.StorageEntry
var err error
if config.StaticSalt != "" {
s.salt = config.StaticSalt
} else {
if view != nil {
// Look for the salt
raw, err = view.Get(config.Location)
if err != nil {
return nil, fmt.Errorf("failed to read salt: %v", err)
}
// Restore the salt if it exists
if raw != nil {
s.salt = string(raw.Value)
// Restore the salt if it exists
if raw != nil {
s.salt = string(raw.Value)
}
}
}
// Generate a new salt if necessary
if s.salt == "" {
s.salt = uuid.GenerateUUID()
s.generated = true
raw = &logical.StorageEntry{
Key: config.Location,
Value: []byte(s.salt),
}
if err := view.Put(raw); err != nil {
return nil, fmt.Errorf("failed to persist salt: %v", err)
if view != nil {
raw = &logical.StorageEntry{
Key: config.Location,
Value: []byte(s.salt),
}
if err := view.Put(raw); err != nil {
return nil, fmt.Errorf("failed to persist salt: %v", err)
}
}
}
if config.HMAC != nil {
if len(config.HMACType) == 0 {
return nil, fmt.Errorf("HMACType must be defined")
}
s.hmac = hmac.New(config.HMAC, []byte(s.salt))
if s.hmac == nil {
return nil, fmt.Errorf("failed to instantiate HMAC function")
}
s.hmacType = config.HMACType
}
return s, nil
@ -104,7 +127,7 @@ func (s *Salt) SaltID(id string) string {
return SaltID(s.salt, id, s.config.HashFunc)
}
// SaltIDandHMAC is used to apply a salt and hash function to an ID to make sure
// GetHMAC is used to apply a salt and hash function to an ID to make sure
// it is not reversible, with an additional HMAC
func (s *Salt) GetHMAC(id string) string {
if s.hmac == nil {
@ -112,7 +135,19 @@ func (s *Salt) GetHMAC(id string) string {
}
s.hmac.Reset()
s.hmac.Write([]byte(id))
return string(s.hmac.Sum(nil))
return hex.EncodeToString(s.hmac.Sum(nil))
}
// GetIdentifiedHMAC is used to apply a salt and hash function to an ID to make sure
// it is not reversible, with an additional HMAC, and ID prepended
func (s *Salt) GetIdentifiedHMAC(id string) string {
if s.hmac == nil {
return ""
}
s.hmac.Reset()
s.hmac.Write([]byte(id))
return s.hmacType + ":" + hex.EncodeToString(s.hmac.Sum(nil))
}
// DidGenerate returns if the underlying salt value was generated

View File

@ -209,11 +209,12 @@ func (c *Core) newAuditBackend(t string, view logical.Storage, conf map[string]s
salter, err := salt.NewSalt(view, &salt.Config{
HashFunc: salt.SHA256Hash,
HMAC: sha256.New,
HMACType: "hmac-sha256",
})
if err != nil {
return nil, fmt.Errorf("[ERR] core: unable to generate salt: %v", err)
}
return f(audit.BackendConfig{
return f(&audit.BackendConfig{
Salt: salter,
Config: conf,
})

View File

@ -15,6 +15,7 @@ import (
)
type NoopAudit struct {
Config *audit.BackendConfig
ReqErr error
ReqAuth []*logical.Auth
Req []*logical.Request
@ -44,8 +45,10 @@ func (n *NoopAudit) LogResponse(a *logical.Auth, r *logical.Request, re *logical
func TestCore_EnableAudit(t *testing.T) {
c, key, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
return &NoopAudit{}, nil
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return &NoopAudit{
Config: config,
}, nil
}
me := &MountEntry{
@ -66,8 +69,10 @@ func TestCore_EnableAudit(t *testing.T) {
AuditBackends: make(map[string]audit.Factory),
DisableMlock: true,
}
conf.AuditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
return &NoopAudit{}, nil
conf.AuditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return &NoopAudit{
Config: config,
}, nil
}
c2, err := NewCore(conf)
if err != nil {
@ -94,8 +99,10 @@ func TestCore_EnableAudit(t *testing.T) {
func TestCore_DisableAudit(t *testing.T) {
c, key, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
return &NoopAudit{}, nil
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return &NoopAudit{
Config: config,
}, nil
}
err := c.disableAudit("foo")

View File

@ -841,7 +841,10 @@ func TestCore_HandleRequest_AuditTrail(t *testing.T) {
// Create a noop audit backend
noop := &NoopAudit{}
c, _, root := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
noop = &NoopAudit{
Config: config,
}
return noop, nil
}
@ -920,7 +923,10 @@ func TestCore_HandleLogin_AuditTrail(t *testing.T) {
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
return noopBack, nil
}
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
noop = &NoopAudit{
Config: config,
}
return noop, nil
}

View File

@ -521,8 +521,10 @@ func TestSystemBackend_policyCRUD(t *testing.T) {
func TestSystemBackend_enableAudit(t *testing.T) {
c, b, _ := testCoreSystemBackend(t)
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
return &NoopAudit{}, nil
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return &NoopAudit{
Config: config,
}, nil
}
req := logical.TestRequest(t, logical.WriteOperation, "audit/foo")
@ -552,8 +554,10 @@ func TestSystemBackend_enableAudit_invalid(t *testing.T) {
func TestSystemBackend_auditTable(t *testing.T) {
c, b, _ := testCoreSystemBackend(t)
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
return &NoopAudit{}, nil
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return &NoopAudit{
Config: config,
}, nil
}
req := logical.TestRequest(t, logical.WriteOperation, "audit/foo")
@ -586,8 +590,10 @@ func TestSystemBackend_auditTable(t *testing.T) {
func TestSystemBackend_disableAudit(t *testing.T) {
c, b, _ := testCoreSystemBackend(t)
c.auditBackends["noop"] = func(map[string]string) (audit.Backend, error) {
return &NoopAudit{}, nil
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return &NoopAudit{
Config: config,
}, nil
}
req := logical.TestRequest(t, logical.WriteOperation, "audit/foo")

View File

@ -57,8 +57,10 @@ oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F
// TestCore returns a pure in-memory, uninitialized core for testing.
func TestCore(t *testing.T) *Core {
noopAudits := map[string]audit.Factory{
"noop": func(audit.BackendConfig) (audit.Backend, error) {
return new(noopAudit), nil
"noop": func(config *audit.BackendConfig) (audit.Backend, error) {
return &noopAudit{
Config: config,
}, nil
},
}
noopBackends := make(map[string]logical.Factory)
@ -240,7 +242,9 @@ func AddTestLogicalBackend(name string, factory logical.Factory) error {
return nil
}
type noopAudit struct{}
type noopAudit struct {
Config *audit.BackendConfig
}
func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error {
return nil